#include "cluster_based_open_list.h"

#include "open_list.h"

#include "../globals.h"
#include "../option_parser.h"
#include "../plugin.h"

#include "../utils/collections.h"
#include "../utils/markup.h"
#include "../utils/memory.h"
#include "../utils/rng.h"
#include "../utils/logging.h"
#include "../utils/timer.h"
#include "../utils/states_helper.h"
#include "../utils/countdown_timer.h"

#include <cmath>
#include <memory>
#include <unordered_map>
#include <vector>

using namespace std;

template<class Entry>
class ClusterBasedOpenList : public OpenList<Entry> {
    void dump() {
        g_log << "Amount of Buckets: " << buckets.size() << endl;
        for (size_t i = 0; i < buckets.size(); i++) {
            g_log << "Bucket " << i << " Size: " << buckets[i].size() << endl;
        }
        cout << endl << endl;
    }

    struct StateTuple {
        Entry entry;
        vector<bool> bit_mask;
        StateTuple(Entry e, vector<bool> b) : entry(e), bit_mask(b) {}
    };


    class Bucket {
        vector<StateTuple> states;
        vector<int> facts_count;
        vector<double> facts_mean;

public:

        Bucket () {
            facts_count = vector<int>(utils::get_num_facts());
            facts_mean = vector<double>(utils::get_num_facts());
        }

        bool operator <(const Bucket &other) const {
            return states.size() < other.states.size();
        }

        double get_distance(const vector<bool> &bit_mask) const {
            assert(static_cast<int>(bit_mask.size()) == utils::get_num_facts());
            double sum = 0;
            for (size_t i = 0; i < bit_mask.size(); i++) {
                double diff = abs(facts_mean[i] - bit_mask[i]);
                sum += diff;
            }
            return sum;
        }

        int size() {
            return states.size();
        }

        void clear() {
            states.clear();
            fill(facts_count.begin(), facts_count.end(), 0);
            fill(facts_mean.begin(), facts_mean.end(), 0);
        }

        void clear_but_keep_means() {
            states.clear();
            fill(facts_count.begin(), facts_count.end(), 0);
        }

        void insert(StateTuple &state, bool keepMeans) {
            states.push_back(state);
            for (int i = 0; i < utils::get_num_facts(); ++i) {
                if (state.bit_mask[i]) {
                    facts_count[i]++;
                }
                if (!keepMeans) {
                    facts_mean[i] = static_cast<double>(facts_count[i]) / size();
                }
            }
        }

        Entry pop(int pos) {
            assert(utils::in_bounds(pos, states));
            StateTuple state = utils::swap_and_pop_from_vector(states, pos);
            Entry &result = state.entry;
            vector<bool> &bit_mask = states[pos].bit_mask;
            if (!empty()) {
                for (int i = 0; i < utils::get_num_facts(); i++) {
                    if (bit_mask[i]) {
                        facts_count[i]--;
                    }
                    facts_mean[i] = static_cast<double>(facts_count[i]) / states.size();
                }
            } else {
                clear();
            }
            return result;
        }

        StateTuple pop_end_state_tuple() {
            StateTuple result = states.back();
            states.pop_back();
            return result;
        }

        bool empty() {
            return states.empty();
        }

        vector<double> &get_facts_mean() {
            return facts_mean;
        }

        vector<int> &get_facts_count() {
            return facts_count;
        }

        void recalculate_mean() {
            if (!states.empty()) {
                for (int i = 0; i < utils::get_num_facts(); i++) {
                    facts_mean[i] = static_cast<double>(facts_count[i]) / states.size();
                }
            } else {
                generate(facts_mean.begin(), facts_mean.end(), []() {
                    return static_cast<double>(rand()) / (RAND_MAX);
                });
            }
        }
    };


    int get_min_index(const vector<bool> &bit_mask) {
        double min_distance = numeric_limits<double>::max();
        vector<int> min_indexes;
        for (int i = 0; i < static_cast<int>(buckets.size()); i++) {
            double distance = buckets[i].get_distance(bit_mask);
            if (distance < min_distance) {
                min_distance = distance;
                min_indexes.clear();
                min_indexes.push_back(i);
            }
            if (distance == min_distance) {
                min_indexes.push_back(i);
            }
        }
        if (min_indexes.size() == 1) {
            assert(utils::in_bounds(min_indexes[0], buckets));
            return min_indexes[0];
        } else {
            assert((int)min_indexes.size() > 0);
            int index = (*g_rng())((int)min_indexes.size());
            assert(utils::in_bounds(min_indexes[index], buckets));
            return min_indexes[index];
        }
    }

    void k_means() {
        utils::CountdownTimer stop_timer(stop_time);
        while (!stop_timer.is_expired()) {
            for (Bucket &bucket : buckets) {
                bucket.clear_but_keep_means();
            }
            for (StateTuple &state : all_states) {
                int min_index = get_min_index(state.bit_mask);
                buckets[min_index].insert(state, true);
            }
            for (Bucket bucket : buckets) {
                bucket.recalculate_mean();
            }
        }
        sort(buckets.rbegin(), buckets.rend());
        while (buckets[num_filled - 1].empty()) {
            --num_filled;
        }
    }

    int num_buckets;
    double interval;
    double stop_time;
    vector<Bucket> buckets = vector<Bucket>(num_buckets);
    int num_filled = 0;
    utils::Timer timer;
    vector<StateTuple> all_states;

protected:
    virtual void do_insertion(
        EvaluationContext &eval_context, const Entry &entry) override;

public:
    explicit ClusterBasedOpenList(const Options &opts);
    virtual ~ClusterBasedOpenList() override = default;

    virtual Entry remove_min(vector<int> *key = nullptr) override;
    virtual bool empty() const override;
    virtual void clear() override;
    virtual bool is_dead_end(EvaluationContext &eval_context) const override;
    virtual bool is_reliable_dead_end(
        EvaluationContext &eval_context) const override;
    virtual void get_involved_heuristics(set<Heuristic *> &hset) override;
};

template<class Entry>
void ClusterBasedOpenList<Entry>::do_insertion(EvaluationContext &eval_context, const Entry &entry) {
    const GlobalState &state = eval_context.get_state();
    int index_to_insert = -1;
    vector<bool> bit_mask = utils::create_bit_mask(state);
    StateTuple stateTuple = StateTuple(entry, bit_mask);
    if (num_filled < num_buckets) {
        assert(utils::in_bounds(num_filled, buckets));
        index_to_insert = num_filled;
        num_filled++;
        if (num_filled == num_buckets && interval > 0) {
            timer.reset();
        }
    } else {
        index_to_insert = get_min_index(bit_mask);
    }
    buckets[index_to_insert].insert(stateTuple, false);
    all_states.push_back(stateTuple);
    if (interval > 0 && num_filled == num_buckets) {
        if (timer() > interval) {
            k_means();
            timer.reset();
        }
    }
}


template<class Entry>
ClusterBasedOpenList<Entry>::ClusterBasedOpenList(const Options &opts)
    : num_buckets(opts.get<int>("num_buckets")),
      interval(opts.get<double>("interval")),
      stop_time(opts.get<double>("stop_time")) {
}

template<class Entry>
Entry ClusterBasedOpenList<Entry>::remove_min(vector<int> *) {
    assert(num_filled > 0);
    int index = (*g_rng())(num_filled);
    assert(utils::in_bounds(index, buckets));
    Bucket &bucket = buckets[index];
    assert(bucket.size() > 0);
    int pos = (*g_rng())(static_cast<int>(bucket.size()));
    Entry result = bucket.pop(pos);
    if (bucket.empty()) {
        assert(utils::in_bounds(num_filled - 1, buckets));
        buckets[index] = buckets[num_filled - 1];
        buckets[num_filled - 1].clear();
        num_filled--;
    }
    return result;
}

template<class Entry>
bool ClusterBasedOpenList<Entry>::empty() const {
    return num_filled == 0;
}

template<class Entry>
void ClusterBasedOpenList<Entry>::clear() {
    buckets.clear();
}

template<class Entry>
bool ClusterBasedOpenList<Entry>::is_dead_end(
    EvaluationContext &eval_context) const {
    // If one evaluator is sure we have a dead end, return true.
    if (is_reliable_dead_end(eval_context))
        return true;
    return false;
}

template<class Entry>
bool ClusterBasedOpenList<Entry>::is_reliable_dead_end(
    EvaluationContext &) const {
    return false;
}

template<class Entry>
void ClusterBasedOpenList<Entry>::get_involved_heuristics(
    set<Heuristic *> &) {
}


ClusterBasedOpenListFactory::ClusterBasedOpenListFactory(
    const Options &options)
    : options(options) {
}


unique_ptr<StateOpenList>
ClusterBasedOpenListFactory::create_state_open_list() {
    return utils::make_unique_ptr<ClusterBasedOpenList<StateOpenListEntry>>(options);
}

unique_ptr<EdgeOpenList>
ClusterBasedOpenListFactory::create_edge_open_list() {
    return utils::make_unique_ptr<ClusterBasedOpenList<EdgeOpenListEntry>>(options);
}

static shared_ptr<OpenListFactory> _parse(OptionParser &parser) {
    parser.document_synopsis(
        "Cluster-based open list",
        "Uses multiple evaluators to assign entries to buckets. "
        "All entries in a bucket have the same evaluator values. "
        "When retrieving an entry, a bucket is chosen uniformly at "
        "random and one of the contained entries is selected "
        "uniformly randomly. "
        "The algorithm is based on" + utils::format_paper_reference(
            {"Fan Xie", "Martin Mueller", "Robert Holte", "Tatsuya Imai"},
            "Cluster-Based Exploration with Multiple Search Queues for"
            " Satisficing Planning",
            "http://www.aaai.org/ocs/index.php/AAAI/AAAI14/paper/view/8472/8705",
            "Proceedings of the Twenty-Eigth AAAI Conference Conference"
            " on Artificial Intelligence (AAAI 2014)",
            "2395-2401",
            "AAAI Press 2014"));
    parser.add_option<int> (
        "num_buckets",
        "Amount of buckets.",
        "100",
        Bounds("1", "infinity"));
    parser.add_option<double>(
        "interval",
        "Time interval that defines how often k-means clustering is applied - 0 means not reclustering with kmeans",
        "0",
        Bounds("0", "infinity"));
    parser.add_option<double>(
        "stop_time",
        "Time that defines how long k-means clustering is applied",
        "2",
        Bounds("0.1", "infinity"));

    Options opts = parser.parse();
    if (parser.dry_run())
        return nullptr;
    else
        return make_shared<ClusterBasedOpenListFactory>(opts);
}

static PluginShared<OpenListFactory> _plugin("cluster_based", _parse);
