#include "novelty_heuristic.h"

#include "../global_state.h"
#include "../option_parser.h"
#include "../plugin.h"
#include "../task_tools.h"
#include "../evaluation_context.h"

#include <cstddef>
#include <limits>
#include <utility>

using namespace std;

namespace novelty_heuristic {
    NoveltyHeuristic::NoveltyHeuristic(const Options &opts)
            : Heuristic(opts),
              bucket_h(opts.get<ScalarEvaluator *>("bh")),
              novelty_bound(opts.get<int>("n")){
        cout << "Initializing novelty search heuristic..." << endl;

        VariablesProxy variables = task_proxy.get_variables();
        int nr_of_var = variables.size();
        assert(novelty_bound<nr_of_var);

        int unique_sets = factorial(nr_of_var)/(factorial(nr_of_var-novelty_bound) - factorial(novelty_bound));
        cout << "Unique sets: " << unique_sets << endl;

        for(unsigned int i=0;i<variables.size();++i){
            domain_sizes.push_back(variables[i].get_domain_size());
        }

        offsets = {0};
        for(int i = 1;i<=novelty_bound;i++){
            comb(i);
        }

        restrict_ins = novelty_bound; // Global var: Discard higher novelties on insert

        // The last offset is the actual size because the offsets start at 0
        lookup_size = offsets[offsets.size() - 1];
        offsets.pop_back();
    }

    NoveltyHeuristic::~NoveltyHeuristic() {
    }

    int NoveltyHeuristic::compute_heuristic(const GlobalState &global_state) {
        /**
         * Returns 0 if state is a goal state.
         * Loops through all subsets of size [1 - novelty bound] and updates
         * the lookup table if a new partial state is encountered.
         * Returns the size of the smallest new partial state or
         * [novelty bound + 1] (no new partial states encountered).
         */
        State state = convert_global_state(global_state);
        if (is_goal_state(task_proxy, state))
            return 0;
        else {
            EvaluationContext eval_context(global_state);
            int h_value = eval_context.get_heuristic_value(bucket_h);
            if(lookup.find(h_value)==lookup.end()){
                lookup[h_value] = std::vector<bool>(lookup_size,false);
            }

            std::vector<bool> &current_bucket = lookup.find(h_value)->second;
            vector<int> values = state.get_values();
            unsigned int novelty = novelty_bound + 1;
            int offset_index = 0;

            for(auto subset = subsets.begin();subset != subsets.end(); subset++){
                int offset = offsets[offset_index];
                for(auto position = subset->begin(); position != subset->end(); position++){
                    if(position == subset->end() -1){
                        offset += values[*position];
                    } else {
                        for (auto next_pos = position + 1; next_pos != subset->end(); next_pos++){
                            offset += values[*position] * domain_sizes[*next_pos];
                        }
                    }
                }
                if(!current_bucket[offset]){
                    novelty = std::min(novelty, subset->size());
                    current_bucket[offset] = true;
                }
                offset_index++;
            }
            return novelty;
        }
    }

    static Heuristic *_parse(OptionParser &parser) {
        parser.document_synopsis("Novelty heuristic",
                                 "Returns novelty if novelty lower than bound, "
                                         "0 for goal states and"
                                         "bound + 1 where novelties are exceeding the bound");
        parser.document_language_support("action costs", "irrelevant");
        parser.document_language_support("conditional effects", "irrelevant");
        parser.document_language_support("axioms", "unsupported");
        parser.document_property("admissible", "no");
        parser.document_property("consistent", "no");
        parser.document_property("safe", "yes");
        parser.document_property("preferred operators", "no");
        parser.add_option<int>("n", "novelty bound", "1");
        parser.add_option<ScalarEvaluator *>("bh", "bucket heuristic");

        Heuristic::add_options_to_parser(parser);
        Options opts = parser.parse();
        if (parser.dry_run())
            return 0;
        else
            return new NoveltyHeuristic(opts);
    }

    unsigned int NoveltyHeuristic::factorial(unsigned int n)
    {
        if (n == 0)
            return 1;
        return n * factorial(n - 1);
    }

    void NoveltyHeuristic::comb(int K)
    {
        /**
         * Generates all possible subsets and their offset on the lookup table
         */
        int N = domain_sizes.size();
        string bitmask(K, 1); // K leading 1's
        bitmask.resize(N, 0); // N-K trailing 0's

        // print integers and permute bitmask
        do {
            vector<int> subset;
            int next_off = 1;
            for (int i = 0; i < N; ++i) // [0..N-1] integers
            {
                if (bitmask[i]){
                    subset.push_back(i);
                    next_off *= domain_sizes[i];
                }
            }
            subsets.push_back(subset);
            offsets.push_back(offsets[offsets.size()-1] + next_off);
        } while (prev_permutation(bitmask.begin(), bitmask.end()));
    }

    static Plugin<Heuristic> _plugin("novelty", _parse);
}
