#include "pattern_database_factory.h"

#include "match_tree.h"
#include "pattern_database.h"

#include "../algorithms/priority_queues.h"
#include "../task_utils/task_properties.h"
#include "../utils/collections.h"
#include "../utils/logging.h"
#include "../utils/math.h"
#include "../utils/timer.h"

#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <limits>
#include <string>
#include <vector>

using namespace std;

namespace pdbs {
    class AbstractOperator {
        /*
          This class represents an abstract operator how it is needed for
          the regression search performed during the PDB-construction. As
          all abstract states are represented as a number, abstract
          operators don't have "usual" effects but "hash effects", i.e. the
          change (as number) the abstract operator implies on a given
          abstract state.
        */

        int cost;

        int original_operator_id;

        /*
          Preconditions for the regression search, corresponds to normal
          effects and prevail of concrete operators.
        */
        std::vector<FactPair> regression_preconditions;

        /*
          Effect of the operator during regression search on a given
          abstract state number.
        */
        std::size_t hash_effect;
    public:
        /*
          Abstract operators are built from concrete operators. The
          parameters follow the usual name convention of SAS+ operators,
          meaning prevail, preconditions and effects are all related to
          progression search.
        */
        AbstractOperator(const std::vector<FactPair> &prevail,
                         const std::vector<FactPair> &preconditions,
                         const std::vector<FactPair> &effects,
                         int cost,
                         const std::vector<std::size_t> &hash_multipliers,
                         int original_operator_id);
        ~AbstractOperator();

        /*
          Returns variable value pairs which represent the preconditions of
          the abstract operator in a regression search
        */
        const std::vector<FactPair> &get_regression_preconditions() const {
            return regression_preconditions;
        }


        const int get_original_op_id() const {
            return original_operator_id;
        }

        /*
          Returns the effect of the abstract operator in form of a value
          change (+ or -) to an abstract state index
        */
        std::size_t get_hash_effect() const {return hash_effect;}

        /*
          Returns the cost of the abstract operator (same as the cost of
          the original concrete operator)
        */
        int get_cost() const {return cost;}
        void dump(const Pattern &pattern,
                  const VariablesProxy &variables) const;
    };

    AbstractOperator::AbstractOperator(const vector<FactPair> &prev_pairs,
                                       const vector<FactPair> &pre_pairs,
                                       const vector<FactPair> &eff_pairs,
                                       int cost,
                                       const vector<size_t> &hash_multipliers,
                                       int original_operator_id)
            : cost(cost),
              regression_preconditions(prev_pairs),
              original_operator_id(original_operator_id) {
        regression_preconditions.insert(regression_preconditions.end(),
                                        eff_pairs.begin(),
                                        eff_pairs.end());
        // Sort preconditions for MatchTree construction.
        sort(regression_preconditions.begin(), regression_preconditions.end());
        for (size_t i = 1; i < regression_preconditions.size(); ++i) {
            assert(regression_preconditions[i].var !=
                   regression_preconditions[i - 1].var);
        }
        hash_effect = 0;
        assert(pre_pairs.size() == eff_pairs.size());
        for (size_t i = 0; i < pre_pairs.size(); ++i) {
            int var = pre_pairs[i].var;
            assert(var == eff_pairs[i].var);
            int old_val = eff_pairs[i].value;
            int new_val = pre_pairs[i].value;
            assert(new_val != -1);
            size_t effect = (new_val - old_val) * hash_multipliers[var];
            hash_effect += effect;
        }
    }

    AbstractOperator::~AbstractOperator() {
    }

    void AbstractOperator::dump(const Pattern &pattern,
                                const VariablesProxy &variables) const {
        utils::g_log << "AbstractOperator:" << endl;
        utils::g_log << "Regression preconditions:" << endl;
        for (size_t i = 0; i < regression_preconditions.size(); ++i) {
            int var_id = regression_preconditions[i].var;
            int val = regression_preconditions[i].value;
            utils::g_log << "Variable: " << var_id << " (True name: "
                         << variables[pattern[var_id]].get_name()
                         << ", Index: " << i << ") Value: " << val << endl;
        }
        utils::g_log << "Hash effect:" << hash_effect << endl;
    }

/*
  Computes all abstract operators, builds the match tree (successor
  generator) and then does a Dijkstra regression search to compute
  all final h-values (stored in distances). operator_costs can
  specify individual operator costs for each operator for action
  cost partitioning. If left empty, default operator costs are used.
*/
    class PatternDatabaseFactory {
        const TaskProxy &task_proxy;
        const Pattern &pattern;
        bool dump;
        const std::vector<int> &operator_costs;
        std::vector<std::size_t> hash_multipliers;
        bool save_operator_transitions = false;

        /*
          Recursive method; called by build_abstract_operators. In the case
          of a precondition with value = -1 in the concrete operator, all
          multiplied out abstract operators are computed, i.e. for all
          possible values of the variable (with precondition = -1), one
          abstract operator with a concrete value (!= -1) is computed.
        */
        void multiply_out(
                int pos, int cost,
                std::vector<FactPair> &prev_pairs,
                std::vector<FactPair> &pre_pairs,
                std::vector<FactPair> &eff_pairs,
                const std::vector<FactPair> &effects_without_pre,
                const VariablesProxy &variables,
                std::vector<AbstractOperator> &operators,
                int original_op_id);

        /*
          Computes all abstract operators for a given concrete operator (by
          its global operator number). Initializes data structures for initial
          call to recursive method multiply_out. variable_to_index maps
          variables in the task to their index in the pattern or -1.
        */
        void build_abstract_operators(
                const OperatorProxy &op, int cost,
                const std::vector<int> &variable_to_index,
                const VariablesProxy &variables,
                std::vector<AbstractOperator> &operators);

        /*
          For a given abstract state (given as index), the according values
          for each variable in the state are computed and compared with the
          given pairs of goal variables and values. Returns true iff the
          state is a goal state.
        */
        bool is_goal_state(
                std::size_t state_index,
                const std::vector<FactPair> &abstract_goals,
                const VariablesProxy &variables) const;
    public:
        PatternDatabaseFactory(
                const TaskProxy &task_proxy,
                const Pattern &pattern,
                bool dump,
                const std::vector<int> &operator_costs,
                bool save_operator_transitions);
        ~PatternDatabaseFactory() = default;
        shared_ptr<PatternDatabase> generate();
    };

    PatternDatabaseFactory::PatternDatabaseFactory(
            const TaskProxy &task_proxy,
            const Pattern &pattern,
            bool dump,
            const std::vector<int> &operator_costs,
            bool save_operator_transitions)
            : task_proxy(task_proxy),
              pattern(pattern),
              dump(dump),
              operator_costs(operator_costs),
              save_operator_transitions(save_operator_transitions) {
    }

    void PatternDatabaseFactory::multiply_out(
            int pos, int cost, vector<FactPair> &prev_pairs,
            vector<FactPair> &pre_pairs,
            vector<FactPair> &eff_pairs,
            const vector<FactPair> &effects_without_pre,
            const VariablesProxy &variables,
            vector<AbstractOperator> &operators,
            int original_op_id) {
        if (pos == static_cast<int>(effects_without_pre.size())) {
            // All effects without precondition have been checked: insert op.
            if (!eff_pairs.empty()) {
                operators.push_back(
                        AbstractOperator(prev_pairs, pre_pairs, eff_pairs, cost,
                                         hash_multipliers, original_op_id));
            }
        } else {
            // For each possible value for the current variable, build an
            // abstract operator.
            int var_id = effects_without_pre[pos].var;
            int eff = effects_without_pre[pos].value;
            VariableProxy var = variables[pattern[var_id]];
            for (int i = 0; i < var.get_domain_size(); ++i) {
                if (i != eff) {
                    pre_pairs.emplace_back(var_id, i);
                    eff_pairs.emplace_back(var_id, eff);
                } else {
                    prev_pairs.emplace_back(var_id, i);
                }
                multiply_out(pos + 1, cost, prev_pairs, pre_pairs, eff_pairs,
                             effects_without_pre, variables, operators, original_op_id);
                if (i != eff) {
                    pre_pairs.pop_back();
                    eff_pairs.pop_back();
                } else {
                    prev_pairs.pop_back();
                }
            }
        }
    }

    void PatternDatabaseFactory::build_abstract_operators(
            const OperatorProxy &op, int cost,
            const vector<int> &variable_to_index,
            const VariablesProxy &variables,
            vector<AbstractOperator> &operators) {
        // All variable value pairs that are a prevail condition
        vector<FactPair> prev_pairs;
        // All variable value pairs that are a precondition (value != -1)
        vector<FactPair> pre_pairs;
        // All variable value pairs that are an effect
        vector<FactPair> eff_pairs;
        // All variable value pairs that are a precondition (value = -1)
        vector<FactPair> effects_without_pre;

        size_t num_vars = variables.size();
        vector<bool> has_precond_and_effect_on_var(num_vars, false);
        vector<bool> has_precondition_on_var(num_vars, false);

        for (FactProxy pre : op.get_preconditions())
            has_precondition_on_var[pre.get_variable().get_id()] = true;

        for (EffectProxy eff : op.get_effects()) {
            int var_id = eff.get_fact().get_variable().get_id();
            int pattern_var_id = variable_to_index[var_id];
            int val = eff.get_fact().get_value();
            if (pattern_var_id != -1) {
                if (has_precondition_on_var[var_id]) {
                    has_precond_and_effect_on_var[var_id] = true;
                    eff_pairs.emplace_back(pattern_var_id, val);
                } else {
                    effects_without_pre.emplace_back(pattern_var_id, val);
                }
            }
        }
        for (FactProxy pre : op.get_preconditions()) {
            int var_id = pre.get_variable().get_id();
            int pattern_var_id = variable_to_index[var_id];
            int val = pre.get_value();
            if (pattern_var_id != -1) { // variable occurs in pattern
                if (has_precond_and_effect_on_var[var_id]) {
                    pre_pairs.emplace_back(pattern_var_id, val);
                } else {
                    prev_pairs.emplace_back(pattern_var_id, val);
                }
            }
        }
        multiply_out(0, cost, prev_pairs, pre_pairs, eff_pairs, effects_without_pre,
                     variables, operators, op.get_id());
    }

    bool PatternDatabaseFactory::is_goal_state(
            const size_t state_index,
            const vector<FactPair> &abstract_goals,
            const VariablesProxy &variables) const {
        for (const FactPair &abstract_goal : abstract_goals) {
            int pattern_var_id = abstract_goal.var;
            int var_id = pattern[pattern_var_id];
            VariableProxy var = variables[var_id];
            int temp = state_index / hash_multipliers[pattern_var_id];
            int val = temp % var.get_domain_size();
            if (val != abstract_goal.value) {
                return false;
            }
        }
        return true;
    }

    shared_ptr<PatternDatabase> PatternDatabaseFactory::generate() {
        task_properties::verify_no_axioms(task_proxy);
        task_properties::verify_no_conditional_effects(task_proxy);
        assert(operator_costs.empty() ||
               operator_costs.size() == task_proxy.get_operators().size());
        assert(utils::is_sorted_unique(pattern));

        utils::Timer timer;

        std::size_t num_states;
        hash_multipliers.reserve(pattern.size());
        num_states = 1;
        for (int pattern_var_id : pattern) {
            hash_multipliers.push_back(num_states);
            VariableProxy var = task_proxy.get_variables()[pattern_var_id];
            if (utils::is_product_within_limit(num_states, var.get_domain_size(),
                                               numeric_limits<int>::max())) {
                num_states *= var.get_domain_size();
            } else {
                cerr << "Given pattern is too large! (Overflow occured): " << endl;
                cerr << pattern << endl;
                utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
            }
        }

        VariablesProxy variables = task_proxy.get_variables();
        vector<int> variable_to_index(variables.size(), -1);
        for (size_t i = 0; i < pattern.size(); ++i) {
            variable_to_index[pattern[i]] = i;
        }

        // compute all abstract operators
        vector<AbstractOperator> operators;
        for (OperatorProxy op : task_proxy.get_operators()) {
            int op_cost;
            if (operator_costs.empty()) {
                op_cost = op.get_cost();
            } else {
                op_cost = operator_costs[op.get_id()];
            }
            build_abstract_operators(
                    op, op_cost, variable_to_index, variables, operators);
        }

        // build the match tree
        MatchTree match_tree(task_proxy, pattern, hash_multipliers);
        for (size_t op_id = 0; op_id < operators.size(); ++op_id) {
            const AbstractOperator &op = operators[op_id];
            match_tree.insert(op_id, op.get_regression_preconditions());
        }

        // compute abstract goal var-val pairs
        vector<FactPair> abstract_goals;
        for (FactProxy goal : task_proxy.get_goals()) {
            int var_id = goal.get_variable().get_id();
            int val = goal.get_value();
            if (variable_to_index[var_id] != -1) {
                abstract_goals.emplace_back(variable_to_index[var_id], val);
            }
        }

        std::vector<int> distances;
        distances.reserve(num_states);
        // first implicit entry: priority, second entry: index for an abstract state
        priority_queues::AdaptiveQueue<size_t> pq;

        // initialize queue
        for (size_t state_index = 0; state_index < num_states; ++state_index) {
            if (is_goal_state(state_index, abstract_goals, variables)) {
                pq.push(0, state_index);
                distances.push_back(0);
            } else {
                distances.push_back(numeric_limits<int>::max());
            }
        }

        // Dijkstra loop
        while (!pq.empty()) {
            pair<int, size_t> node = pq.pop();
            int distance = node.first;
            size_t state_index = node.second;
            if (distance > distances[state_index]) {
                continue;
            }

            // regress abstract_state
            vector<int> applicable_operator_ids;
            match_tree.get_applicable_operator_ids(state_index, applicable_operator_ids);
            for (int op_id : applicable_operator_ids) {
                const AbstractOperator &op = operators[op_id];
                size_t predecessor = state_index + op.get_hash_effect();
                int alternative_cost = distances[state_index] + op.get_cost();
                if (alternative_cost < distances[predecessor]) {
                    distances[predecessor] = alternative_cost;
                    pq.push(alternative_cost, predecessor);
                }
            }
        }
        std::vector<std::vector<std::pair<int, int>>> operator_transitions(task_proxy.get_operators().size());

        if (save_operator_transitions) {
            for (size_t successor = 0; successor < num_states; ++successor) {
                vector<int> applicable_operator_ids;
                match_tree.get_applicable_operator_ids(successor, applicable_operator_ids);
                for (int op_id : applicable_operator_ids) {
                    const AbstractOperator &op = operators[op_id];
                    const size_t predecessor = successor + op.get_hash_effect();
                    int original_op_id = op.get_original_op_id();

                    operator_transitions[original_op_id].emplace_back(static_cast<int>(predecessor),
                                                                      static_cast<int>(successor));
                }
            }
        }

        if (dump)
            utils::g_log << "PDB construction time: " << timer << endl;

        return make_shared<PatternDatabase>(
                pattern, num_states, move(distances), move(hash_multipliers), move(operator_transitions));
    }

    shared_ptr<PatternDatabase> create_default_pdb(
            const TaskProxy &task_proxy,
            const Pattern &pattern,
            bool dump,
            const std::vector<int> &operator_costs,
            bool save_operator_transitions) {
        PatternDatabaseFactory pdb_factory(task_proxy, pattern, dump, operator_costs, save_operator_transitions);
        return pdb_factory.generate();
    }
}