#include "transition_normal_form_task.h"

#include "explicit_task.h"

#include "../option_parser.h"
#include "../plugin.h"
#include "../task_proxy.h"
#include "../task_utils/task_properties.h"

#include "../utils/collections.h"

#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>

using namespace std;

namespace extra_tasks {
#ifndef NDEBUG
static bool is_in_transition_normal_form(const explicit_tasks::ExplicitOperator &op) {
    if (op.preconditions.size() != op.effects.size()) {
        return false;
    }
    unordered_set<int> precondition_vars;
    unordered_set<int> effect_vars;
    for (const FactPair &fact : op.preconditions) {
        precondition_vars.insert(fact.var);
    }
    for (const explicit_tasks::ExplicitEffect &effect : op.effects) {
        effect_vars.insert(effect.fact.var);
    }
    return precondition_vars == effect_vars;
}
#endif

/*
  Task transformation for TNF tasks.

  Since the constructor has to be called with the correct arguments,
  describing a task in TNF, we keep the class in the .cc file and
  provide a factory function in the header file for creating TNF tasks.
*/
class TransitionNormalFormTask : public tasks::ExplicitTask {
    unordered_set<int> forget_op_ids;
    unordered_map<int, vector<int>> unknown_pre_op_ids;
public:
    TransitionNormalFormTask(
        const shared_ptr<AbstractTask> &parent,
        vector<explicit_tasks::ExplicitVariable> &&variables,
        vector<vector<set<FactPair>>> &&mutex_facts,
        vector<explicit_tasks::ExplicitOperator> &&operators,
        vector<int> &&initial_state_values,
        vector<FactPair> &&goals,
        unordered_set<int> &&forget_op_ids,
        unordered_map<int, vector<int>> unknown_pre_op_ids)
        : ExplicitTask(
              parent,
              move(variables),
              move(mutex_facts),
              move(operators),
              {},
              move(initial_state_values),
              move(goals)),
          forget_op_ids(move(forget_op_ids)),
          unknown_pre_op_ids(move(unknown_pre_op_ids)) {
        assert(variables.size() == goals.size());
        assert(all_of(
                   this->operators.begin(), this->operators.end(),
                   is_in_transition_normal_form));
    }

    unordered_set<int> get_forget_op_ids() const override {
        return forget_op_ids;
    }

    vector<int> get_unknown_pre_op_ids(int var) const override {
        auto it = unknown_pre_op_ids.find(var);
        if (it != unknown_pre_op_ids.end())
            return it->second;
        else
            return {};
    }

    /*
      Since TNF tasks only *extend* variable domains, no changes are
      needed to convert a given state into a TNF state.
    */
    virtual void convert_state_values_from_parent(
        vector<int> &) const override {
    }
};


static int get_unknown_value(
    const TaskProxy &parent_task_proxy, int var_id) {
    return parent_task_proxy.get_variables()[var_id].get_domain_size();
}

static set<int> get_ordered_mentioned_variables(const OperatorProxy &op) {
    set<int> vars;
    for (FactProxy precondition : op.get_preconditions()) {
        vars.insert(precondition.get_variable().get_id());
    }
    for (EffectProxy effect : op.get_effects()) {
        vars.insert(effect.get_fact().get_variable().get_id());
    }
    return vars;
}

static vector<int> get_precondition_values(const OperatorProxy &op, int num_vars) {
    vector<int> precondition_values(num_vars, -1);
    for (FactProxy precondition : op.get_preconditions()) {
        const FactPair fact = precondition.get_pair();
        precondition_values[fact.var] = fact.value;
    }
    return precondition_values;
}

static vector<int> get_effect_values(const OperatorProxy &op, int num_vars) {
    vector<int> effect_values(num_vars, -1);
    for (EffectProxy effect : op.get_effects()) {
        const FactPair fact = effect.get_fact().get_pair();
        effect_values[fact.var] = fact.value;
    }
    return effect_values;
}

/*
  TODO: This function runs in time O(|facts|^2). We could think about a
  faster way of passing mutex information between tasks (see issue661).
*/
static vector<vector<set<FactPair>>> create_mutexes(
    const TaskProxy &parent_task_proxy,
    const vector<bool> &unknown_fact_needed) {
    VariablesProxy parent_variables = parent_task_proxy.get_variables();

    // Initialize structure.
    vector<vector<set<FactPair>>> mutexes(parent_variables.size());
    for (VariableProxy var : parent_variables) {
        int tnf_domain_size = var.get_domain_size();
        if (unknown_fact_needed[var.get_id()])
            ++tnf_domain_size;
        mutexes[var.get_id()].resize(tnf_domain_size);
    }

    // Fill structure.
    FactsProxy facts = parent_variables.get_facts();
    for (FactProxy fact1_proxy : facts) {
        FactPair fact1 = fact1_proxy.get_pair();
        for (FactProxy fact2_proxy : facts) {
            if (fact1_proxy.is_mutex(fact2_proxy)) {
                mutexes[fact1.var][fact1.value].insert(fact2_proxy.get_pair());
            }
        }
    }
    return mutexes;
}

static vector<explicit_tasks::ExplicitVariable> create_variables(
    const TaskProxy &parent_task_proxy,
    const vector<bool> &unknown_fact_needed) {
    vector<explicit_tasks::ExplicitVariable> variables;
    variables.reserve(parent_task_proxy.get_variables().size());
    for (VariableProxy var : parent_task_proxy.get_variables()) {
        int var_id = var.get_id();
        int parent_domain_size = var.get_domain_size();
        int tnf_domain_size = parent_domain_size;
        if (unknown_fact_needed[var_id]) {
            ++tnf_domain_size;
        }
        string var_name = var.get_name();

        vector<string> fact_names;
        fact_names.reserve(tnf_domain_size);
        for (int value = 0; value < parent_domain_size; ++value) {
            FactProxy fact = var.get_fact(value);
            fact_names.push_back(fact.get_name());
        }
        if (unknown_fact_needed[var_id]) {
            fact_names.push_back(var_name + " " + "unknown");
        }

        variables.emplace_back(
            tnf_domain_size, move(var_name), move(fact_names), -1, -1);
    }
    return variables;
}

static explicit_tasks::ExplicitOperator create_normal_operator(
    const TaskProxy &parent_task_proxy,
    vector<bool> &unknown_fact_needed,
    unordered_map<int, vector<int>> &unknown_pre_op_ids,
    const OperatorProxy &op) {
    int num_vars = parent_task_proxy.get_variables().size();
    vector<int> precondition_values = get_precondition_values(op, num_vars);
    vector<int> effect_values = get_effect_values(op, num_vars);

    vector<FactPair> preconditions;
    vector<explicit_tasks::ExplicitEffect> effects;
    for (int var_id : get_ordered_mentioned_variables(op)) {
        int pre_value = precondition_values[var_id];
        if (pre_value == -1) {
            unknown_fact_needed[var_id] = true;
            pre_value = get_unknown_value(parent_task_proxy, var_id);
            unknown_pre_op_ids[var_id].push_back(op.get_id());
        }
        int post_value = effect_values[var_id];
        if (post_value == -1) {
            assert(precondition_values[var_id] != -1);
            post_value = pre_value;
        }
        preconditions.emplace_back(var_id, pre_value);
        effects.emplace_back(var_id, post_value, vector<FactPair>());
    }
    return explicit_tasks::ExplicitOperator(
        move(preconditions),
        move(effects),
        op.get_cost(),
        op.get_name(),
        false);
}

static void create_normal_operators(
    const TaskProxy &parent_task_proxy,
    vector<bool> &unknown_fact_needed,
    unordered_map<int, vector<int>> &unknown_pre_op_ids,
    vector<explicit_tasks::ExplicitOperator> &operators) {
    for (OperatorProxy op : parent_task_proxy.get_operators()) {
        operators.push_back(
            create_normal_operator(
                parent_task_proxy,
                unknown_fact_needed,
                unknown_pre_op_ids,
                op));
    }
}

static explicit_tasks::ExplicitOperator create_forget_operator(
    const FactProxy &fact, int post_value) {
    int var_id = fact.get_variable().get_id();
    return explicit_tasks::ExplicitOperator(
        {fact.get_pair()},
        {explicit_tasks::ExplicitEffect(var_id, post_value, {})},
        0,
        "forget " + fact.get_name(),
        false);
}

static void create_forget_operators(
    const TaskProxy &parent_task_proxy,
    const vector<bool> &unknown_fact_needed,
    unordered_set<int> &forget_op_ids,
    vector<explicit_tasks::ExplicitOperator> &operators) {
    for (VariableProxy var : parent_task_proxy.get_variables()) {
        int var_id = var.get_id();
        if (unknown_fact_needed[var.get_id()]) {
            int post_value = get_unknown_value(parent_task_proxy, var_id);
            for (int value = 0; value < var.get_domain_size(); ++value) {
                FactProxy fact = var.get_fact(value);
                operators.push_back(create_forget_operator(fact, post_value));
                forget_op_ids.insert(operators.size()-1);
            }
        }
    }
}

// Create fully defined goal state.
static vector<FactPair> create_goals(const TaskProxy &parent_task_proxy) {
    VariablesProxy variables = parent_task_proxy.get_variables();
    vector<FactPair> goals;
    goals.reserve(variables.size());
    for (VariableProxy var : variables) {
        int var_id = var.get_id();
        goals.emplace_back(var_id, get_unknown_value(parent_task_proxy, var_id));
    }
    for (FactProxy goal : parent_task_proxy.get_goals()) {
        const FactPair fact = goal.get_pair();
        goals[fact.var] = fact;
    }
    return goals;
}

shared_ptr<AbstractTask> create_transition_normal_form_task(
    const shared_ptr<AbstractTask> &parent) {
    TaskProxy parent_task_proxy(*parent);
    task_properties::verify_no_axioms(parent_task_proxy);
    task_properties::verify_no_conditional_effects(parent_task_proxy);

    /*
      Information we later need to access from ConstraintConstructor.
    */
    unordered_set<int> forget_op_ids = {};
    unordered_map<int, vector<int>> unknown_pre_op_ids = {};

    /*
      We add an "unknown" fact for variables occuring in effects, but
      not in preconditions, and variables missing from the goal
      description.
    */
    vector<bool> unknown_fact_needed(
        parent_task_proxy.get_variables().size(), false);

    /*
      Compute TNF versions of normal operators and record which
      variables need "unknown" fact.
    */
    vector<explicit_tasks::ExplicitOperator> operators;
    create_normal_operators(
        parent_task_proxy, unknown_fact_needed, unknown_pre_op_ids, operators);

    // Create TNF goals.
    vector<FactPair> goals = create_goals(parent_task_proxy);

    // Variables missing in goal description need an "unknown" fact.
    for (const FactPair &goal : goals) {
        if (goal.value == get_unknown_value(parent_task_proxy, goal.var)) {
            unknown_fact_needed[goal.var] = true;
        }
    }

    // Create "forget" operators for the variables with "unknown" facts.
    create_forget_operators(
        parent_task_proxy, unknown_fact_needed, forget_op_ids, operators);

    return make_shared<TransitionNormalFormTask>(
        parent,
        create_variables(parent_task_proxy, unknown_fact_needed),
        create_mutexes(parent_task_proxy, unknown_fact_needed),
        move(operators),
        parent->get_initial_state_values(),
        move(goals),
        move(forget_op_ids),
        move(unknown_pre_op_ids));
}

static shared_ptr<AbstractTask> _parse(OptionParser &parser) {
    parser.document_language_support("conditional effects", "not supported");
    parser.document_language_support("axioms", "not supported");

    parser.add_option<shared_ptr<AbstractTask>>(
        "transform",
        "Parent task transformation",
        "no_transform");

    Options opts = parser.parse();

    if (parser.dry_run())
        return nullptr;

    return create_transition_normal_form_task(
        opts.get<shared_ptr<AbstractTask>>("transform"));
}
}
