#include "distances.h"

#include "label_equivalence_relation.h"
#include "transition_system.h"

#include "../priority_queue.h"

#include <cassert>
#include <deque>

using namespace std;

namespace merge_and_shrink {
const int Distances::DISTANCE_UNKNOWN;

Distances::Distances(const TransitionSystem &transition_system)
    : transition_system(transition_system) {
    clear_distances();
}

Distances::~Distances() {
}

void Distances::clear_distances() {
    max_f = DISTANCE_UNKNOWN;
    max_g = DISTANCE_UNKNOWN;
    max_h = DISTANCE_UNKNOWN;
    init_distances.clear();
    goal_distances.clear();
}

int Distances::get_num_states() const {
    return transition_system.get_size();
}

bool Distances::is_unit_cost() const {
    /*
      TODO: Is this a good implementation? It differs from the
      previous implementation in transition_system.cc because that
      would require access to more attributes. One nice thing about it
      is that it gets at the label cost information in the same way
      that the actual shortest-path algorithms (e.g.
      compute_goal_distances_general_cost) do.
    */
    for (const GroupAndTransitions &gat : transition_system) {
        const LabelGroup &label_group = gat.label_group;
        if (label_group.get_cost() != 1)
            return false;
    }
    return true;
}

static void breadth_first_search(
    const vector<vector<int>> &graph, deque<int> &queue,
    vector<int> &distances) {
    while (!queue.empty()) {
        int state = queue.front();
        queue.pop_front();
        for (size_t i = 0; i < graph[state].size(); ++i) {
            int successor = graph[state][i];
            if (distances[successor] > distances[state] + 1) {
                distances[successor] = distances[state] + 1;
                queue.push_back(successor);
            }
        }
    }
}

void Distances::compute_init_distances_unit_cost() {
    vector<vector<int>> forward_graph(get_num_states());
    for (const GroupAndTransitions &gat : transition_system) {
        const vector<Transition> &transitions = gat.transitions;
        for (const Transition &transition : transitions) {
            forward_graph[transition.src].push_back(transition.target);
        }
    }

    deque<int> queue;
    // TODO: This is an oddly inefficient initialization! Fix it.
    for (int state = 0; state < get_num_states(); ++state) {
        if (state == transition_system.get_init_state()) {
            init_distances[state] = 0;
            queue.push_back(state);
        }
    }
    breadth_first_search(forward_graph, queue, init_distances);
}

void Distances::compute_goal_distances_unit_cost() {
    vector<vector<int>> backward_graph(get_num_states());
    for (const GroupAndTransitions &gat : transition_system) {
        const vector<Transition> &transitions = gat.transitions;
        for (const Transition &transition : transitions) {
            backward_graph[transition.target].push_back(transition.src);
        }
    }

    deque<int> queue;
    for (int state = 0; state < get_num_states(); ++state) {
        if (transition_system.is_goal_state(state)) {
            goal_distances[state] = 0;
            queue.push_back(state);
        }
    }
    breadth_first_search(backward_graph, queue, goal_distances);
}

static void dijkstra_search(
    const vector<vector<pair<int, int>>> &graph,
    AdaptiveQueue<int> &queue,
    vector<int> &distances) {
    while (!queue.empty()) {
        pair<int, int> top_pair = queue.pop();
        int distance = top_pair.first;
        int state = top_pair.second;
        int state_distance = distances[state];
        assert(state_distance <= distance);
        if (state_distance < distance)
            continue;
        for (size_t i = 0; i < graph[state].size(); ++i) {
            const pair<int, int> &transition = graph[state][i];
            int successor = transition.first;
            int cost = transition.second;
            int successor_cost = state_distance + cost;
            if (distances[successor] > successor_cost) {
                distances[successor] = successor_cost;
                queue.push(successor_cost, successor);
            }
        }
    }
}

void Distances::compute_init_distances_general_cost() {
    vector<vector<pair<int, int>>> forward_graph(get_num_states());
    for (const GroupAndTransitions &gat : transition_system) {
        const LabelGroup &label_group = gat.label_group;
        const vector<Transition> &transitions = gat.transitions;
        int cost = label_group.get_cost();
        for (const Transition &transition : transitions) {
            forward_graph[transition.src].push_back(
                make_pair(transition.target, cost));
        }
    }

    // TODO: Reuse the same queue for multiple computations to save speed?
    //       Also see compute_goal_distances_general_cost.
    AdaptiveQueue<int> queue;
    // TODO: This is an oddly inefficient initialization! Fix it.
    for (int state = 0; state < get_num_states(); ++state) {
        if (state == transition_system.get_init_state()) {
            init_distances[state] = 0;
            queue.push(0, state);
        }
    }
    dijkstra_search(forward_graph, queue, init_distances);
}

void Distances::compute_goal_distances_general_cost() {
    vector<vector<pair<int, int>>> backward_graph(get_num_states());
    for (const GroupAndTransitions &gat : transition_system) {
        const LabelGroup &label_group = gat.label_group;
        const vector<Transition> &transitions = gat.transitions;
        int cost = label_group.get_cost();
        for (const Transition &transition : transitions) {
            backward_graph[transition.target].push_back(
                make_pair(transition.src, cost));
        }
    }

    // TODO: Reuse the same queue for multiple computations to save speed?
    //       Also see compute_init_distances_general_cost.
    AdaptiveQueue<int> queue;
    for (int state = 0; state < get_num_states(); ++state) {
        if (transition_system.is_goal_state(state)) {
            goal_distances[state] = 0;
            queue.push(0, state);
        }
    }
    dijkstra_search(backward_graph, queue, goal_distances);
}

bool Distances::are_distances_computed() const {
    if (max_h == DISTANCE_UNKNOWN) {
        assert(max_f == DISTANCE_UNKNOWN);
        assert(max_g == DISTANCE_UNKNOWN);
        assert(init_distances.empty());
        assert(goal_distances.empty());
        return false;
    }
    return true;
}

vector<bool> Distances::compute_distances(Verbosity verbosity) {
    /*
      This method does the following:
      - Computes the distances of abstract states from the abstract
        initial state ("abstract g") and from the abstract goal states
        ("abstract h").
      - Set max_f, max_g and max_h.
      - Return a vector<bool> that indicates which states can be pruned
        because the are unreachable (abstract g is infinite) or
        irrelevant (abstract h is infinite).
    */

    if (verbosity >= Verbosity::VERBOSE) {
        cout << transition_system.tag();
    }
    assert(!are_distances_computed());
    assert(init_distances.empty() && goal_distances.empty());

    int num_states = get_num_states();

    if (num_states == 0) {
        if (verbosity >= Verbosity::VERBOSE) {
            cout << "empty transition system, no distances to compute" << endl;
        }
        max_f = max_g = max_h = INF;
        return vector<bool>();
    }

    init_distances.resize(num_states, INF);
    goal_distances.resize(num_states, INF);
    if (is_unit_cost()) {
        if (verbosity >= Verbosity::VERBOSE) {
            cout << "computing distances using unit-cost algorithm" << endl;
        }
        compute_init_distances_unit_cost();
        compute_goal_distances_unit_cost();
    } else {
        if (verbosity >= Verbosity::VERBOSE) {
            cout << "computing distances using general-cost algorithm" << endl;
        }
        compute_init_distances_general_cost();
        compute_goal_distances_general_cost();
    }

    max_f = 0;
    max_g = 0;
    max_h = 0;

    int unreachable_count = 0, irrelevant_count = 0;
    vector<bool> prunable_states(num_states, false);
    for (int i = 0; i < num_states; ++i) {
        int g = init_distances[i];
        int h = goal_distances[i];
        // States that are both unreachable and irrelevant are counted
        // as unreachable, not irrelevant. (Doesn't really matter, of
        // course.)
        if (g == INF) {
            ++unreachable_count;
            prunable_states[i] = true;
        } else if (h == INF) {
            ++irrelevant_count;
            prunable_states[i] = true;
        } else {
            max_f = max(max_f, g + h);
            max_g = max(max_g, g);
            max_h = max(max_h, h);
        }
    }
    if (verbosity >= Verbosity::VERBOSE &&
        (unreachable_count || irrelevant_count)) {
        cout << transition_system.tag()
             << "unreachable: " << unreachable_count << " states, "
             << "irrelevant: " << irrelevant_count << " states" << endl;
    }
    assert(are_distances_computed());
    return prunable_states;
}

void Distances::apply_abstraction(
    const StateEquivalenceRelation &state_equivalence_relation,
    Verbosity verbosity) {
    assert(are_distances_computed());
    assert(state_equivalence_relation.size() < init_distances.size());
    assert(state_equivalence_relation.size() < goal_distances.size());

    int new_num_states = state_equivalence_relation.size();
    vector<int> new_init_distances(new_num_states, DISTANCE_UNKNOWN);
    vector<int> new_goal_distances(new_num_states, DISTANCE_UNKNOWN);

    bool must_recompute = false;
    for (int new_state = 0; new_state < new_num_states; ++new_state) {
        const StateEquivalenceClass &state_equivalence_class =
            state_equivalence_relation[new_state];
        assert(!state_equivalence_class.empty());

        StateEquivalenceClass::const_iterator pos = state_equivalence_class.begin();
        int new_init_dist = init_distances[*pos];
        int new_goal_dist = goal_distances[*pos];

        ++pos;
        for (; pos != state_equivalence_class.end(); ++pos) {
            if (init_distances[*pos] != new_init_dist) {
                must_recompute = true;
                break;
            }
            if (goal_distances[*pos] != new_goal_dist) {
                must_recompute = true;
                break;
            }
        }

        if (must_recompute)
            break;

        new_init_distances[new_state] = new_init_dist;
        new_goal_distances[new_state] = new_goal_dist;
    }

    if (must_recompute) {
        if (verbosity >= Verbosity::VERBOSE) {
            cout << transition_system.tag()
                 << "simplification was not f-preserving!" << endl;
        }
        clear_distances();
        compute_distances(verbosity);
    } else {
        init_distances = move(new_init_distances);
        goal_distances = move(new_goal_distances);
    }
}

void Distances::dump() const {
    cout << "Distances: ";
    for (size_t i = 0; i < goal_distances.size(); ++i) {
        cout << i << ": " << goal_distances[i] << ", ";
    }
    cout << endl;
}

void Distances::statistics() const {
    cout << transition_system.tag();
    if (!are_distances_computed()) {
        cout << "distances not computed";
    } else if (transition_system.is_solvable()) {
        cout << "init h=" << get_goal_distance(transition_system.get_init_state())
             << ", max f=" << get_max_f()
             << ", max g=" << get_max_g()
             << ", max h=" << get_max_h();
    } else {
        cout << "transition system is unsolvable";
    }
    cout << endl;
}
}
