#include "constraint_constructor.h"

#include "augmented_matrix.h"
#include "mutex_finder.h"

#include "../option_parser.h"
#include "../task_proxy.h"

#include "../algorithms/dynamic_bitset.h"
#include "../task_utils/task_properties.h"
#include "../tasks/transition_normal_form_task.h"
#include "../utils/logging.h"

#include <unordered_set>

using namespace std;
using Bitset = dynamic_bitset::DynamicBitset<>;

namespace parity_potentials {
ConstraintConstructor::ConstraintConstructor(const options::Options &opts)
    : tnf_task(
          extra_tasks::create_transition_normal_form_task(
              opts.get<shared_ptr<AbstractTask>>("transform"))),
      num_con_vars(0),
      num_1d_tuples(0),
      num_tuples(0) {
    TaskProxy tnf_task_proxy = get_tnf_task_proxy();

    task_properties::verify_no_axioms(tnf_task_proxy);
    task_properties::verify_no_conditional_effects(tnf_task_proxy);

    initialize(tnf_task_proxy);
    utils::g_log << "Number of operators: "
                 << tnf_task_proxy.get_operators().size()
                 << endl;
}

void ConstraintConstructor::initialize(const TaskProxy &tnf_task_proxy) {
    // Initialize one-dimensional constraint variables.
    VariablesProxy vars = tnf_task_proxy.get_variables();
    con_var_ids.resize(vars.size());
    for (VariableProxy var : vars) {
        int var_id = var.get_id();
        int num_values = var.get_domain_size();
        con_var_ids[var_id].resize(num_values);
        for (int val = 0; val < num_values; ++val) {
            FactProxy fact_proxy = var.get_fact(val);
            con_var_ids[var_id][val] = num_con_vars++;
        }
    }
    num_1d_tuples = num_con_vars;

    // Initialize two-dimensional constraint variables.
    con_pair_var_ids.resize(vars.size());
    for (VariableProxy var1 : vars) {
        int var1_id = var1.get_id();
        int num_values1 = var1.get_domain_size();
        con_pair_var_ids[var1_id].resize(num_values1);
        for (int val1 = 0; val1 < num_values1; ++val1) {
            vector<vector<int>> &fact1_ids = con_pair_var_ids[var1_id][val1];
            fact1_ids.resize(var1_id); 
            for (VariableProxy var2 : vars) {
                int var2_id = var2.get_id();
                if (var2_id >= var1_id) {
                    /* TODO: This is a break; in the code from Florian.
                             Can it be one here? */
                    continue;
                }
                int num_values2 = var2.get_domain_size();
                fact1_ids[var2_id].resize(num_values2);
                for (int val2 = 0; val2 < num_values2; ++val2) {
                    fact1_ids[var2_id][val2] = num_con_vars++;
                    // This vector is for MutexFinder.
                    FactProxy fact1 = var1.get_fact(val1);
                    FactProxy fact2 = var2.get_fact(val2);
                    pair_id_to_facts.emplace_back(fact1, fact2);
                }
            }
        }
    }
    num_tuples = num_con_vars;

    // Initialize context constraint variables.
    // TODO: Possible optimization by only adding relevant variables.
    OperatorsProxy ops = get_tnf_task_proxy().get_operators();
    con_ctx_var_ids.resize(ops.size());
    for (OperatorProxy op : ops) {
        con_ctx_var_ids[op.get_id()].resize(vars.size());
        for (VariableProxy var : vars) {
            bool var_in_op = false;
            for (EffectProxy eff : op.get_effects()) {
                if (eff.get_fact().get_variable().get_id() == var.get_id()) {
                    var_in_op = true;
                    break;
                }
            }
            if (!var_in_op) {
                con_ctx_var_ids[op.get_id()][var.get_id()] = num_con_vars++;
            }
        }
    }
}

void ConstraintConstructor::construct_separation_constraints(
    AugmentedMatrix &constraint_matrix) {
    TaskProxy tnf_task_proxy = get_tnf_task_proxy();

    // Potential of goal state must be 1.
    GoalsProxy goals = tnf_task_proxy.get_goals();
    assert(goals.size() == tnf_task_proxy.get_variables().size());
    Bitset& goal_constraint = constraint_matrix.add_zero_row();
    // This is the leftmost, most significant bit.
    int result_index = constraint_matrix.get_num_cols()-1;
    goal_constraint.set(result_index);
    for (FactProxy goal : goals) {
        int con_var_id = get_con_var_id(goal);
        goal_constraint.set(con_var_id);
        for (FactProxy goal2 : goals) {
            if (goal.get_variable().get_id() <= goal2.get_variable().get_id()) {
                /* TODO: This is a break; in the code from Florian.
                         Can it be one here? */
                continue;
            }
            int con_pair_var_id = get_con_var_id(goal, goal2);
            goal_constraint.set(con_pair_var_id);
        }
    }

    // Potential of initial state must be 0.
    State initial_state = tnf_task_proxy.get_initial_state();
    assert(initial_state.size() == tnf_task_proxy.get_variables().size());
    Bitset& initial_constraint = constraint_matrix.add_zero_row();
    for (FactProxy initial_fact : initial_state) {
        int con_var_id = get_con_var_id(initial_fact);
        initial_constraint.set(con_var_id);
        for (FactProxy initial_fact2 : initial_state) {
            if (initial_fact.get_variable().get_id() <=
                initial_fact2.get_variable().get_id()) {
                /* TODO: This is a break; in the code from Florian.
                         Can it be one here?
                         I think because variables and values are ordered... */
                continue;
            }
            int con_pair_var_id = get_con_var_id(initial_fact, initial_fact2);
            initial_constraint.set(con_pair_var_id);
        }
    }
    Bitset init_tuples = initial_constraint;
    init_tuples.resize(num_tuples);
    mutex_finder = unique_ptr<MutexFinder>(
        new MutexFinder(
            *this,
            tnf_task_proxy,
            num_1d_tuples,
            init_tuples,
            move(pair_id_to_facts)
        )
    );
}

void ConstraintConstructor::construct_op_consistency_constraints(
    AugmentedMatrix &constraint_matrix,
    const VariablesProxy &vars,
    const OperatorProxy &op) {

    unordered_map<int, int> var_to_precondition;
    for (FactProxy pre : op.get_preconditions()) {
        var_to_precondition[pre.get_variable().get_id()] = pre.get_value();
    }
    
    // Construct 0 = C_o^ind xor C_o^ctx constraint.
    Bitset& consistency_constraint = constraint_matrix.add_zero_row();
    // Add C_o^ind.
    for (EffectProxy effect : op.get_effects()) {
        VariableProxy var = effect.get_fact().get_variable();
        int var_id = var.get_id();
        // Handle one-dimensional features
        int pre = var_to_precondition.at(var_id);
        int post = effect.get_fact().get_value();
        /* Only add one-dimensional features if they are consumed/produced
           but forward static atoms to two-dimensional case anyway. */
        if (pre != post) {
            int pre_id = con_var_ids[var_id][pre];
            int post_id = con_var_ids[var_id][post];
            consistency_constraint.set(pre_id);
            consistency_constraint.set(post_id);
        }
        // Handle two-dimensional features.
        for (EffectProxy effect2 : op.get_effects()) {
            VariableProxy var2 = effect2.get_fact().get_variable();
            int var2_id = var2.get_id();
            if (var2_id < var_id) {
                int pre2 = var_to_precondition.at(var2_id);
                int post2 = effect2.get_fact().get_value();
                if (pre2 != post2 || pre != post) {
                    // Covers 2x consumed OR 1x static, 1x consumed case.
                    int pre_id =
                        con_pair_var_ids[var_id][pre][var2_id][pre2];
                    // Covers 2x produced OR 1x static, 1x produced case.
                    int post_id =
                        con_pair_var_ids[var_id][post][var2_id][post2];
                    consistency_constraint.set(pre_id);
                    consistency_constraint.set(post_id);
                }
            }
        }
    }
    // Add C_o^ctx.
    for (VariableProxy var : vars) {
        if (!var_to_precondition.count(var.get_id())) {
            int ctx_id = get_con_var_id(op, var);
            consistency_constraint.set(ctx_id);
        }
    }

    // Construct C_o,V^ctx constraints.
    for (VariableProxy var : vars) {
        if (!var_to_precondition.count(var.get_id())) {
            int ctx_id = get_con_var_id(op, var);
            int var_id = var.get_id();
            for (int value = 0; value < var.get_domain_size(); ++value) {
                // TODO: Skip if the atom is mutex with pre(op) or eff(op).
                FactProxy current_fact = var.get_fact(value);
                if (are_mutex(current_fact, op)) {
                    continue;
                }

                Bitset& ctx_constraint = constraint_matrix.add_zero_row();
                ctx_constraint.set(ctx_id);
                for (EffectProxy effect : op.get_effects()) {
                    VariableProxy var2 = effect.get_fact().get_variable();
                    int var2_id = var2.get_id();
                    int pre2 = var_to_precondition.at(var2_id);
                    int post2 = effect.get_fact().get_value();
                    // Skip if atoms are not in flips_o.
                    if (pre2 == post2) {
                        continue;
                    }
                    int pre_id, post_id;
                    assert(!(var2_id == var_id));
                    if (var2_id < var_id) {
                        pre_id =
                            con_pair_var_ids[var_id][value][var2_id][pre2];
                        post_id =
                            con_pair_var_ids[var_id][value][var2_id][post2];
                    } else {
                        pre_id =
                            con_pair_var_ids[var2_id][pre2][var_id][value];
                        post_id =
                            con_pair_var_ids[var2_id][post2][var_id][value];
                    }
                    ctx_constraint.set(pre_id);
                    ctx_constraint.set(post_id);
                }
            }
        }
    }
}

void ConstraintConstructor::construct_consistency_constraints(
    AugmentedMatrix &constraint_matrix) {
    TaskProxy tnf_task_proxy = get_tnf_task_proxy();
    OperatorsProxy operators = tnf_task_proxy.get_operators();

    unordered_set<int> forget_op_ids = tnf_task->get_forget_op_ids();


    for (const OperatorProxy &op : operators) {
        if (forget_op_ids.find(op.get_id()) != forget_op_ids.end()) {
            if (forget_op_is_unreachable(tnf_task_proxy, op))
                continue;
        }
        construct_op_consistency_constraints(
            constraint_matrix, tnf_task_proxy.get_variables(), op);
    }
}

AugmentedMatrix ConstraintConstructor::construct_constraints() {
    TaskProxy tnf_task_proxy = get_tnf_task_proxy();

    AugmentedMatrix constraint_matrix(num_con_vars);
    /* TODO: resize/reserve appropriate amount of memory for constraint_matrix.
       This actually doesn't seem feasible, too complicated to figure out the
       number of rows. */

    construct_separation_constraints(constraint_matrix);
    construct_consistency_constraints(constraint_matrix);

    return constraint_matrix;
}

bool ConstraintConstructor::forget_op_is_unreachable(
    const TaskProxy &tnf_task_proxy,
    const OperatorProxy &op) const {
    assert(op.get_preconditions().size() == 1);
    assert(op.get_effects().size() == 1);
    assert(op.get_preconditions()[0].get_variable().get_id() ==
           op.get_effects()[0].get_fact().get_variable().get_id());

    FactProxy fact_to_forget = op.get_preconditions()[0];
    int var_id = fact_to_forget.get_variable().get_id();
    int unknown_value = op.get_effects()[0].get_fact().get_value();

    // Check goal conditions.
    bool goal_mutex = false;
    for (const FactProxy &goal : tnf_task_proxy.get_goals()) {
        if (goal.get_variable().get_id() == var_id) {
            if (goal.get_value() != unknown_value) {
                return false;
            } else {
                if (goal_mutex)
                    break;
            }
        } else {
            if (!goal_mutex) {
                if (mutex_finder->are_mutex(goal, fact_to_forget))
                    goal_mutex = true;
            }
        }
    }
    if (!goal_mutex)
        return false;

    // Check operator conditions.
    vector<int> unknown_pre_op_ids = tnf_task->get_unknown_pre_op_ids(var_id);
    for (const int unknown_pre_op_id : unknown_pre_op_ids) {
        bool op_mutex = false;
        OperatorProxy unknown_pre_op =
            tnf_task_proxy.get_operators()[unknown_pre_op_id];
        for (const FactProxy &pre : unknown_pre_op.get_preconditions()) {
            assert(!(pre.get_variable().get_id() == var_id &&
                     pre.get_value() != unknown_value));
            if (pre.get_variable().get_id() == var_id)
                continue;
            if (mutex_finder->are_mutex(pre, fact_to_forget)) {
                op_mutex = true;
                break;
            }
        }
        if (!op_mutex)
            return false;
    }
    return true;
}

bool ConstraintConstructor::are_mutex(
    const FactProxy &fact,
    const OperatorProxy &op) const {
    for (const FactProxy &pre : op.get_preconditions()) {
        if (mutex_finder->are_mutex(fact, pre)) {
            return true;
        }
    }
    for (const EffectProxy &eff : op.get_effects()) {
        FactProxy eff_fact = eff.get_fact();
        if (mutex_finder->are_mutex(fact, eff_fact)) {
            return true;
        }
    }
    return false;
}

int ConstraintConstructor::get_con_var_id(const FactProxy &fact_proxy) const {
    const FactPair fact = fact_proxy.get_pair();
    assert(utils::in_bounds(fact.var, con_var_ids));
    assert(utils::in_bounds(fact.value, con_var_ids[fact.var]));
    return con_var_ids[fact.var][fact.value];
}

int ConstraintConstructor::get_con_var_id(
    const FactProxy &fact1, const FactProxy &fact2) const {
    int var1_id = fact1.get_variable().get_id();
    int value1 = fact1.get_value();
    int var2_id = fact2.get_variable().get_id();
    int value2 = fact2.get_value();
    assert(var1_id >= var2_id);
    assert(utils::in_bounds(var1_id, con_pair_var_ids));
    assert(utils::in_bounds(value1, con_pair_var_ids[var1_id]));
    const vector<vector<int>> &fact1_ids = con_pair_var_ids[var1_id][value1];
    assert(utils::in_bounds(var2_id, fact1_ids));
    assert(utils::in_bounds(value2, fact1_ids[var2_id]));
    return fact1_ids[var2_id][value2];
}

int ConstraintConstructor::get_con_var_id(const OperatorProxy &op, const
    VariableProxy &var) const {
    int op_id = op.get_id();
    int var_id = var.get_id();
    assert(utils::in_bounds(op_id, con_ctx_var_ids));
    assert(utils::in_bounds(var_id, con_ctx_var_ids[op_id]));
    return con_ctx_var_ids[op_id][var_id];
}

TaskProxy ConstraintConstructor::get_tnf_task_proxy() const {
    return TaskProxy(*tnf_task);
}
}


