#include <limits>
#include <queue>
#include <functional>
#include "abstract_solution_data.h"
#include "../task_utils/task_properties.h"
#include "../utils/rng.h"

using namespace std;

namespace pdbs {


AbstractSolutionData::AbstractSolutionData(
        const std::shared_ptr<AbstractTask> &parent,
        Pattern &pattern,
        const std::shared_ptr<utils::RandomNumberGenerator> &rng)
        : concrete_task_proxy(*parent),
          pdb(new PatternDatabase(concrete_task_proxy,pattern)),
          abstracted_task(parent, pdb->get_pattern()),
          abs_task_proxy(abstracted_task),
          is_solvable(true),
          op_index(0) {

    successor_generator::SuccessorGenerator succ_gen(abs_task_proxy);

    State init = abs_task_proxy.get_initial_state();

    if(pdb->get_value_abstracted(init) == numeric_limits<int>::max()) {
        is_solvable = false;
        return;
    }

    if (abstracted_task.has_zero_cost_operators()) {
        find_plan_astar(init, succ_gen, rng);
    } else {
        find_plan_greedy(init, succ_gen, rng);
    }


    print_plan();
}

void AbstractSolutionData::find_plan_greedy(
        State &init,
        const successor_generator::SuccessorGenerator& succ_gen,
        const std::shared_ptr<utils::RandomNumberGenerator> &rng) {
    State &current = init;
    // plan extraction
    int h = numeric_limits<int>::max();
    while(pdb->get_value_abstracted(current) != 0) {
        // generate list of operators that are applicable in the current state
        std::vector<OperatorID> applicable_ops;
        succ_gen.generate_applicable_ops(current, applicable_ops);
        rng->shuffle(applicable_ops);

        // find the operator that leads to the
        // successor state with the lowest h value
        OperatorID best = applicable_ops[0];
        for(OperatorID op_id : applicable_ops) {
            OperatorProxy op = abs_task_proxy.get_operators()[op_id];
            State successor = current.get_successor(op);
            int new_h = pdb->get_value_abstracted(successor);
            if(new_h < h) {
                h = new_h;
                best = op_id;
                if(h == 0) break;
            }
        }

        // add best operator to our plan
        OperatorProxy best_op = abs_task_proxy.get_operators()[best];
        current = current.get_successor(best_op);
        plan.push_back(best);
    }
}

void AbstractSolutionData::find_plan_astar(
        State &init,
        const successor_generator::SuccessorGenerator& succ_gen,
        const std::shared_ptr<utils::RandomNumberGenerator> &rng) {
    std::unordered_set<int> closed;
    auto ptr_comparator = [] (shared_ptr<AstarSearchNode> a,
                              shared_ptr<AstarSearchNode> b) -> bool {
        if(a->get_f() == b->get_f()) {
            return a->get_h() > b->get_h();
        } else {
            return a->get_f() > b->get_f();
        }
    };
    std::priority_queue<
            shared_ptr<AstarSearchNode>,
            vector<shared_ptr<AstarSearchNode>>,
            function<bool(
                    shared_ptr<AstarSearchNode>,
                    shared_ptr<AstarSearchNode>)> > open(ptr_comparator);


    OperatorID dummy_id(-1);
    open.push(make_shared<AstarSearchNode>(
            dummy_id, init, 0, pdb->get_value_abstracted(init), nullptr));

    // plan extraction
    while (!open.empty()) {
        auto curr = open.top();
        open.pop();

        if (!closed.count(pdb->get_abstract_state_index(curr->get_state()))) {
            closed.insert(pdb->get_abstract_state_index(curr->get_state()));

            if (task_properties::is_goal_state(abs_task_proxy, curr->get_state())) {
                // extract plan
                for (shared_ptr<AstarSearchNode> n = curr; n != nullptr; n = n->get_prev()) {
                    plan.push_front(n->get_operator_id());
                }
                plan.pop_front(); // remove root node, as it is not associated with any operator
                return;
            }

            std::vector<OperatorID> applicable_ops;
            succ_gen.generate_applicable_ops(curr->get_state(), applicable_ops);
            rng->shuffle(applicable_ops);
            for(auto a : applicable_ops) {
                OperatorProxy op = abs_task_proxy.get_operators()[a];
                State sc = curr->get_state().get_successor(op);
                int h = pdb->get_value_abstracted(sc);
                int g = curr->get_g()+op.get_cost();

                // TODO: cache h value of init
                if (g+h != pdb->get_value_abstracted(init)) continue;

                auto n_ = make_shared<AstarSearchNode>(a, sc, g, h, curr);
                open.push(n_);
            }
        }

    }

    utils::exit_with(utils::ExitCode::UNSOLVABLE);
}

void AbstractSolutionData::print_plan() const {
    cout << "CEGAR_PDBs: " << "the plan for pattern ";
    print_pattern();
    cout << " is " << endl;
    for(OperatorID opid : plan) {
        OperatorProxy op = abs_task_proxy.get_operators()[opid];
        cout << "CEGAR_PDBs: " << op.get_name() << " " << endl;
    }
}

}
