#include "abstract_search.h"

#include "abstract_state.h"
#include "transition_system.h"
#include "utils.h"

#include "../utils/memory.h"

#include <cassert>

using namespace std;

namespace cegar {
AbstractSearch::AbstractSearch(
    const vector<int> &operator_costs, 
    const HUpdateStrategy hupd,
    bool greedy)
    : operator_costs(operator_costs),
      search_info(1),
      shortest_path(1),
      hupd(hupd),
      greedy(greedy){
}

void AbstractSearch::reset(int num_states) {
    open_queue.clear();
    search_info.resize(num_states);
    for (AbstractSearchInfo &info : search_info) {
        info.reset();
    }
}

unique_ptr<Solution> AbstractSearch::extract_solution(int init_id, int goal_id) const {
    unique_ptr<Solution> solution = utils::make_unique_ptr<Solution>();
    int current_id = goal_id;
    while (current_id != init_id) {
        const Transition &prev = search_info[current_id].get_incoming_transition();
        solution->emplace_front(prev.op_id, current_id);
        assert(prev.target_id != current_id);
        current_id = prev.target_id;
    }
    return solution;
}

void AbstractSearch::update_goal_distances(const Solution &solution) {
    /*
      Originally, we only updated the goal distances of states that are part of
      the trace (see Seipp and Helmert, JAIR 2018). The code below generalizes
      this idea and potentially updates the goal distances of all states.

      Let C* be the cost of the trace and g(s) be the g value of states s when
      A* finds the trace. Then for all states s with g(s) < INF (i.e., s has
      been reached by the search), C*-g(s) is a lower bound on the goal
      distance. This is the case since

      g(s) >= g*(s) [1]

      and

          f*(s) >= C*         (optimality of A* with an admissible heuristic)
      ==> g*(s) + h*(s) >= C* (definition of f values)
      ==> g(s) + h*(s) >= C*  (using [1])
      ==> h*(s) >= C* - g(s)  (arithmetic)

      Together with our existing lower bound h*(s) >= h(s), i.e., the h values
      from the last iteration, for each abstract state s with g(s) < INF, we
      can set h(s) = max(h(s), C*-g(s)).
    */
    int solution_cost = 0;
    for (const Transition &transition : solution) {
        solution_cost += operator_costs[transition.op_id];
    }
    for (auto &info : search_info) {
        if (info.get_g_value() < INF) {
            int new_h = max(info.get_h_value(), solution_cost - info.get_g_value());
            info.increase_h_value_to(new_h);
        }
    }
}

unique_ptr<Solution> AbstractSearch::find_solution(
    const vector<Transitions> &transitions,
    int init_id,
    const Goals &goal_ids) {
    assert(init_id >= 0);
    reset(transitions.size());
    search_info[init_id].decrease_g_value_to(0);
    int goal_id = UNDEFINED;
    if (!greedy) {
        open_queue.push(search_info[init_id].get_h_value(), init_id);
        goal_id = astar_search(transitions, goal_ids);
    } else {
        if (hupd == HUpdateStrategy::OLD) {
            cout << "Greedy search needs perfect h" << endl;
            exit(34); //SEARCH_UNSUPPORTED
        }
        return forward_search(init_id, goal_ids);
    }
    open_queue.clear();
    bool has_found_solution = (goal_id != UNDEFINED);
    if (has_found_solution) {
        unique_ptr<Solution> solution = extract_solution(init_id, goal_id);
 /*       if (hupd == HUpdateStrategy::OLD) {
            update_goal_distances(*solution);
        } */
        return solution;
    } else {
        search_info[init_id].increase_h_value_to(INF);
    }
    return nullptr;
}

unique_ptr<Solution> AbstractSearch::forward_search(
    int init_id,
    const Goals &goals) {
    /*
     * No longer actually a greedy search; just extracts the solution from the
     * shortest path tree.
     * 
     * */
    int current_state = init_id;
    assert (current_state != -1);
    /* h* == INF iff goal is unreachable from this state */
    if (search_info[current_state].get_h_value() == INF)
        return nullptr;
    
    Solution solution;
    assert(!goals.count(current_state));
    while (!goals.count(current_state)) {
        assert(utils::in_bounds(current_state, shortest_path));
        const Transition& t = shortest_path[current_state];
//        cout << "move from " << current_state << " (h=" << search_info[current_state].get_h_value() 
//             << ") to " << t.target_id << " (h=" << search_info[t.target_id].get_h_value() << ")" << endl;
        assert(!(t == Transition()));
        assert(t.target_id != current_state);
        assert(t.target_id != -1);
        assert(search_info[t.target_id].get_h_value() <= search_info[current_state].get_h_value());
        solution.push_back(t);
        current_state = t.target_id;
    }
    return utils::make_unique_ptr<Solution>(std::move(solution));
}

int AbstractSearch::astar_search(
    const vector<Transitions> &transitions, const Goals &goals) {
    while (!open_queue.empty()) {
        pair<int, int> top_pair = open_queue.pop();
        int old_f = top_pair.first;
        int state_id = top_pair.second;

        const int g = search_info[state_id].get_g_value();
        assert(0 <= g && g < INF);
        int new_f = g + search_info[state_id].get_h_value();
        assert(new_f <= old_f);
        if (new_f < old_f)
            continue;
        if (goals.count(state_id)) {
            return state_id;
        }
        assert(utils::in_bounds(state_id, transitions));
        for (const Transition &transition : transitions[state_id]) {
            int op_id = transition.op_id;
            int succ_id = transition.target_id;

            assert(utils::in_bounds(op_id, operator_costs));
            const int op_cost = operator_costs[op_id];
            assert(op_cost >= 0);
            int succ_g = (op_cost == INF) ? INF : g + op_cost;
            assert(succ_g >= 0);

            if (succ_g < search_info[succ_id].get_g_value()) {
                search_info[succ_id].decrease_g_value_to(succ_g);
                int h = search_info[succ_id].get_h_value();
                if (h == INF)
                    continue;
                int f = succ_g + h;
                assert(f >= 0);
                assert(f != INF);
                open_queue.push(f, succ_id);
                search_info[succ_id].set_incoming_transition(Transition(op_id, state_id));
            }
        }
    }
    return UNDEFINED;
}

int AbstractSearch::get_h_value(int state_id) const {
    assert(utils::in_bounds(state_id, search_info));
    return search_info[state_id].get_h_value();
}

void AbstractSearch::set_h_value(int state_id, int h) {
    assert(utils::in_bounds(state_id, search_info));
    search_info[state_id].increase_h_value_to(h);
}


void AbstractSearch::copy_h_value_to_children(int v, int v1, int v2) {
    int h = get_h_value(v);
    search_info.resize(search_info.size() + 1);
    set_h_value(v1, h);
    set_h_value(v2, h);
}

void AbstractSearch::set_h_values(std::vector<int> values) {
    search_info.resize(values.size());
    for (size_t i = 0; i < values.size(); ++i)
        search_info[i].increase_h_value_to(values[i]);
}

void AbstractSearch::update_perfect_h(
    const vector<Transitions> &transitions,                                
    pair<int,int> split_states) {
    /* 
     * Assumption: all h-values correspond to the perfect heuristic for the
     * state space before the split. The second value of the split corresponds
     * to the "good" state for which h is still perfect, the first to the "bad".
     * 
     */
    search_info.resize(search_info.size() +1);
    shortest_path.resize(shortest_path.size()+1);
    assert(static_cast<size_t>(split_states.second) == transitions.size()-1);
    assert(search_info.size() == transitions.size());
    assert(shortest_path.size() == transitions.size());
    set_h_value(split_states.second, search_info[split_states.first].get_h_value());
    assert(!(shortest_path[split_states.first] == Transition()));
    shortest_path[split_states.second] = shortest_path[split_states.first];
    /* 
     * orphans holds the newly computed reverse g-values (i.e. h-values) for 
     * orphaned states, and -1 for settled states. A state is orphaned if its
     * successor in the shortest-path tree is orphaned, starting with s_1 and s_2.
     * We start by assuming g=INF for all orphaned states.
     */
    vector<int> orphans(transitions.size(), -1); 
    std::function<void(int)> insert = [&] (int state) {
        orphans[state] = INF;
        shortest_path[state] = Transition(); //shortest path information is invalid now
        for (const Transition& t : transitions[state]) {
            assert(t.target_id != state);
            assert(t.target_id != UNDEFINED);
            if (orphans[t.target_id] == -1 
                && shortest_path[t.target_id].target_id == state) {
                insert(t.target_id);
            }
        }
    };
    insert(split_states.first);
    
    for (const Transition& t : transitions[split_states.second]) {
        assert(t.target_id != split_states.second);
        assert(t.target_id != UNDEFINED);
        if (orphans[t.target_id] == -1 
            && shortest_path[t.target_id].target_id == split_states.first) {
            insert(t.target_id);
        }
    }
    
    insert(split_states.second);
    assert (orphans[split_states.first] == INF);
    //expand settled states
//    int orph_count = 0;
    for (unsigned int state=0; state<orphans.size(); ++state) {
        //consider only settled states...
        if (orphans[state] == -1) { 
            const int this_g = search_info[state].get_h_value();
            //...that are reachable...
            if(this_g == INF) continue;
            for (const Transition& t : transitions[state]) {
                const int opcost = operator_costs[t.op_id];
                assert(opcost >= 0);
                assert(opcost <=INF);
                // ...leading to orphans
                if (orphans[t.target_id] != -1) {
                    assert(INF-opcost > this_g);
                    const int succ_g = this_g + opcost;
                    int& old_g = orphans[t.target_id];
                    if (succ_g < old_g) {
                        old_g = succ_g;
                        shortest_path[t.target_id] = Transition(t.op_id, state);
                    }
                }
            }
        } else {
 //           ++orph_count;
        }
    }
//    cout << float(orph_count)/orphans.size()*100 << "% orphans" << endl;
    //seed queue with children of settled states
    open_queue.clear();
    for (size_t i=0; i<orphans.size(); ++i) {
        if (orphans[i] != -1 && orphans[i] != INF) {
            assert(orphans[i]>=0);
            open_queue.push(orphans[i], i);
        }
    }
    assert (!open_queue.empty());
    //start regular Dijkstra
    while(!open_queue.empty()) {
        pair<int,int> top_pair = open_queue.pop();
        const int g_value = top_pair.first;
        const int state = top_pair.second;
        assert(utils::in_bounds(state, orphans));
        if (g_value > orphans[state]) 
            continue;
        assert (g_value == orphans[state]);
        assert (g_value !=INF);
        search_info[state].increase_h_value_to(g_value);
        assert(utils::in_bounds(state, transitions));
        for (Transition t : transitions[state]) {
            const int opcost = operator_costs[t.op_id];
            assert(opcost !=INF);
            if (orphans[t.target_id] != -1) {
                const int succ_g = g_value + opcost;
                if (succ_g < orphans[t.target_id]) {
                    orphans[t.target_id] = succ_g;
                    shortest_path[t.target_id] = Transition(t.op_id, state);
                    open_queue.push(succ_g, t.target_id);
                }
            }
        }
    }
    for (size_t i=0; i<orphans.size(); ++i)
        if (orphans[i]==INF) search_info[i].increase_h_value_to(INF);
/*        
#ifndef NDEBUG
    for (size_t i=0; i<search_info.size(); ++i) {
        if (shortest_path[i].target_id != UNDEFINED) {
            bool check = false;
            for (Transition t : transitions[shortest_path[i].target_id]) {
                if (t.target_id >=0 && size_t(t.target_id) == i) {
                    check=true;
                    break;
                }
            }
            if(!check) {
                std::cout << "While splitting state " << split_states.second << " from " << split_states.first << std::endl;
                std::cout << "State " << i << " of " << search_info.size() << " leads to " << shortest_path[i].target_id << ": No such transition" << std::endl;
                shortest_path[i] = Transition();
            }
        }
        if (shortest_path[i].target_id != UNDEFINED && search_info[shortest_path[i].target_id].get_h_value() + operator_costs[shortest_path[i].op_id] != search_info[i].get_h_value()) {
            std::cout << "Discrepancy in state " << i << ": has h = " << search_info[i].get_h_value() << " != " << search_info[shortest_path[i].target_id].get_h_value() << " + " << operator_costs[shortest_path[i].op_id] << std::endl;
        }
    }

#endif
*/
}

void AbstractSearch::update_perfect_h_alt(
    const vector<Transitions> &in,  
    const vector<Transitions> &out, 
    pair<int,int> split_states) {
    /* 
     * Assumption: all h-values correspond to the perfect heuristic for the
     * state space before the split. The second value of the split corresponds
     * to the "good" state for which h is still perfect, the first to the "bad".
     * 
     */
    search_info.resize(search_info.size() +1);
    shortest_path.resize(shortest_path.size()+1);
    assert(in.size() == out.size());
    assert(static_cast<size_t>(split_states.second) == in.size()-1);
    assert(search_info.size() == in.size());
    assert(shortest_path.size() == in.size());
    set_h_value(split_states.second, search_info[split_states.first].get_h_value());
    shortest_path[split_states.second] = shortest_path[split_states.first];
    /* 
     * orphans holds the newly computed reverse g-values (i.e. h-values) for 
     * orphaned states, and -1 for settled states. A state is orphaned if at least
     * one of its possible shortest-path successors is orphaned, starting with s_1.
     * We start by assuming g=INF for all orphaned states.
     */
    vector<int> orphans(in.size(), -1); 
    std::function<void(int)> insert = [&] (int state) {
        orphans[state] = INF;
        shortest_path[state] = Transition(); //shortest path information is invalid now
        for (const Transition& t : in[state]) {
            assert(t.target_id != state);
            assert(t.target_id != UNDEFINED);
            if (orphans[t.target_id] == -1 
                && shortest_path[t.target_id].target_id == state) {
                insert(t.target_id);
            }
        }
    };
    insert(split_states.first);
    for (const Transition& t : in[split_states.second]) {
        assert(t.target_id != split_states.second);
        assert(t.target_id != UNDEFINED);
        if (orphans[t.target_id] == -1 
            && shortest_path[t.target_id].target_id == split_states.first) {
            insert(t.target_id);
        }
    }
    insert(split_states.second);
    assert (orphans[split_states.first] == INF);
    //expand settled states by searching forward from orphaned to settled states
    int orph_count = 0;
    open_queue.clear();
    for (size_t state=0; state<orphans.size(); ++state) {
        if (orphans[state] == -1) continue; //ignore settled states
        int& old_g = orphans[state];
        for (const Transition& t : out[state]) {
            const int opcost = operator_costs[t.op_id];
            assert(opcost !=INF);
            if (orphans[t.target_id] == -1) {
                const int prev_g = search_info[t.target_id].get_h_value();
                const int new_g = (prev_g<INF) ? (prev_g+opcost) : INF;
                if (new_g < old_g) {
                    old_g = new_g;
                    shortest_path[state] = t;
                }
            }
        }
        //we have evaluated all transitions from settled states to this one; it doesn't get cheaper during this loop
        if(orphans[state] != INF) {
            open_queue.push(orphans[state], state);
            ++orph_count;
        }
    }
    orphan_ratio *= float(orph_count)/orphans.size();
    //seed queue with children of settled states
    assert (!open_queue.empty());
    
    //start regular Dijkstra
    while(!open_queue.empty()) {
        pair<int,int> top_pair = open_queue.pop();
        const int g_value = top_pair.first;
        const int state = top_pair.second;
        if (g_value > orphans[state]) 
            continue;
        assert (g_value == orphans[state]);
        assert (g_value >= 0);
        assert (g_value != INF);
        search_info[state].increase_h_value_to(g_value);
        for (Transition t : in[state]) {
            const int opcost = operator_costs[t.op_id];
            assert(opcost >=0);
            assert(opcost != INF);
            if (orphans[t.target_id] != -1) {
                assert(INF-opcost > g_value);
                const int succ_g = g_value + opcost;
                if (succ_g < orphans[t.target_id]) {
                    orphans[t.target_id] = succ_g;
                    shortest_path[t.target_id] = Transition(t.op_id, state);
                    open_queue.push(succ_g, t.target_id);
                }
            }
        }
    }
    for (size_t i=0; i<orphans.size(); ++i)
        if (orphans[i]==INF) search_info[i].increase_h_value_to(INF);
}

void AbstractSearch::update_alg2(
    const vector<Transitions> &in,  
    const vector<Transitions> &out, 
    pair<int,int> split_states) {
    /* 
     * Assumption: all h-values correspond to the perfect heuristic for the
     * state space before the split. The second value of the split corresponds
     * to the "good" state for which h is still perfect, the first to the "bad".
     * 
     * Also, we assume there are no 0-cost operators, so the h-value of a
     * successor state will always be strictly less than of the previous state.
     * 
     */
#ifndef NDEBUG    
    for (int cost: operator_costs) assert(cost > 0);
#endif    
    search_info.resize(search_info.size() +1);
    shortest_path.resize(shortest_path.size()+1);
    assert(in.size() == out.size());
    assert(static_cast<size_t>(split_states.second) == in.size()-1);
    assert(search_info.size() == in.size());
    assert(shortest_path.size() == in.size());
    set_h_value(split_states.second, search_info[split_states.first].get_h_value());
    shortest_path[split_states.second] = shortest_path[split_states.first];
    /* 
     * orphans holds the newly computed reverse g-values (i.e. h-values) for 
     * orphaned states, and -1 for settled states. A state is orphaned if its
     * parent in the shortest-path tree (SPT) is orphaned, starting with s_1.
     * We start by assuming g=INF for all orphaned states.
     * 
     * Instead of just recursively inserting all orphans, we first push them
     * into a candidate queue that is sorted by (old, possibly too low) h-values.
     * Then, we try to reconnect them to a non-orphaned state at no additional
     * cost. Only if that fails, we flag the candidate as orphaned and push its 
     * SPT-children (who have strictly larger h-values due to no 0-cost operators)
     * into the candidate queue.
     * 
     */
    vector<int> orphans(in.size(), -1);
    candidate_queue.clear();
    /* 
     * The split-off state as well as all of its SPT-children will at least 
     * require updating their shortest-path pointer.
     */ 
    candidate_queue.push(search_info[split_states.first].get_h_value(), split_states.first);
    for (const Transition& t : in[split_states.second]) {
        assert(t.target_id != split_states.second);
        assert(t.target_id != UNDEFINED);
        if (shortest_path[t.target_id].target_id == split_states.first) {
            candidate_queue.push(search_info[t.target_id].get_h_value(), t.target_id);
        }
    }
    candidate_queue.push(search_info[split_states.second].get_h_value(), split_states.second);
    assert(!candidate_queue.empty());
    /*
     * Try to reconnect candidates to non-orphaned states for free.
     */
    while(!candidate_queue.empty()) {
        int state = candidate_queue.pop().second;
        bool reconnected = false;
        // try to reconnect to non-orphaned non-descendant state
        for (const Transition& t : out[state]) {
            if (orphans[t.target_id] == -1 
                && search_info[t.target_id].get_h_value() + operator_costs[t.op_id] 
                == search_info[state].get_h_value()) {
                shortest_path[state] = t;
                orphans[state] = -1;
                reconnected = true;
                break;
            }
        }
        if (!reconnected) {
            // flag as orphaned, add children
            orphans[state] = INF;
            shortest_path[state] = Transition();
            for (const Transition& t : in[state]) {
                if (shortest_path[t.target_id].target_id == state
                    && orphans[t.target_id] == -1) {
                    candidate_queue.push(search_info[t.target_id].get_h_value(), t.target_id);
                }
            }
        }
    }
    
    //from here, continue like alg1a
    
    //expand settled states by searching forward from orphaned to settled states
    int orph_count = 0;
    open_queue.clear();
    for (size_t state=0; state<orphans.size(); ++state) {
        if (orphans[state] == -1) continue; //ignore settled states
        int& old_g = orphans[state];
        for (const Transition& t : out[state]) {
            const int opcost = operator_costs[t.op_id];
            assert(opcost !=INF);
            if (orphans[t.target_id] == -1) {
                const int prev_g = search_info[t.target_id].get_h_value();
                const int new_g = (prev_g<INF) ? (prev_g+opcost) : INF;
                if (new_g < old_g) {
                    old_g = new_g;
                    shortest_path[state] = t;
                }
            }
        }
        //we have evaluated all transitions from settled states to this one; it doesn't get cheaper during this loop
        if(orphans[state] != INF) {
            open_queue.push(orphans[state], state);
            ++orph_count;
        }
    }
    orphan_ratio *= float(orph_count)/orphans.size();
    
    //start regular Dijkstra
    while(!open_queue.empty()) {
        pair<int,int> top_pair = open_queue.pop();
        const int g_value = top_pair.first;
        const int state = top_pair.second;
        if (g_value > orphans[state]) 
            continue;
        assert (g_value == orphans[state]);
        assert (g_value >= 0);
        assert (g_value != INF);
        search_info[state].increase_h_value_to(g_value);
        for (Transition t : in[state]) {
            const int opcost = operator_costs[t.op_id];
            assert(opcost >=0);
            assert(opcost != INF);
            if (orphans[t.target_id] != -1) {
                assert(INF-opcost > g_value);
                const int succ_g = g_value + opcost;
                if (succ_g < orphans[t.target_id]) {
                    orphans[t.target_id] = succ_g;
                    shortest_path[t.target_id] = Transition(t.op_id, state);
                    open_queue.push(succ_g, t.target_id);
                }
            }
        }
    }
    for (size_t i=0; i<orphans.size(); ++i)
        if (orphans[i]==INF) search_info[i].increase_h_value_to(INF);
}

void AbstractSearch::compute_full_distances(
    const std::vector<Transitions> &in,
    const std::unordered_set<int> &goals) {
    open_queue.clear();
    shortest_path.resize(in.size());
    vector<int> distances(in.size(), INF);
    for (int goal : goals) {
        distances[goal] = 0;
        shortest_path[goal] = Transition();
        open_queue.push(0, goal);
    }
     while (!open_queue.empty()) {
        pair<int, int> top_pair = open_queue.pop();
        int old_g = top_pair.first;
        int state_id = top_pair.second;

        const int g = distances[state_id];
        assert(0 <= g && g < INF);
        assert(g <= old_g);
        if (g < old_g)
            continue;
        assert(utils::in_bounds(state_id, in));
        for (const Transition &t : in[state_id]) {
            const int op_cost = operator_costs[t.op_id];
            assert(op_cost >= 0);
            int succ_g = (op_cost == INF) ? INF : g + op_cost;
            assert(succ_g >= 0);
            int succ_id = t.target_id;
            if (succ_g < distances[succ_id]) {
                distances[succ_id] = succ_g;
                shortest_path[succ_id] = Transition(t.op_id, state_id);
                open_queue.push(succ_g, succ_id);
            }
        }
    }
    set_h_values(distances);
}

bool AbstractSearch::test_distances(
    const std::vector<Transitions> &in,
    const std::unordered_set<int> &goals) {
    std::vector<int> distances = compute_distances(in, operator_costs, goals);
    for(size_t i=0; i<distances.size(); ++i)
        if(distances[i] != search_info[i].get_h_value()) {
            std::cout << "Discrepancy for state " << i << " of " 
                << distances.size() << ": h is " << search_info[i].get_h_value()  
                << ", should be " << distances[i] << std::endl;
            return false;
        }
    return true;
}


vector<int> compute_distances(
    const vector<Transitions> &transitions,
    const vector<int> &costs,
    const unordered_set<int> &start_ids) {
    vector<int> distances(transitions.size(), INF);
    priority_queues::AdaptiveQueue<int> open_queue;
    for (int goal_id : start_ids) {
        distances[goal_id] = 0;
        open_queue.push(0, goal_id);
    }
    while (!open_queue.empty()) {
        pair<int, int> top_pair = open_queue.pop();
        int old_g = top_pair.first;
        int state_id = top_pair.second;

        const int g = distances[state_id];
        assert(0 <= g && g < INF);
        assert(g <= old_g);
        if (g < old_g)
            continue;
        assert(utils::in_bounds(state_id, transitions));
        for (const Transition &transition : transitions[state_id]) {
            const int op_cost = costs[transition.op_id];
            assert(op_cost >= 0);
            int succ_g = (op_cost == INF) ? INF : g + op_cost;
            assert(succ_g >= 0);
            int succ_id = transition.target_id;
            if (succ_g < distances[succ_id]) {
                distances[succ_id] = succ_g;
                open_queue.push(succ_g, succ_id);
            }
        }
    }
    return distances;
}
}
