#include "bsw_search.h"

#include "search_common.h"

#include "../evaluation_context.h"
#include "../globals.h"
#include "../heuristic.h"
#include "../option_parser.h"
#include "../plugin.h"
#include "../pruning_method.h"
#include "../successor_generator.h"

#include "../algorithms/ordered_set.h"

#include "../open_lists/open_list_factory.h"

#include <cassert>
#include <cstdlib>
#include <memory>
#include <set>

using namespace std;

namespace bsw_search {
    BswSearch::BswSearch(const Options &opts)
            : SearchEngine(opts),
              reopen_closed_nodes(opts.get<bool>("reopen_closed")),
              w(opts.get<double>("w")),//1.f),
              open_list(opts.get<shared_ptr<OpenListFactory>>("open")->
                      create_state_open_list()),
              cleanup_list(opts.get<shared_ptr<OpenListFactory>>("cleanup")->
                      create_state_open_list()),
              focal_list(opts.get<shared_ptr<OpenListFactory>>("focal")->
                      create_state_open_list()),
              focal_backup(opts.get<shared_ptr<OpenListFactory>>("focal_backup")->
                      create_state_open_list()),
              f_evaluator(opts.get<ScalarEvaluator *>("f", nullptr)),
              f_inad_evaluator(opts.get<ScalarEvaluator *>("f_inad", nullptr)),
              preferred_operator_heuristics(opts.get_list<Heuristic *>("preferred")),
              pruning_method(opts.get<shared_ptr<PruningMethod>>("pruning"))
              {
    }

    void BswSearch::initialize() {
        cout << "Conducting best first search"
             << (reopen_closed_nodes ? " with" : " without")
             << " reopening closed nodes, (real) bound = " << bound
             << endl;

        assert(open_list);
        cout << "w: " << w << endl;

        set<Heuristic *> hset;
        open_list->get_involved_heuristics(hset);

        // add heuristics that are used for preferred operators (in case they are
        // not also used in the open list)
        hset.insert(preferred_operator_heuristics.begin(),
                    preferred_operator_heuristics.end());

        // add heuristics that are used in the f_evaluator. They are usually also
        // used in the open list and hence already be included, but we want to be
        // sure.
        if (f_evaluator) {
            f_evaluator->get_involved_heuristics(hset);
        }

        heuristics.assign(hset.begin(), hset.end());
        assert(!heuristics.empty());

        const GlobalState &initial_state = state_registry.get_initial_state();
        for (Heuristic *heuristic : heuristics) {
            heuristic->notify_initial_state(initial_state);
        }

        // Note: we consider the initial state as reached by a preferred
        // operator.
        EvaluationContext eval_context(initial_state, 0, true, &statistics);

        statistics.inc_evaluated_states();

        if (open_list->is_dead_end(eval_context)) {
            cout << "Initial state is a dead end." << endl;
        } else {
            if (search_progress.check_progress(eval_context))
                print_checkpoint_line(0);
            start_f_value_statistics(eval_context);
            SearchNode node = search_space.get_node(initial_state);
            node.open_initial();

            open_list->insert(eval_context, initial_state.get_id());
            cleanup_list->insert(eval_context,initial_state.get_id());
            focal_backup->insert(eval_context,initial_state.get_id());

        }

        print_initial_h_values(eval_context);

        pruning_method->initialize(g_root_task());
    }

    void BswSearch::print_checkpoint_line(int g) const {
        cout << "[g=" << g << ", ";
        statistics.print_basic_statistics();
        cout << "]" << endl;
    }

    void BswSearch::print_statistics() const {
        statistics.print_detailed_statistics();
        search_space.print_statistics();
        pruning_method->print_statistics();
    }

    SearchStatus BswSearch::step() {
        pair<SearchNode, bool> n = fetch_next_node();
        if (!n.second) {
            return FAILED;
        }
        SearchNode node = n.first;

        GlobalState s = node.get_state();
        if (check_goal_and_set_plan(s))
            return SOLVED;

        vector<const GlobalOperator *> applicable_ops;
        g_successor_generator->generate_applicable_ops(s, applicable_ops);

        /*
          TODO: When preferred operators are in use, a preferred operator will be
          considered by the preferred operator queues even when it is pruned.
        */
        pruning_method->prune_operators(s, applicable_ops);

        // This evaluates the expanded state (again) to get preferred ops
        EvaluationContext eval_context(s, node.get_g(), false, &statistics, true);
        algorithms::OrderedSet<const GlobalOperator *> preferred_operators =
                collect_preferred_operators(eval_context, preferred_operator_heuristics);

        for (const GlobalOperator *op : applicable_ops) {
            if ((node.get_real_g() + op->get_cost()) >= bound)
                continue;

            GlobalState succ_state = state_registry.get_successor_state(s, *op);

            statistics.inc_generated();
            bool is_preferred = preferred_operators.contains(op);

            SearchNode succ_node = search_space.get_node(succ_state);

            // Previously encountered dead end. Don't re-evaluate.
            if (succ_node.is_dead_end())
                continue;


            if (succ_node.is_new()) {
                // We have not seen this state before.
                // Evaluate and create a new node.

                // Careful: succ_node.get_g() is not available here yet,
                // hence the stupid computation of succ_g.
                // TODO: Make this less fragile.
                int succ_g = node.get_g() + get_adjusted_cost(*op);

                EvaluationContext eval_context(
                        succ_state, succ_g, is_preferred, &statistics);
                statistics.inc_evaluated_states();

                if (open_list->is_dead_end(eval_context)) {
                    succ_node.mark_as_dead_end();
                    statistics.inc_dead_ends();
                    continue;
                }
                succ_node.open(node, op);

                open_list->insert(eval_context, succ_state.get_id());
                focal_backup->insert(eval_context, succ_state.get_id());
                cleanup_list->insert(eval_context, succ_state.get_id());

                if (search_progress.check_progress(eval_context)) {
                    print_checkpoint_line(succ_node.get_g());
                    reward_progress();
                }
            } else if (succ_node.get_g() > node.get_g() + get_adjusted_cost(*op)) {
                // We found a new cheapest path to an open or closed state.
                if (reopen_closed_nodes) {
                    if (succ_node.is_closed()) {
                        /*
                          TODO: It would be nice if we had a way to test
                          that reopening is expected behaviour, i.e., exit
                          with an error when this is something where
                          reopening should not occur (e.g. A* with a
                          consistent heuristic).
                        */
                        statistics.inc_reopened();
                    }
                    succ_node.reopen(node, op);

                    EvaluationContext eval_context(
                            succ_state, succ_node.get_g(), is_preferred, &statistics);

                    /*
                      Note: our old code used to retrieve the h value from
                      the search node here. Our new code recomputes it as
                      necessary, thus avoiding the incredible ugliness of
                      the old "set_evaluator_value" approach, which also
                      did not generalize properly to settings with more
                      than one heuristic.

                      Reopening should not happen all that frequently, so
                      the performance impact of this is hopefully not that
                      large. In the medium term, we want the heuristics to
                      remember heuristic values for states themselves if
                      desired by the user, so that such recomputations
                      will just involve a look-up by the Heuristic object
                      rather than a recomputation of the heuristic value
                      from scratch.
                    */
                    open_list->insert(eval_context, succ_state.get_id());
                    focal_backup->insert(eval_context, succ_state.get_id());
                    cleanup_list->insert(eval_context,succ_state.get_id());
                } else {
                    // If we do not reopen closed nodes, we just update the parent pointers.
                    // Note that this could cause an incompatibility between
                    // the g-value and the actual path that is traced back.
                    succ_node.update_parent(node, op);
                }
            }
        }

        return IN_PROGRESS;
    }

    pair<SearchNode, bool> BswSearch::fetch_next_node() {
        // Find best f_inad
        auto from_open = fetch_best_from(open_list);

        SearchNode best_open = get<0>(from_open);
        EvaluationContext eval_context_best_open = get<1>(from_open);

        bool success = get<2>(from_open);
        if(!success)
            return make_pair(best_open, false);

        int bound_value = eval_context_best_open.get_heuristic_value(f_inad_evaluator);


        // Find best f
        auto from_cleanup = fetch_best_from(cleanup_list);
        SearchNode best_cleanup = get<0>(from_cleanup);
        EvaluationContext eval_context_best_cleanup = get<1>(from_cleanup);

        success = get<2>(from_cleanup);
        if(!success)
            return make_pair(best_cleanup, false);

        int best_cleanup_value = eval_context_best_cleanup.get_heuristic_value(f_evaluator);


        // Push from backup to focal
        while(true) {
            if (focal_backup->empty()) {
                break;
            }
            vector<int> key;
            StateID id = focal_backup->remove_min(&key);
            GlobalState s = state_registry.lookup_state(id);
            SearchNode node = search_space.get_node(s);
            if(node.is_closed())
                continue;

            EvaluationContext eval_context(s, node.get_g(), false, &statistics);
            if (key[0] <= int(w * bound_value)) {
                focal_list->insert(eval_context, id);
            } else {
                focal_backup->insert(eval_context, id);
                break;
            }
        }

        while(true){
            if(focal_list->empty())
                break;

            vector<int> key;
            StateID id = focal_list->remove_min(&key);

            GlobalState s = state_registry.lookup_state(id);
            SearchNode node = search_space.get_node(s);

            if (node.is_closed())
                continue;

            EvaluationContext eval_context(s, node.get_g(), false, &statistics);
            int node_inad_value = eval_context.get_heuristic_value(f_inad_evaluator);

            if(node_inad_value>=int(w*best_cleanup_value)) {
                focal_list->insert(eval_context,id);
                break;
            }

            open_list->insert(eval_context_best_open,best_open.get_state_id());
            cleanup_list->insert(eval_context_best_cleanup,best_cleanup.get_state_id());
            node.close();
            assert(!node.is_dead_end());
            update_f_value_statistics(node);
            statistics.inc_expanded();
            return make_pair(node, true);
        }

        if(bound_value>=int(w*best_cleanup_value)){
            open_list->insert(eval_context_best_open,best_open.get_state_id());
            best_cleanup.close();
            update_f_value_statistics(best_cleanup);
            statistics.inc_expanded();
            return make_pair(best_cleanup, true);
        }

        cleanup_list->insert(eval_context_best_cleanup,best_cleanup.get_state_id());
        best_open.close();
        update_f_value_statistics(best_open);
        statistics.inc_expanded();
        return make_pair(best_open, true);
    }

    tuple<SearchNode, EvaluationContext, bool> BswSearch::fetch_best_from(unique_ptr<StateOpenList> const& current_list){
        while (true) {
            if (current_list->empty()) {
                cout << "Completely explored state space -- no solution!" << endl;
                // HACK! HACK! we do this because SearchNode has no default/copy constructor
                const GlobalState &initial_state = state_registry.get_initial_state();
                SearchNode dummy_node = search_space.get_node(initial_state);
                EvaluationContext dummy_context(
                        dummy_node.get_state(), dummy_node.get_g(), false, &statistics);
                return make_tuple(dummy_node, dummy_context, false);
            }
            StateID id = current_list->remove_min(nullptr);

            GlobalState s = state_registry.lookup_state(id);
            SearchNode node = search_space.get_node(s);


            if (node.is_closed())
                continue;

            EvaluationContext eval_context(s, node.get_g(), false, &statistics);

            assert(!node.is_dead_end());
            return make_tuple(node,eval_context,true);
        }
    }

    void BswSearch::reward_progress() {
        // Boost the "preferred operator" open lists somewhat whenever
        // one of the heuristics finds a state with a new best h value.
        open_list->boost_preferred();
    }

    void BswSearch::dump_search_space() const {
        search_space.dump();
    }

    void BswSearch::start_f_value_statistics(EvaluationContext &eval_context) {
        if (f_evaluator) {
            int f_value = eval_context.get_heuristic_value(f_evaluator);
            statistics.report_f_value_progress(f_value);
        }
    }

/* TODO: HACK! This is very inefficient for simply looking up an h value.
   Also, if h values are not saved it would recompute h for each and every state. */
    void BswSearch::update_f_value_statistics(const SearchNode &node) {
        if (f_evaluator) {
            /*
              TODO: This code doesn't fit the idea of supporting
              an arbitrary f evaluator.
            */
            EvaluationContext eval_context(node.get_state(), node.get_g(), false, &statistics);
            int f_value = eval_context.get_heuristic_value(f_evaluator);
            statistics.report_f_value_progress(f_value);
        }
    }

/* TODO: merge this into SearchEngine::add_options_to_parser when all search
         engines support pruning. */
    void add_pruning_option(OptionParser &parser) {
        parser.add_option<shared_ptr<PruningMethod>>(
                "pruning",
                        "Pruning methods can prune or reorder the set of applicable operators in "
                                "each state and thereby influence the number and order of successor states "
                                "that are considered.",
                        "null()");
    }

    static SearchEngine *_parse_bsw(OptionParser &parser) {
        parser.add_option<ScalarEvaluator *>("eval", "evaluator for h-value");
        parser.add_option<ScalarEvaluator *>("inad", "evaluator for h-value");
        parser.add_option<ScalarEvaluator *>("exp", "evaluator for h-value");

        parser.add_option<double>("w", "weight (quality bound)", "1.0");

        add_pruning_option(parser);
        SearchEngine::add_options_to_parser(parser);
        Options opts = parser.parse();

        BswSearch *engine = nullptr;
        if (!parser.dry_run()) {
            auto temp = search_common::create_bsw_open_list_factory_and_f_eval(opts);
            opts.set("open", get<0>(temp));
            opts.set("cleanup",get<1>(temp));
            opts.set("focal", get<2>(temp));
            opts.set("focal_backup", get<3>(temp));
            opts.set("f_inad", get<4>(temp));
            opts.set("f", get<5>(temp));
            opts.set("w",get<6>(temp));
            opts.set("reopen_closed", true);
            vector<Heuristic *> preferred_list;
            opts.set("preferred", preferred_list);
            engine = new BswSearch(opts);
        }

        return engine;
    }

    static Plugin<SearchEngine> _plugin_greedy("bsw", _parse_bsw);
}
