#include "mscp_heuristic.h"

#include "../merge_and_shrink/cost_partitioning.h"
#include "../merge_and_shrink/cp_mas.h"
#include "../merge_and_shrink/distances.h"
#include "../merge_and_shrink/factored_transition_system.h"
#include "../merge_and_shrink/fts_factory.h"
#include "../merge_and_shrink/labels.h"
#include "../merge_and_shrink/merge_and_shrink_algorithm.h"
#include "../merge_and_shrink/merge_and_shrink_representation.h"
#include "../merge_and_shrink/saturated_cost_partitioning.h"
#include "../merge_and_shrink/shrink_bisimulation.h"
#include "../merge_and_shrink/shrink_strategy.h"
#include "../merge_and_shrink/transition_system.h"
#include "../merge_and_shrink/utils.h"

#include "../option_parser.h"
#include "../plugin.h"

#include "../task_utils/task_properties.h"

#include "../utils/countdown_timer.h"

#include <cassert>
#include <iostream>
#include <limits>
#include <set>
#include <string>

using namespace std;

namespace mscp_heuristic {

// separate namespace to get abstractions, avoiding clashes
namespace {
vector<unique_ptr<merge_and_shrink::Abstraction>> get_abstractions(
        const merge_and_shrink::FactoredTransitionSystem &fts, vector<int> indices) {
    vector <unique_ptr<merge_and_shrink::Abstraction>> abstractions;

    for (int index : indices) {
        assert(fts.is_active(index));
        const merge_and_shrink::TransitionSystem *transition_system = fts.get_transition_system_raw_ptr(index);
        unique_ptr <merge_and_shrink::MergeAndShrinkRepresentation> mas_representation = nullptr;

        if (dynamic_cast<const merge_and_shrink::MergeAndShrinkRepresentationLeaf *>(fts.get_mas_representation_raw_ptr(index))) {
            mas_representation = utils::make_unique_ptr<merge_and_shrink::MergeAndShrinkRepresentationLeaf>(
                    dynamic_cast<const merge_and_shrink::MergeAndShrinkRepresentationLeaf *>
                    (fts.get_mas_representation_raw_ptr(index)));
        } else {
            mas_representation = utils::make_unique_ptr<merge_and_shrink::MergeAndShrinkRepresentationMerge>(
                    dynamic_cast<const merge_and_shrink::MergeAndShrinkRepresentationMerge *>(
                            fts.get_mas_representation_raw_ptr(index)));
        }
        abstractions.push_back(utils::make_unique_ptr<merge_and_shrink::Abstraction>(transition_system, move(mas_representation)));
}
    return abstractions;
}
}

void MSCPHeuristic::prepare_priority_queue(PriorityQueue &pq_merged_systems, const vector<int> &label_costs) {
    // prune the atomic projections
    if (prune) {
        for (int i = 0; i < fts.get_size(); i++) {
            if (merge_and_shrink::prune_step(fts, i, true, true, verbosity)) {
                ++prune_counter;
            }
        }
    }

    // loop through fts to find the pair to merge
    for (int i = 0; i < fts.get_size() - 1; i++) {
        for (int j = i + 1; j < fts.get_size(); j++) {
            // check if the current transition systems were not merged before
            if (fts.is_active(i) && fts.is_active(j)) {
                // get the transition systems at indexes i and j
                const merge_and_shrink::TransitionSystem &tr_sys0 = fts.get_transition_system(i);
                const merge_and_shrink::TransitionSystem &tr_sys1 = fts.get_transition_system(j);

                // calculate the sizes for both
                int i_size = tr_sys0.get_size();
                int j_size = tr_sys1.get_size();

                // make sure merge does not exceed limit
                int estimated_merged_size = i_size * j_size;
                if (estimated_merged_size > merge_threshold) {
                    continue;
                }

                int merged_size;

                if (check_shrink) {
                    // check if the merged system is not too large
                    pair<merge_and_shrink::TransitionSystem, merge_and_shrink::TransitionSystem> ts =
                            shrink_before_merge_step_and_keep(
                                    fts,
                                    i,
                                    j,
                                    max_states,
                                    max_states_before_merge,
                                    max_states_before_merge,
                                    *shrink_strategy,
                                    verbosity);

                    // compute the merged system
                    const unique_ptr <merge_and_shrink::TransitionSystem> merged_tr_sys = fts.merge_and_keep(ts.first,
                                                                                                             ts.second,
                                                                                                             verbosity);


                    // calculate the distance for the merged system
                    unique_ptr <merge_and_shrink::Distances> dist = utils::make_unique_ptr<merge_and_shrink::Distances>(*merged_tr_sys);
                    dist->compute_distances(false, true, verbosity);
                    merged_dist = dist->get_goal_distance(merged_tr_sys->get_init_state());

                    // calculate the size for the merged system
                    merged_size = merged_tr_sys->get_size();
                } else {
                    // compute the merged system
                    const unique_ptr <merge_and_shrink::TransitionSystem> merged_tr_sys = fts.merge_and_keep(tr_sys0, tr_sys1, verbosity);

                    // calculate the distance for the merged system
                    unique_ptr <merge_and_shrink::Distances> dist = utils::make_unique_ptr<merge_and_shrink::Distances>(
                            *merged_tr_sys);
                    dist->compute_distances(false, true, verbosity);
                    merged_dist = dist->get_goal_distance(merged_tr_sys->get_init_state());

                    // calculate the size for the merged system
                    merged_size = merged_tr_sys->get_size();
                }

                // compute cost partitioning for order i, j
                vector<int> order = {0,1};
                vector<int> label_costs_copy(label_costs);
                vector<unique_ptr<merge_and_shrink::Abstraction>> abstractions = get_abstractions(fts, {i, j});
                merge_and_shrink::SaturatedCostPartitioningFactory* scp_factory =
                        static_cast<merge_and_shrink::SaturatedCostPartitioningFactory*>(cp_factory.get());
                unique_ptr<merge_and_shrink::CostPartitioning> cp =
                        scp_factory->generate_for_order(move(label_costs_copy), move(abstractions), order, verbosity);
                int cp_dist = cp->compute_value(task_proxy.get_initial_state());

                // compute cost partitioning for order j, i
                order = {1,0};
                label_costs_copy = label_costs;
                abstractions = get_abstractions(fts, {i, j});
                cp = scp_factory->generate_for_order(move(label_costs_copy), move(abstractions), order, verbosity);
                int cp_dist2 = cp->compute_value(task_proxy.get_initial_state());

                // choose the best value between the distances
                cp_dist = max(cp_dist, cp_dist2);

                // compute the quality
                double curr_quality;

                if (consider_size) {
                    if (merged_size <= (i_size + j_size) && merged_dist > cp_dist){
                        curr_quality = 99999;
                        cout << " ----> merged_size = " << merged_size << endl;
                        cout << " ----> cp_size = " << i_size + j_size << endl;
                        cout << " ----> merged_dist = " << merged_dist << endl;
                        cout << " ----> cp_dist = " << cp_dist << endl;

                    }
                    else if (merged_size <= (i_size + j_size) && merged_dist == cp_dist)
                        curr_quality = 1.0 + (double)(merged_dist - cp_dist);
                    else if (merged_size <= (i_size + j_size) && merged_dist < cp_dist)
                        curr_quality = (double)(merged_dist - cp_dist);
                    else
                        curr_quality = ((double)(merged_dist - cp_dist))/((double)(merged_size - (i_size + j_size)));
                }
                else {
                    curr_quality = merged_dist - cp_dist;
                }


                // store the merge qualities for future access
                pq_merged_systems.push(pair<double,pair<int,int>>(curr_quality, pair<int,int>(i,j)));
            }
        }
    }
}

void MSCPHeuristic::extend_priority_queue(PriorityQueue &pq_merged_systems, const vector<int> &label_costs, int last_merged_index) {
    assert(fts.is_active(last_merged_index));

    // get the system merged in previous iteration
    const merge_and_shrink::TransitionSystem &prev_merged_tr_sys = fts.get_transition_system(last_merged_index);

    for (int i = 0; i < fts.get_size(); i++) {
        // make sure i has a correct value
        if (i == last_merged_index || !fts.is_active(i)) {
            continue;
        }

        // get the system at index i
        const merge_and_shrink::TransitionSystem &curr_tr_sys = fts.get_transition_system(i);

        // calculate the sizes for both
        int i_size = prev_merged_tr_sys.get_size();
        int j_size = curr_tr_sys.get_size();

        // make sure merge does not exceed limit
        int estimated_merged_size = i_size * j_size;
        if (estimated_merged_size > merge_threshold) {
            continue;
        }

        int merged_size;

        if (check_shrink) {
            pair <merge_and_shrink::TransitionSystem, merge_and_shrink::TransitionSystem> ts =
                    shrink_before_merge_step_and_keep(
                            fts,
                            i,
                            last_merged_index,
                            max_states,
                            max_states_before_merge,
                            max_states_before_merge,
                            *shrink_strategy,
                            verbosity);

            // compute the merged system
            const unique_ptr <merge_and_shrink::TransitionSystem> merged_tr_sys = fts.merge_and_keep(ts.first, ts.second, verbosity);

            // calculate the distance for the merged system
            unique_ptr <merge_and_shrink::Distances> dist = utils::make_unique_ptr<merge_and_shrink::Distances>(
                    *merged_tr_sys);
            dist->compute_distances(false, true, verbosity);
            merged_dist = dist->get_goal_distance(merged_tr_sys->get_init_state());

            // calculate the size for the merged system
            merged_size = merged_tr_sys->get_size();

        } else {
            // compute the merged system
            const unique_ptr <merge_and_shrink::TransitionSystem> merged_tr_sys = fts.merge_and_keep(curr_tr_sys, prev_merged_tr_sys, verbosity);
            // calculate the distance for the merged system
            unique_ptr <merge_and_shrink::Distances> dist = utils::make_unique_ptr<merge_and_shrink::Distances>(
                    *merged_tr_sys);
            dist->compute_distances(false, true, verbosity);
            merged_dist = dist->get_goal_distance(merged_tr_sys->get_init_state());

            // calculate the size for the merged system
            merged_size = merged_tr_sys->get_size();
        }

        // compute cost partitioning for order i, j
        vector<int> order = {0, 1};
        vector<int> label_costs_copy(label_costs);
        vector <unique_ptr<merge_and_shrink::Abstraction>> abstractions = get_abstractions(fts, {i,
                                                                                                 last_merged_index});
        merge_and_shrink::SaturatedCostPartitioningFactory *scp_factory =
                static_cast<merge_and_shrink::SaturatedCostPartitioningFactory *>(cp_factory.get());
        unique_ptr <merge_and_shrink::CostPartitioning> cp =
                scp_factory->generate_for_order(move(label_costs_copy), move(abstractions), order,
                                                verbosity);
        int cp_dist = cp->compute_value(task_proxy.get_initial_state());

        // compute cost partitioning for order j, i
        order = {1, 0};
        label_costs_copy = label_costs;
        abstractions = get_abstractions(fts, {i, last_merged_index});
        cp = scp_factory->generate_for_order(move(label_costs_copy), move(abstractions), order, verbosity);
        int cp_dist2 = cp->compute_value(task_proxy.get_initial_state());

        // choose the best value between the distances
        cp_dist = max(cp_dist, cp_dist2);

        // compute the quality
        double curr_quality;

        if (consider_size) {
            if (merged_size <= (i_size + j_size) && merged_dist > cp_dist){
                curr_quality = 99999;
                cout << " ----> merged_size = " << merged_size << endl;
                cout << " ----> cp_size = " << i_size + j_size << endl;
                cout << " ----> merged_dist = " << merged_dist << endl;
                cout << " ----> cp_dist = " << cp_dist << endl;

            }
            else if (merged_size <= (i_size + j_size) && merged_dist == cp_dist)
                curr_quality = 1.0 + (double)(merged_dist - cp_dist);
            else if (merged_size <= (i_size + j_size) && merged_dist < cp_dist)
                curr_quality = (double)(merged_dist - cp_dist);
            else
                curr_quality = ((double)(merged_dist - cp_dist))/((double)(merged_size - (i_size + j_size)));
        }
        else {
            curr_quality = merged_dist - cp_dist;
        }

        // store the merge qualities for future access
        pq_merged_systems.push(pair<double, pair<int, int>>(curr_quality, pair<int, int>(i, last_merged_index)));
    }
}

MSCPHeuristic::MSCPHeuristic(const Options &opts)
        : Heuristic(opts),
          verbosity(utils::Verbosity::NORMAL),
          fts(merge_and_shrink::create_factored_transition_system(task_proxy, true, true, verbosity)),
          main_loop_max_time(opts.get<double>("main_loop_max_time")),
          merge_threshold(opts.get<int>("merge_threshold")),
          max_states(opts.get<int>("max_states")),
          max_states_before_merge(opts.get<int>("max_states_before_merge")),
          check_shrink(opts.get<bool>("check_shrink")),
          perform_shrink(opts.get<bool>("perform_shrink")),
          cp_factory(opts.get<shared_ptr<merge_and_shrink::CostPartitioningFactory>>("cost_partitioning")),
          quality_threshold(opts.get<int>("quality_threshold")),
          consider_size(opts.get<bool>("consider_size")),
          nonlinear_merge(opts.get<bool>("nonlinear_merge")),
          prune(opts.get<bool>("prune")),
          max_occ_min_threshold(opts.get<int>("max_occ_min_threshold")),
          shrink_strategy(opts.get<shared_ptr<merge_and_shrink::ShrinkStrategy>>("shrink_strategy")){

    utils::CountdownTimer timer(main_loop_max_time);
    int iteration = 0;

    cout << " --------------------------------------------------------" << endl;
    cout << " -- FTS size = " << fts.get_num_active_entries() << endl;

    // create the priority queue to keep the merged qualities
    priority_queue<pair<double, pair<int, int>>, vector<pair<double, pair<int, int>>>, compare_pairs> pq_merged_systems;
    int last_merged_index = -1;
    bool allow_min_threshold;

    if (max_states > merge_threshold){
        max_states = merge_threshold;
    }

    if (max_states_before_merge > merge_threshold){
        max_states_before_merge = merge_threshold;
    }

    // main loop
    while (fts.get_num_active_entries() > 1) {
        // check if limit run time is not up
        if (timer.is_expired()) {
            break;
        }

        // get the label costs
        int num_labels = fts.get_labels().get_size();
        vector<int> label_costs(num_labels, -1);
        for (int label_no = 0; label_no < num_labels; ++label_no) {
            label_costs[label_no] = fts.get_labels().get_label_cost(label_no);
        }
        
        if (iteration == 0) {
            // queue setup
            prepare_priority_queue(pq_merged_systems, label_costs);
        }
        else {
            // queue extend
            extend_priority_queue(pq_merged_systems, label_costs, last_merged_index);
        }

        // if no pair is viable, do not merge
        if (pq_merged_systems.empty()) {
            break;
        }

        vector<pair<int, int>> best_pairs = {};
        int best_i = -1;
        int best_j = -1;
        double quality = -9999.0;

        if (nonlinear_merge) {
            // merge is non-linear
            // TODO: to make this work you need to make sure that the transition systems remain alive after merging
//            const pair<double, pair<int, int>> &top = pq_merged_systems.top();
//            best_pairs.push_back(top.second);
//            pq_merged_systems.pop();
//
//            // check if more pairs have the best quality
//            while (top.first == pq_merged_systems.top().first) {
//                const pair<double, pair<int, int>> &curr_top = pq_merged_systems.top();
//                best_pairs.push_back(curr_top.second);
//                pq_merged_systems.pop();
//
//                if (!pq_merged_systems.empty())
//                    break;
//            }
//
//            // if more pairs have the best quality, choose random
//            int randomIndex = rand() % best_pairs.size();
//            best_i = best_pairs[randomIndex].first;
//            best_j = best_pairs[randomIndex].second;
//            quality = top.first;
//
//            best_pairs.erase(best_pairs.begin() + randomIndex);
//
//            // put the pairs not chosen back in the queue
//            for(size_t p = 0; p < best_pairs.size(); p++) {
//                pq_merged_systems.push(pair <double, pair<int,int>>(top.first, best_pairs[p]));
//            }
            cout << "Not Implemented!" << endl;
            exit(EXIT_SUCCESS);
        } else {
            // merge is linear
            while(!pq_merged_systems.empty()) {
                pair<double, pair<int, int>> top = pq_merged_systems.top();

                pq_merged_systems.pop();

                // check if indices are still actives
                if (fts.is_active(top.second.first) && fts.is_active(top.second.second)) {
                    best_pairs.push_back(top.second);


                    // check if more pairs have the best quality
                    while (top.first == pq_merged_systems.top().first){
                        pair<double, pair<int, int>> curr_top = pq_merged_systems.top();
                        pq_merged_systems.pop();

                        // check if current top indices are still active
                        if (fts.is_active(curr_top.second.first) && fts.is_active(curr_top.second.second)) {
                            best_pairs.push_back(curr_top.second);
                        }

                        if (!pq_merged_systems.empty())
                            break;
                    }

                    // if more pairs have the best quality, choose random
                    int randomIndex = rand() % best_pairs.size();
                    best_i = best_pairs[randomIndex].first;
                    best_j = best_pairs[randomIndex].second;
                    quality = top.first;

                    best_pairs.erase(best_pairs.begin() + randomIndex);

                    // put the pairs not chosen back in the queue
                    for(size_t p = 0; p < best_pairs.size(); p++) {
                        pq_merged_systems.push(pair<double,pair<int,int>>(top.first, best_pairs[p]));
                    }

                    break;
                }
            }
        }

        cout << " --------------------------------------------------------" << endl;
        cout << " -- Best pair to merge is: " << best_i << ", "
            << best_j << " with merge quality: " << quality << endl;

        // allow min threshold occurrences based on given value
        if (max_occ_min_threshold <= 0){
            allow_min_threshold = false;
        }
        else {
            allow_min_threshold = true;
            if (quality <= quality_threshold){
                --max_occ_min_threshold;
            }
        }

        // merge decision based on quality value
        if (quality > quality_threshold || allow_min_threshold) {

            // Decision: Merge
            assert(best_i >= 0);
            assert(best_j >= 0);

            if (perform_shrink) {
                if (timer.is_expired()) {
                    break;
                }

                pair<bool, bool> shrunk = merge_and_shrink::shrink_before_merge_step(
                        fts,
                        best_i,
                        best_j,
                        max_states,
                        max_states_before_merge,
                        max_states_before_merge,
                        *shrink_strategy,
                        verbosity);

                if (shrunk.first || shrunk.second){
                    ++shrink_counter;
                }

            }

            last_merged_index = fts.merge(best_i, best_j, verbosity);

            // pruning
            if (prune) {
                if (timer.is_expired()) {
                    break;
                }

                if(merge_and_shrink::prune_step(
                       fts,
                       last_merged_index,
                       true,
                       true,
                       verbosity)) {
                       ++prune_counter;
                }
            }

            ++iteration;
        }
        else {
            // Decision: Don't merge
            break;
        }
    }

    cout << " --------------------------------------------------------" << endl;
    cout << " - There have been " << iteration << " merges." << endl;
    cout << " - Running time: " << timer.get_elapsed_time() << " seconds." << endl;
    cout << " - We pruned " << prune_counter << " abstractions" << endl;
    cout << " - We shrank " << shrink_counter << " abstractions" << endl;
    cout << " --------------------------------------------------------" << endl;

    // get the remaining active factors
    vector<int> alive_factors;
    for (int i = 0; i < fts.get_size(); i++) {
        if (fts.is_active(i)) {
            alive_factors.push_back(i);
        }
    }

    // get the label costs
    int num_labels = fts.get_labels().get_size();
    vector<int> label_costs(num_labels, -1);
    for (int label_no = 0; label_no < num_labels; ++label_no) {
        label_costs[label_no] = fts.get_labels().get_label_cost(label_no);
    }

    // generate a cost partitioning over the abstractions of the active factors
    vector<unique_ptr<merge_and_shrink::Abstraction>> abstractions = get_abstractions(fts, alive_factors);
    cost_partitioning =
            cp_factory->generate(move(label_costs), move(abstractions), verbosity);
}

bool MSCPHeuristic::dead_ends_are_reliable() const {
    return true;
}

int MSCPHeuristic::compute_heuristic(const GlobalState &global_state) {
    State state = convert_global_state(global_state);

    // return the heuristic value
    return cost_partitioning->compute_value(state);
}

static shared_ptr<Heuristic> _parse(OptionParser &parser) {
    parser.document_synopsis("mscp heuristic", "");
    parser.add_option<int>("quality_threshold", "subset size", "0");
    parser.add_option<double>("main_loop_max_time", "max time", "900");
    parser.add_option<int>("merge_threshold", "subset size", "50000");
    parser.add_option<int>("max_states", "subset size", "50000");
    parser.add_option<int>("max_states_before_merge", "subset size", "50000");
    parser.add_option<int>("max_occ_min_threshold", "maximum number allowed of minimum value quality", "0");
    parser.add_option<bool>("nonlinear_merge", "allow nonlinear merges", "false");
    parser.add_option<bool>("check_shrink", "check if allow shrink", "false");
    parser.add_option<bool>("perform_shrink", "allow shrink", "false");
    parser.add_option<bool>("consider_size", "allow quality computation to consider size", "false");
    Heuristic::add_options_to_parser(parser);
    parser.add_option<shared_ptr<merge_and_shrink::CostPartitioningFactory>>(
            "cost_partitioning",
            "A method for computing cost partitionings over intermediate "
            "'snapshots' of the factored transition system.");
    parser.add_option<shared_ptr<merge_and_shrink::ShrinkStrategy>>(
            "shrink_strategy",
                    "See detailed documentation for shrink strategies. "
                    "We currently recommend non-greedy shrink_bisimulation, which can be "
                    "achieved using {{{shrink_strategy=shrink_bisimulation(greedy=false)}}}");
    parser.add_option<bool>(
            "prune",
            "true");

    Options opts = parser.parse();
    if (parser.dry_run())
        return nullptr;
    else
        return make_shared<MSCPHeuristic>(opts);
}

static Plugin<Evaluator> _plugin("mscp", _parse);
}
