//
// Created by marvin on 20.11.18.
//

//#define BACKWARD_DEBUG

#include "fmm.h"
#include "fmm_sub.h"

#include "../open_list_factory.h"
#include "../option_parser.h"
#include "../evaluator.h"
#include "../tasks/backward_task.h"

#include "../task_utils/successor_generator.h"
#include "../pruning_method.h"
#include "../task_utils/task_properties.h"

using namespace std;

namespace fmm {

    // ----------------------- Utility Functions -----------------------
    string pair_to_string(FactPair pair){
        std::ostringstream oss;
        oss << "<" << pair.var << "," << pair.value << ">";
        return oss.str();
    }

    // ----------------------- Directional Search -----------------------
    FmmSubSearch::FmmSubSearch(const options::Options &opts, bool forwardFlag,
                                         const std::shared_ptr<AbstractTask> input_task)
            : SearchEngine(opts, input_task),
              reopen_closed_nodes(opts.get<bool>("reopen_closed")),
              forward_flag(forwardFlag),
              pruning(opts.get<bool>("mutex_pruning")),
              fraction_p(opts.get<double>("fraction_p")),
              pruning_method(opts.get<shared_ptr<PruningMethod>>("pruning")),
              f_evaluator(opts.get<shared_ptr<Evaluator>>("f_eval")),
              g_evaluator(opts.get<shared_ptr<Evaluator>>("g_eval")),
              open_list(opts.get<shared_ptr<OpenListFactory>>("open")->
                      create_state_open_list()){
        cout << "Initializing " << (forwardFlag?"forward":"backward")
             << " fMM sub search..." << endl;
    }

    void FmmSubSearch::initialize_public(){
        assert(open_list);

        #ifdef BACKWARD_DEBUG
        initial_dump();
        #endif

        const GlobalState &initial_state = state_registry.get_initial_state();
        f_evaluator->notify_initial_state(initial_state);

        EvaluationContext eval_context(initial_state, 0, true, &statistics);

        statistics.inc_evaluated_states();
        if (open_list->is_dead_end(eval_context)) {
            cout << "Initial state is a dead end." << endl;
        } else {
            if (search_progress.check_progress(eval_context))
                print_checkpoint_line(0);
            //start_f_value_statistics(eval_context);
            SearchNode node = search_space.get_node(initial_state);
            node.open_initial();

            open_list->insert(eval_context, initial_state.get_id());

            if (forward_flag){
                vector<int> key = initial_state.unpack().get_values(); //would need to pop key, if backward search.
                states.emplace(key, initial_state.get_id());
            }
        }
        print_initial_evaluator_values(eval_context);

        //pruning_method->initialize(task);
    }

    void FmmSubSearch::step(
            shared_ptr<FmmSubSearch> other_engine,
            shared_ptr<backward_tasks::BackwardTask> backward_task){

        assert(!open_list->empty());

        // Expand next node
        StateID id = open_list->remove_min();
        GlobalState s = state_registry.lookup_state(id);
        SearchNode node = search_space.get_node(s);

        if (node.is_closed())
            return;

        node.close();
        assert(!node.is_dead_end());
        statistics.inc_expanded();

        EvaluationContext eval_context(
                node.get_state(), node.get_real_g(), false, &statistics);
        forgetState(eval_context);
        #ifdef BACKWARD_DEBUG
        cout << (forward_flag?
                 "--> Expand state Forward:":"<-- Expand state Backward:")
                  << " g: " << g_evaluator->compute_result(eval_context).get_evaluator_value()
                  << " ,f: " << f_evaluator->compute_result(eval_context).get_evaluator_value() << endl;
        s.dump_fdr();
        #endif

        check_frontier_and_set_plan(s,move(other_engine), move(backward_task));

        // Generate Successors
        vector<OperatorID> applicable_ops;
        successor_generator.generate_applicable_ops(s, applicable_ops);

        for (OperatorID op_id : applicable_ops) {
            OperatorProxy op = task_proxy.get_operators()[op_id];
            if ((node.get_real_g() + op.get_cost()) >= bound)
                continue;

            GlobalState succ_state = state_registry.get_successor_state(s, op);
            vector<int> key = succ_state.unpack().get_values();
            if (!forward_flag){
                key.pop_back();
            }
            states.emplace(key, succ_state.get_id());

            #ifdef BACKWARD_DEBUG
            cout << "Generating state from op: " << op.get_name() << endl;
            succ_state.dump_fdr();
            #endif

            // prune fact mutexes
            if (pruning && !forward_flag && is_mutex(succ_state)){
                #ifdef BACKWARD_DEBUG
                cout << "Pruning illegal state!" << endl;
                #endif
                expanded_statistics.inc_pruned();
                continue;
            }

            statistics.inc_generated();

            SearchNode succ_node = search_space.get_node(succ_state);

            if (succ_node.is_dead_end())
                continue;

            if (succ_node.is_new()) {
                f_evaluator->notify_state_transition(s, op_id, succ_state);
                //g_evaluator->notify_state_transition(s, op_id, succ_state);

                int succ_g = node.get_g() + get_adjusted_cost(op);

                EvaluationContext eval_context(
                        succ_state, succ_g, false, &statistics);
                statistics.inc_evaluated_states();
                rememberState(eval_context);

                if (open_list->is_dead_end(eval_context)) {
                    succ_node.mark_as_dead_end();
                    statistics.inc_dead_ends();
                    continue;
                }
                succ_node.open(node, op, get_adjusted_cost(op));
                open_list->insert(eval_context, succ_state.get_id());
                if (search_progress.check_progress(eval_context)) {
                    print_checkpoint_line(succ_node.get_g());
                }
            } else if (succ_node.get_g() > node.get_g() + get_adjusted_cost(op)) {
                // We found a new cheapest path to an open or closed state.
                if (reopen_closed_nodes) {
                    EvaluationContext eval_context_temp(
                            succ_state, succ_node.get_g(), false, &statistics);
                    forgetState(eval_context_temp);

                    if (succ_node.is_closed()) {
                        statistics.inc_reopened();
                    }
                    succ_node.reopen(node, op, get_adjusted_cost(op));

                    EvaluationContext eval_context(
                            succ_state, succ_node.get_g(), false, &statistics);

                    rememberState(eval_context);
                    open_list->insert(eval_context, succ_state.get_id());
                } else {
                    succ_node.update_parent(node, op, get_adjusted_cost(op));
                }
            }
        }
        // Save the generated in the first expansion.
        if (!forward_flag && expanded_statistics.get_real_goal_states() == 0){
            expanded_statistics.inc_real_goal_states(statistics.get_generated());
            // Reset Statistic
            statistics.inc_expanded(-statistics.get_expanded());
            statistics.inc_evaluated_states(-statistics.get_evaluated_states());
            statistics.inc_generated(-statistics.get_generated());
        }
    }

    /**
     * Converts given state from the other search direction, to this search.
     * @param state Other search state.
     * @return ID of the state in this search.
     */
    pair<SearchNode, bool> FmmSubSearch::get_node_and_check(const GlobalState &state){
        vector<int> key = state.unpack().get_values();
        if (forward_flag) { //if forward, then state is from backward.
            key.pop_back();
        }
        if (states.count(key) > 0){
            StateID id = states.at(key);
            GlobalState other_state = state_registry.lookup_state(id);
            return make_pair(search_space.get_node(other_state),true);
        } else {
            SearchNode dummy_node = search_space.get_node(state_registry.get_initial_state());
            return make_pair(dummy_node, false);
        }
    }

    void FmmSubSearch::check_frontier_and_set_plan(const GlobalState &state,
                                                        shared_ptr<FmmSubSearch> other_engine,
                                                        shared_ptr<backward_tasks::BackwardTask> backward_task) {

        SearchSpace & other_space = other_engine->get_search_space();
        pair<SearchNode, bool> n = other_engine->get_node_and_check(state);
        if (!n.second) {
            // The frontiers didn't meet.
            return;
        }
        SearchNode other_node = n.first;

        // If other_node was already generated -> build path and safe it.
        if (!other_node.is_new()){
            expanded_statistics.inc_frontier_meetings();
            expanded_statistics.set_expansions_before_first_meeting(statistics.get_expanded());

            Plan forward;
            Plan backward;
            search_space.trace_path(state, (forward_flag?forward:backward));
            other_space.trace_path(other_node.get_state(), (forward_flag?backward:forward));

            reverse(backward.begin(), backward.end());

            OperatorID artificial_goal = OperatorID(-1);
            for(const auto &op : backward){
                OperatorID id = backward_task->get_forward_op_id(op);
                if (id != artificial_goal){
                    forward.push_back(id);
                }
            }

            int plan_cost = calculate_plan_cost(forward,task_proxy);

            if (current_plan_cost < 0 || plan_cost < current_plan_cost){
                set_plan_with_cost(forward, plan_cost);
                other_engine->set_plan_with_cost(forward, plan_cost);
                set_current_plan(state, other_node.get_state(), other_engine);
            }
        }
    }

    void FmmSubSearch::set_current_plan(const GlobalState &state, const GlobalState &other_state, shared_ptr<FmmSubSearch> other_engine){
        SearchNode node = search_space.get_node(state);
        SearchNode other_node = other_engine->get_search_space().get_node(other_state);
        EvaluationContext eval_context(state, node.get_real_g(), false, NULL);
        EvaluationContext other_eval_context(other_state, other_node.get_real_g(), false, NULL);
        current_plan_g = node.get_real_g();
        current_plan_h = f_evaluator->compute_result(eval_context).get_evaluator_value();
        other_engine->current_plan_g = other_node.get_real_g();
        other_engine->current_plan_h = other_engine->f_evaluator->compute_result(other_eval_context).get_evaluator_value();
    }

    bool FmmSubSearch::is_mutex(const GlobalState &state){
        for (size_t i = 0; i < state.unpack().size(); i++){
            FactPair a(i, state[i]);
            for (size_t j = 0; j < state.unpack().size(); j++){
                if (i == j) continue;
                FactPair b(j, state[j]);
                if (task->are_facts_mutex(a,b)){
                    return true;
                }
            }
        }
        return false;
    }

    int FmmSubSearch::calculate_plan_cost(
            const Plan &plan, const TaskProxy &task_proxy) {
        OperatorsProxy operators = task_proxy.get_operators();
        int plan_cost = 0;
        for (OperatorID op_id : plan) {
            plan_cost += operators[op_id].get_cost();
        }
        return plan_cost;
    }

    void FmmSubSearch::set_plan_with_cost(const Plan &p, int cost){
        set_plan(p);
        current_plan_cost = cost;
    }

    double FmmSubSearch::get_lowest_open_value(){
        if (open_list->empty()){
            return fmm::INFTY;
        }

        StateID id = open_list->remove_min();
        GlobalState s = state_registry.lookup_state(id);
        SearchNode node = search_space.get_node(s);

        EvaluationContext eval_context(
                s, node.get_g(), false, &statistics);

        open_list->insert(eval_context, s.get_id());

        int f = f_evaluator->compute_result(eval_context).get_evaluator_value();
        int g = g_evaluator->compute_result(eval_context).get_evaluator_value();
//        cout << "Get Lowest: " << f << ", " << g << endl;
//        cout << "Max: " << max(g*fraction_p, static_cast<double>(f)) << endl;
//        cout << "Fraction:" << fraction_p << endl;
        return max((fraction_p==fmm::INFTY)?fraction_p:g*fraction_p, static_cast<double>(f));
    }

    void FmmSubSearch::rememberState(EvaluationContext &context){
        int g = g_evaluator->compute_result(context).get_evaluator_value();
        int f = f_evaluator->compute_result(context).get_evaluator_value();
        min_g[g] = min_g[g] + 1;
        min_f[f] = min_f[f] + 1;
    }

    void FmmSubSearch::forgetState(EvaluationContext &context){
        int g = g_evaluator->compute_result(context).get_evaluator_value();
        int f = f_evaluator->compute_result(context).get_evaluator_value();
        if (min_g[g] <= 1){
            min_g.erase(g);
        } else {
            min_g[g] = min_g[g] - 1;
        }
        if (min_f[f] <= 1){
            min_f.erase(f);
        } else {
            min_f[f] = min_f[f] - 1;
        }
    }

// ----------------------- Getter / Setter -----------------------

    SearchSpace& FmmSubSearch::get_search_space(){
        return search_space;
    }

//----------------------- Output and Statistics -----------------------

    void FmmSubSearch::update_f_value_statistics(int clb) {
        statistics.report_f_value_progress(clb);
    }

    void FmmSubSearch::print_checkpoint_line(int g) const {
        cout << "[g=" << g << ", ";
        statistics.print_basic_statistics();
        cout << "]" << endl;
    }

    void FmmSubSearch::initial_dump() const{
        cout << "Open List is empty: " << open_list->empty() << endl;
        cout << "Task Proxy things: " << ((forward_flag)?"Forward":"Backward") << endl;
        cout << "Goals: ";
        for (const auto & goal : task_proxy.get_goals()) {
            cout << "<" << goal.get_pair().var << "," << goal.get_pair().value << ">, ";
        }
        cout << endl << "--" << endl;
        for (const auto &f : task_proxy.get_variables().get_facts()){
            cout << "Fact: " << f.get_name() << endl;
            cout << pair_to_string(f.get_pair()) << endl;
        }
        for (const auto &f : task_proxy.get_operators()){
            cout << "Fact: " << f.get_name() << endl;
            cout << "pre: ";
            for (const auto &pre : f.get_preconditions()){
                cout << pair_to_string(pre.get_pair()) << " + ";
            }
            cout << endl << "eff: ";
            for (const auto &eff : f.get_effects()){
                cout << pair_to_string(eff.get_fact().get_pair()) << " + ";
            }
            cout << endl;

        }
        cout << "End Task Proxy things!" << endl;
    }

    void FmmSubSearch::printDirectionalStatistic(
            shared_ptr<backward_tasks::BackwardTask> backward) const{
        string prefix = (forward_flag?"F_":"B_");
        cout << prefix << "Actions: ";
        if (forward_flag){
            cout << task_proxy.get_operators().size();
        } else {
            cout << backward->get_num_operators_without_artificial();
        }
        cout << "." << endl;

        double branching = (statistics.get_expanded()==0)?0:(double)statistics.get_generated()/statistics.get_expanded();
        cout << prefix << "Branching: " << branching
             << " generations per expansion." << endl;
        statistics.print_detailed_statistics_prefix(prefix);
    }
}