#include "cycle_constraints.h"

#include "../option_parser.h"
#include "../plugin.h"
#include "../landmarks/landmark_factory.h"
#include "../algorithms/johnson_cycle_detection.h"
#include "../lp/lp_solver.h"

#include <iomanip>

using namespace std;
using namespace landmarks;

namespace operator_counting {
CycleConstraints::CycleConstraints(const options::Options &opts)
    : lm(opts.get<shared_ptr<LandmarkFactory>>("lm")),
      use_cycle_constraints(opts.get<bool>("cycle_constraints")),
      recomp(opts.get<bool>("recomp")),
      consider_orderings(opts.get<bool>("consider_orderings")) {
}

void CycleConstraints::compute_orderings() {
    ordering_types.clear();
    for (auto &node : lm_graph->get_nodes()) {
        for (auto &child : node->children) {
            if (consider_orderings) {
                ordering_types[make_pair(node->get_id(),
                                         child.first->get_id())] =
                        (int) child.second;
            }
            if (!recomp && use_cycle_constraints) {
                lm_orderings.emplace_back(make_pair(node->get_id(),
                                                    child.first->get_id()));
            }
        }
    }
    if (!recomp && use_cycle_constraints) {
        reached_orderings = make_shared<PerStateBitset>(
                vector<bool>(lm_orderings.size(), true));
    }
}

void CycleConstraints::initialize_constraints(
        const shared_ptr<AbstractTask> &task,
        vector<lp::LPConstraint> &constraints, double infinity) {

    lm_graph = lm->compute_lm_graph(task);
    if (consider_orderings || (!recomp && use_cycle_constraints)) {
        compute_orderings();
    }

//    lm_graph->dump(task_proxy->get_variables());
    if (recomp) {
        // The following constraint is necessary to avoid an error when using IPs.
        constraints.emplace_back(0.0, infinity);
        return;
    }

    reached_lms = make_shared<PerStateBitset>(
            vector<bool>(lm_graph->number_of_landmarks(), true));
    shared_ptr<TaskProxy> task_proxy = make_shared<TaskProxy>(*task);
    compute_constraints(task_proxy->get_initial_state(), infinity);

    for (const auto &lm_constraint : lm_constraints) {
        constraints.push_back(lm_constraint);
    }

    if (use_cycle_constraints) {
        for (const auto &cycle_constraint : cycle_constraints) {
            constraints.push_back(cycle_constraint);
        }
    }
}

void CycleConstraints::compute_constraints(const State &state,
                                           double infinity) {
    lm_constraints.clear();
    for (auto &node : lm_graph->get_nodes()) {
        if (node->is_true_in_state(state)) {
            if (recomp) continue;
            /* We need the 'useless' constraint for indexing if we don't
               recompute the LM graph for every state. */
            lm_constraints.emplace_back(0.0, infinity);
        } else {
            lm_constraints.emplace_back(1.0, infinity);
        }
        lp::LPConstraint &constraint = lm_constraints.back();
        /* This probably only works since all nodes in the LM graph that we get
           are fact (or disjunctive fact) landmarks. */
        for (int op_id : node->possible_achievers) {
            if (op_id < 0) continue;
            constraint.insert(op_id, 1.0);
        }
    }

    if (use_cycle_constraints) compute_cycle_constraints(infinity);
}

void CycleConstraints::compute_cycle_constraints(double infinity) {

    cycle_constraints.clear();
    vector<vector<int>> adj = lm_graph->to_adj_list();
    vector<vector<int>> cycles = johnson_cycles::compute_elementary_cycles(adj);
//    double repeated_action_probability = 0;
    if (cycles.size() > max_cycles) {
        max_cycles = cycles.size();
        cout << "Found more cycles than in all expanded states before: max_cycle="
             << fixed << setprecision(0) << max_cycles << endl;
    }
    size_t loss_counter = 0;
    for (const auto &cycle : cycles) {
        /* Check whether all components of the cycle can be achieved simultaneously
           via one operator (possible due to the definition of reasonable orderings). */
        set<int> intersected_achievers =
                lm_graph->get_nodes()[cycle[0]]->possible_achievers;
        map<int, size_t> achievers;
        size_t counter = 1;
        for (size_t i = 0; i < cycle.size(); ++i) {
            pair<int, int> ord = make_pair(cycle[(i + cycle.size() - 1) % cycle.size()],
                                           cycle[i]);
            set<int> a = lm_graph->get_nodes()[ord.second]->possible_achievers;
            if (!consider_orderings || ordering_types[ord] < 2) {
                // only if the ordering is non-natural
                for (auto op : a) {
                    achievers[op] = achievers.count(op) ? achievers[op] + 1 : 1;
                }
                ++counter;
            } else if (first) {
                ++loss_counter;
            }
            vector<int> intersection(a.size());
            auto it = set_intersection(intersected_achievers.begin(),
                                       intersected_achievers.end(),
                                       a.begin(), a.end(), intersection.begin());
            intersection.resize(it - intersection.begin());
            intersected_achievers.clear();
            for (auto elem : intersection) intersected_achievers.insert(elem);
        }
        /* This means there is an operator which achieves all components of
           the cycle simultaneously. We do not consider such cycles. */
        if (intersected_achievers.empty()) {
            // This means not all components of the cycle can be achieved simultaneously.
            cycle_constraints.emplace_back(counter, infinity);
            lp::LPConstraint &constraint = cycle_constraints.back();
//            size_t c1 = 0, c2 = 0;
            for (auto op : achievers) {
                if (op.first < 0) continue;
                constraint.insert(op.first, op.second);
//                if (first) {
//                    ++c1;
//                    if (op.second > 1) ++c2;
//                }
            }
//            if (first) repeated_action_probability += (double) c2 / (double) c1;
        } // Otherwise don't add the constraint.

        if (!recomp) {
            vector<int> orderings;
            for (size_t i = 0; i < cycle.size(); ++i) {
                int from = cycle[i];
                int to = cycle[(i + 1) % cycle.size()];
                auto it = find(lm_orderings.begin(), lm_orderings.end(),
                               make_pair(from, to));
                assert(it != lm_orderings.end());
                orderings.push_back(distance(lm_orderings.begin(), it));
            }
            orderings_by_cycle.push_back(orderings);
        }
    }
    if (first) {
//        cout << cycles.size() << " cycles were found in the initial state." << endl;
//        cout << "LP Coefficient > 1 probability: "
//             << (cycles.empty() ? 0 : (repeated_action_probability * 100) / cycles.size())
//             << "%" << endl;
        cout << "There were " << loss_counter
             << " LM nodes in cycles with incoming natural orderings." << endl;
    }
}

bool CycleConstraints::update_constraints(const State &state,
                                          lp::LPSolver &lp_solver) {
    assert(recomp);
    vector<lp::LPConstraint> constraints;

    if (!first) lm_graph = lm->recompute_lm_graph(state);

    compute_constraints(state, lp_solver.get_infinity());
    if (first) first = false;
    for (auto &constraint : lm_constraints) constraints.push_back(constraint);
    if (use_cycle_constraints) {
        if (consider_orderings) compute_orderings();
        for (auto &constraint : cycle_constraints) constraints.push_back(constraint);
    }
    if (constraints.empty()) constraints.emplace_back(0.0, lp_solver.get_infinity());
    lp_solver.add_temporary_constraints(constraints);
//    for (auto &constraint : constraints) {
//        for (auto &var : constraint.get_variables()) {
//            cout << var << ", ";
//        }
//        cout << endl;
//    }
    return false;
}

bool CycleConstraints::update_constraints(const GlobalState &state,
                                          lp::LPSolver &lp_solver) {
    assert(!recomp);
    assert(lp_solver.get_num_constraints() == (*reached_lms)[state].size() +
                                              orderings_by_cycle.size());
    for (size_t i = 0; i < lp_solver.get_num_constraints(); ++i) {
        if (i < (*reached_lms)[state].size()) {
            lp_solver.set_constraint_lower_bound(
                    i, (*reached_lms)[state].test(i) ? 0 : 1);
        } else {
            assert(use_cycle_constraints);
            size_t index = i - (*reached_lms)[state].size();
            assert(index >= 0 && index < (*reached_orderings)[state].size());

            bool active = true;
            for (int j : orderings_by_cycle[index]) {
                if ((*reached_orderings)[state].test(j)) {
                    // If the ordering has been taken care of, deactivate the cycle.
                    active = false;
                    break;
                }
            }
            // If none of the orderings in the cycle is reached, keep the constraint.
            lp_solver.set_constraint_lower_bound(
                    i, active ? cycle_constraints[index].get_lower_bound() : 0);
        }
    }
    return false;
}

void CycleConstraints::reject_lms_required_again(const GlobalState &state) {
    for (size_t i = 0; i < lm_orderings.size(); ++i) {
        if (!(*reached_orderings)[state].test(i)) {
            (*reached_lms)[state].reset(lm_orderings[i].first);
            (*reached_lms)[state].reset(lm_orderings[i].second);
        }
    }
}

void CycleConstraints::notify_initial_state(const GlobalState &initial_state) {
    assert(!recomp);
    cout << "NOTIFYING INITIAL STATE" << endl;
    // Accept reached LMs.
    for (int i = 0; i < lm_graph->number_of_landmarks(); ++i) {
        if (!lm_graph->get_nodes()[i]->is_true_in_state(initial_state)) {
            (*reached_lms)[initial_state].reset(i);
        }
    }
    if (use_cycle_constraints) {
        for (size_t i = 0; i < lm_orderings.size(); ++i) {
            int node_id = lm_orderings[i].first;
            if (!lm_graph->get_nodes()[node_id]->is_true_in_state(initial_state)) {
                (*reached_orderings)[initial_state].reset(i);
            }
        }
        reject_lms_required_again(initial_state);
    }
}

bool CycleConstraints::notify_state_transition(const GlobalState &parent_state,
                                               OperatorID /*op_id*/,
                                               const GlobalState &global_state) {
    assert(!recomp);
    for (int i = 0; i < lm_graph->number_of_landmarks(); ++i) {
        if (!lm_graph->get_nodes()[i]->is_true_in_state(global_state)
        && (!(*reached_lms)[global_state].test(i) || !(*reached_lms)[parent_state].test(i))) {
            /* If LM is not true in current state and was missed through any paths to this state,
               this will set it to false. */
            (*reached_lms)[global_state].reset(i);
        }
    }
    if (use_cycle_constraints) {
        for (size_t i = 0; i < lm_orderings.size(); ++i) {
            /* If the ordering is not covered in the current state and has not been covered on
               all paths to this state, it will be set to false here. */
            // TODO: check whether this makes sense. The only ordering that is not necessarily
            //   covered by all plans are reasonable orderings. Is there a possibility that we
            //   destroy some progress when eliminating them at this point? Since they are not
            //   mandatory, not every plan must have them covered...? On the other hand, if
            //   they are not reached in one plans, it consequently means that the
            //   corresponding parent LM is not reached and must be achieved still, doesn't it?
            if (!lm_graph->get_nodes()[lm_orderings[i].first]->is_true_in_state(global_state)
            && (!(*reached_orderings)[parent_state].test(i)
            || !(*reached_orderings)[global_state].test(i))) {
                (*reached_orderings)[global_state].reset(i);
            }
        }
        reject_lms_required_again(global_state);
    }

    // TODO: Remove constraints that are supersets of other constraints. (This should
    //   happen in each iteration, because if the subset constraint is no longer active,
    //   the superset constraint can still be.)
    return false;
}

static shared_ptr<ConstraintGenerator> _parse(OptionParser &parser) {
    parser.document_synopsis("Cycle-covering constraints", ""/*TODO*/);
    parser.add_option<shared_ptr<LandmarkFactory>>("lm", "landmark factory",
            "lm_rhw(reasonable_orders=true)");
    parser.add_option<bool>("cycle_constraints",
                            "tells whether cycle constraints should be added to LP or not",
                            "true");
    parser.add_option<bool>("consider_orderings",
                            "Manipulates constraints to require incoming ordering to be non-natural.",
                            "true");

    Options opts = parser.parse();
    if (parser.dry_run()) return nullptr;
    return make_shared<CycleConstraints>(opts);
}

static Plugin<ConstraintGenerator> _plugin("cycle_constraints", _parse);
}
