#include "mutex_finder.h"

#include "../task_proxy.h"

#include "../utils/logging.h"
#include "../utils/system.h"

#include <utility>

using namespace std;
using utils::ExitCode;
using Bitset = dynamic_bitset::DynamicBitset<>;

namespace parity_potentials {
MutexFinder::MutexFinder(
    const ConstraintConstructor &cc,
    const TaskProxy &tnf_task_proxy,
    const int num_1d_tuples,
    const Bitset &init_tuples,
    const vector<pair<FactProxy, FactProxy>> &&pair_id_to_facts)
    : cc(cc),
      tnf_task_proxy(tnf_task_proxy),
      num_1d_tuples(num_1d_tuples),
      num_tuples(init_tuples.size()),
      table(init_tuples),
      pair_id_to_facts(pair_id_to_facts) {

    assert(pair_id_to_facts.size() == (size_t)(num_tuples - num_1d_tuples));

    utils::g_log << "    Finding mutexes... ";
    bool was_updated;
    do {
        was_updated = false;
        for (const OperatorProxy &op : tnf_task_proxy.get_operators()) {
            Bitset pre = get_tuples(op.get_preconditions());
            if (pre.is_subset_of(table)) {
                Bitset eff = get_tuples(op.get_effects());
                if (add_to_table(eff))
                    was_updated = true;
                eff.resize(num_1d_tuples);
                for (size_t i = 0; i < eff.size(); ++i) {
                    if (eff[i]) {
                        if (extend_to_pair(i, pre, op))
                            was_updated = true;
                    }
                }
            }
        }
    } while (was_updated);
    utils::g_log << "done!" << endl;
    GoalsProxy goals = tnf_task_proxy.get_goals();
    for (int i = 0; i < goals.size(); ++i) {
        for (int j = i+1; j < goals.size(); ++j) {
            if (are_mutex(goals[i], goals[j])) {
                utils::g_timer.stop();
                utils::g_log << "Total time: " << utils::g_timer << endl;
                utils::g_log << "Goal contains mutex! "
                             << "Task is unsolvable." << endl;
                utils::exit_with(ExitCode::SUCCESS);
            }
        }
    }

}

bool MutexFinder::add_to_table(const Bitset &to_add) {
    Bitset new_table = table | to_add;
    if (table != new_table) {
        table = new_table;
        return true;
    } else {
        return false;
    }
}

const FactProxy* MutexFinder::contains(
    const pair<FactProxy, FactProxy> &p,
    const int fact_id) const {
    int fact1_id = cc.get_con_var_id(p.first);
    int fact2_id = cc.get_con_var_id(p.second);
    if (fact1_id == fact_id) {
        return &p.second;
    } else if (fact2_id == fact_id) {
        return &p.first;
    } else {
        return nullptr;
    }
}

template <class FactProxyCollection>
bool MutexFinder::contradicts(
    const FactProxy &fact_to_check,
    const FactProxyCollection &facts) const {
    for (const FactProxy fact : facts) {
        if (fact.get_variable() == fact_to_check.get_variable()
            && fact.get_value() != fact_to_check.get_value())
            return true;
    }
    return false;
}
bool MutexFinder::contradicts(
    const FactProxy &fact_to_check,
    const EffectsProxy &effs) const {
    vector<FactProxy> facts;
    facts.reserve(effs.size());
    for (const EffectProxy &eff : effs) {
        facts.push_back(eff.get_fact());
    }
    return contradicts(fact_to_check, facts);
}

bool MutexFinder::extend_to_pair(
    int eff_fact_id,
    const Bitset &pre,
    const OperatorProxy &op) {
    /* Get pairs that contain fact_id.
       candidate_ids.first is the pair id.
       candidate_ids.second is the other fact in the pair. */
    vector<pair<int, FactProxy>> candidates;
    for (size_t i = 0; i < pair_id_to_facts.size(); ++i) {
        const FactProxy *other = contains(pair_id_to_facts[i], eff_fact_id);
        int candidate_id = i + num_1d_tuples;
        if (other && !table[candidate_id])
            candidates.emplace_back(candidate_id, *other);
    }

    /* Remove candidates that contradict
       operator effects or preconditions. 
       Also remove when the other fact in
       the pair is not in table.*/
    for (auto it = candidates.begin(); it < candidates.end(); /* */) {
        int other_id = cc.get_con_var_id(it->second);
        if (!table[other_id] ||
            contradicts(it->second, op.get_effects()) ||
            contradicts(it->second, op.get_preconditions()))
            it = candidates.erase(it);
        else
            ++it;
    }

    Bitset result(num_tuples);
    bool was_changed = false;
    for (const pair<int, FactProxy> &candidate : candidates) {
        assert(candidate.first >= num_1d_tuples &&
               candidate.first < num_tuples);

        // TODO: Optimize this.
        vector<FactProxy> new_pre_facts;
        for (const FactProxy &fact : op.get_preconditions()) {
            new_pre_facts.push_back(fact);
        }
        new_pre_facts.push_back(candidate.second);
        Bitset new_pre = get_tuples(new_pre_facts);
        if (new_pre.is_subset_of(table)) {
            Bitset to_add(num_tuples);
            to_add.set(candidate.first);
            if (add_to_table(to_add))
                was_changed = true;
        }
    }
    return was_changed;
}

Bitset MutexFinder::get_tuples(const EffectsProxy &effects) const {
    vector<FactProxy> effect_facts;
    effect_facts.reserve(effects.size());
    for (const EffectProxy &effect : effects) {
        effect_facts.push_back(effect.get_fact());
    }
    return get_tuples(effect_facts);
}

template <class FactProxyCollection>
Bitset MutexFinder::get_tuples(const FactProxyCollection &facts) const {
    Bitset result(table.size());
    for (const FactProxy &fact1 : facts) {
        int fact_id = cc.get_con_var_id(fact1);
        result.set(fact_id);
        for (const FactProxy &fact2 : facts) {
            if (fact1.get_variable().get_id() <=
                fact2.get_variable().get_id()) {
                /* TODO: Make sure this break is correct.
                   This is only the case if VariableProxies are
                   always sorted. */
                continue;
            }
            int pair_id = cc.get_con_var_id(fact1, fact2);
            result.set(pair_id);
        }
    }
    return result;
}

bool MutexFinder::are_mutex(const FactProxy &fact1, const FactProxy &fact2) const {
    int var_id1 = fact1.get_variable().get_id();
    int val1 = fact1.get_value();
    int var_id2 = fact2.get_variable().get_id();
    int val2 = fact2.get_value();

    if (var_id1 == var_id2) {
        if (val1 == val2)
            return false;
        else
            return true;
    }
    int pair_id;
    if (var_id1 > var_id2)
        pair_id = cc.get_con_var_id(fact1, fact2);
    else
        pair_id = cc.get_con_var_id(fact2, fact1);
    return !table[pair_id];
}

/*
void MutexFinder::find_mutexes(const TaskProxy &tnf_task_proxy, const Bitset &init_tuples) {
    assert(init_tuples.size() == num_tuples);
    table = init_tuples;

    vector<OperatorProxy> ops;
    const OperatorsProxy ops_proxy = tnf_task_proxy.get_operators();
    ops.reserve(ops_proxy.size());
    for (const OperatorProxy &op : ops_proxy) {
        ops.push_back(op);
    }
    assert(ops.size() > 0);

    bool was_updated;
    do {
        was_updated = false;
        for (auto it = ops.begin(); it < ops.end();) {
            vector<FactProxy> pre_facts = get_facts(it->get_preconditions());
            if (facts_hold(pre_facts)) {
                vector<FactProxy> eff_facts = get_facts(it->get_effects());
                for (const FactProxy &eff_fact : eff_facts) {
                    if (!fact_holds(eff_fact)) {
                        was_updated = true;
                        set_table_1d(eff_fact);
                    }
                }
                if (was_updated) {
                    ops.erase(it);
                    break;
                } else {
                    it = ops.erase(it);
                    continue;
                }
            }
        }
    } while (was_updated);

}
*/

/*
void MutexFinder::find_mutexes(const TaskProxy &tnf_task_proxy) {
    // Set tuples true in initial state.
    vector<Tuple> initial_tuples =
        get_tuples(tnf_task_proxy.get_initial_state());
    for (Tuple tuple : initial_tuples) {
        set_table_2d(tuple);
    }
    cout << "MUTEX AFTER INIT:" << endl;
    for (int i = table_2d.size()-1; i >= 0; --i) {
        cout << table_2d[i];
    }
    cout << endl;

    // Loop until no changes
    bool was_updated = false;
    do {
        was_updated = false;
        for (const OperatorProxy &op : tnf_task_proxy.get_operators()) {
            // TODO: Make get_tuple return a bitset and make this a bit
            // operation.
            vector<Tuple> pre_tuples = get_tuples(op.get_preconditions());
            if (tuples_hold(pre_tuples)) {
                // TODO: Make get_tuple return a bitset and make this a bit
                // operation.
                vector<Tuple> eff_tuples = get_tuples(op.get_effects());
                for (const Tuple &eff_tuple : eff_tuples) {
                    if (!tuple_holds(eff_tuple)) {
                        was_updated = true;
                        set_table_2d(eff_tuple);
                        cout << "Added reached tuple: " << eff_tuple.first.get_name() << ", " << eff_tuple.second.get_name() << endl;
                    }
                }
            }
        }
    } while (was_updated);

    cout << "MUTEX RESULT:" << endl;
    for (int i = table_2d.size()-1; i >= 0; --i) {
        cout << table_2d[i];
    }
    cout << endl;
}
*/

/*
bool MutexFinder::are_mutex(const FactProxy &fact1, const FactProxy &fact2) {
    // TODO: Add asserts.
    return !tuple_holds(fact1, fact2);
}

void MutexFinder::set_table_1d(const FactProxy &fact) {
    int fact_id = con_1d_ids[fact.get_variable().get_id()][fact.get_value()];
    table_1d.set(fact_id);
}


void MutexFinder::set_table_2d(const Tuple &tuple) {
    int tuple_id = get_tuple_id(tuple);
    table_2d.set(tuple_id);
}

bool MutexFinder::facts_hold(const vector<FactProxy> &vec) const {
    bool result = true;
    for (const FactProxy &fact : vec) {
        if (!fact_holds(fact)) {
            result = false;
            break;
        }
    }
    return result;
}

bool MutexFinder::fact_holds(const FactProxy &fact) const {
    int fact_id = con_1d_ids[fact.get_variable().get_id()][fact.get_value()];
    return table_1d[fact_id];
}

vector<FactProxy> MutexFinder::get_facts(const EffectsProxy &effects) const {
    vector<FactProxy> result;
    result.reserve(effects.size());
    for (const EffectProxy &effect : effects) {
        result.push_back(effect.get_fact());
    }
    return result;
}

template <class FactProxyCollection>
vector<FactProxy> MutexFinder::get_facts(const FactProxyCollection &facts) const {
    vector<FactProxy> result;
    result.reserve(facts.size());
    for (const FactProxy &fact : facts) {
        result.push_back(fact);
    }
    return result;
}

bool MutexFinder::tuples_hold(const vector<Tuple> &vec) const {
    bool result = true;
    for (const Tuple &tuple : vec) {
        if (!tuple_holds(tuple)) {
            result = false;
            break;
        }
    }
    return result;
}

bool MutexFinder::tuple_holds(const Tuple &tuple) const {
    return tuple_holds(tuple.first, tuple.second);
}

bool MutexFinder::tuple_holds(
    const FactProxy &fact1,
    const FactProxy &fact2) const {
    int tuple_id = get_tuple_id(fact1, fact2);
    return table_2d[tuple_id];
}

int MutexFinder::get_tuple_id(const Tuple &tuple) const {
    return get_tuple_id(tuple.first, tuple.second);
}

int MutexFinder::get_tuple_id(
    const FactProxy &fact1,
    const FactProxy &fact2) const { 
    int var1_id, val1, var2_id, val2, con_id;
    var1_id = fact1.get_variable().get_id();
    val1 = fact1.get_value();
    var2_id = fact2.get_variable().get_id();
    val2 = fact2.get_value();
    if (var1_id > var2_id) {
        con_id = con_2d_ids[var1_id][val1][var2_id][val2];
    } else {
        con_id = con_2d_ids[var2_id][val2][var1_id][val1];
    }
    return con_id - offset;
}

vector<Tuple> MutexFinder::get_tuples(const EffectsProxy &effects) {
    vector<FactProxy> effect_facts;
    for (const EffectProxy &effect : effects) {
        effect_facts.push_back(effect.get_fact());
    }
    return get_tuples(effect_facts);
}

template <class FactProxyCollection>
vector<Tuple> MutexFinder::get_tuples(const FactProxyCollection &facts) {
    vector<Tuple> result;
    for (const FactProxy &fact1 : facts) {
        for (const FactProxy &fact2 : facts) {
            if (fact1 != fact2)
                result.emplace_back(fact1, fact2);
        }
    }
    return result;
}
*/
}
