#include "compression_double.h"

#include "../task_utils/causal_graph.h"
#include "../utils/logging.h"


using namespace std;

namespace pdbs {

    void print_avg_finite_mean_for_PDBCollectionDouble(shared_ptr<PDBCollectionDouble> &coll) {
        shared_ptr<PatternDatabaseDouble> current_pdb;
        size_t entries = coll->size();
        double finite_mean_sum = 0.0;
        for (size_t i = 0; i < entries; i++) {
            current_pdb = (*coll)[i];
            finite_mean_sum += current_pdb->compute_mean_finite_h();
        }
        double average = finite_mean_sum / entries;
        utils::g_log << "Average mean_finite_h_value: " << average << endl;
    }

    // REALLY BAD solution, but could not figure out how to handle smartpointers and other PDBCollections
    // Zero One pdbs uses PBDCollection directly, Cannonical uses shared pointer
    void print_avg_finite_mean_for_PDBCollectionDouble(PDBCollectionDouble &coll) {
        shared_ptr<PatternDatabaseDouble> current_pdb;
        size_t entries = coll.size();
        double finite_mean_sum = 0.0;
        for (size_t i = 0; i < entries; i++) {
            current_pdb = (coll)[i];
            finite_mean_sum += current_pdb->compute_mean_finite_h();
        }
        double average = finite_mean_sum / entries;
        utils::g_log << "Average mean_finite_h_value: " << average << endl;
    }

    shared_ptr <PatternDatabaseDouble> min_compress(
            const shared_ptr <PatternDatabaseDouble> &small_pdb,
            const shared_ptr <PatternDatabaseDouble> &larger_pdb,
            int variable_domain_size) {
        vector<double> distances_new;
        const vector<double> distances_old = larger_pdb->get_distances();
        for (size_t i = 0; i < distances_old.size(); i += variable_domain_size) {
            double min_distance = numeric_limits<double>::max();
            for (int j = 0; j < variable_domain_size; j++) {
                if (distances_old[i + j] < min_distance) {
                    min_distance = distances_old[i + j];
                }
            }
            distances_new.push_back(min_distance);
        }
        vector <size_t> hash_mult = small_pdb->get_hashmultipliers();
        return make_shared<PatternDatabaseDouble>(
                small_pdb->get_pattern(),
                small_pdb->get_size(),
                move(distances_new),
                move(hash_mult),
                small_pdb->get_operator_transitions());
    }

    vector<int> missing_variables_double(const TaskProxy &task_proxy, const Pattern &pattern) {
        set<int> causally_relevant_variables;
        const causal_graph::CausalGraph &causal_graph = task_proxy.get_causal_graph();

        for (int var : pattern) {
            const vector<int> &causal_rel_vars = causal_graph.get_successors(var);
            causally_relevant_variables.insert(causal_rel_vars.begin(), causal_rel_vars.end());
        }
        // The pattern variables are in causal graph -> remove them
        const set<int> copy_causal_vars = causally_relevant_variables;
        for (int causal_var : copy_causal_vars) {
            for (int pattern_var : pattern) {
                if (causal_var == pattern_var) {
                    causally_relevant_variables.erase(causal_var);
                }
            }
        }

        return vector<int>(causally_relevant_variables.begin(), causally_relevant_variables.end());
    }

    Pattern enlarge_pattern_double(const Pattern &pattern, int variable) {
        Pattern larger_pattern(pattern);
        // Add variable to front for easier compressing (no sorting)
        larger_pattern.insert(larger_pattern.begin(), variable);
        return larger_pattern;
    }

    double compare_pdbs(shared_ptr <PatternDatabaseDouble> cur, shared_ptr <PatternDatabaseDouble> old) {
        int num_improvements = 0;

        const vector<double> cur_distances = cur->get_distances();
        const vector<double> old_distances = old->get_distances();

        size_t num_entries = cur_distances.size();
        for (size_t i = 0; i < num_entries; i++) {
            if (cur_distances[i] > old_distances[i]) {
                num_improvements++;
            }
        }
        // Divide by number of entries
        return num_improvements / static_cast<double>(num_entries);
    }

    vector <Pattern> compute_candidates_double(const Pattern &pattern, const TaskProxy &task) {
        vector<int> missing_vars = missing_variables_double(task, pattern);
        vector <Pattern> candidates;
        for (int variable : missing_vars) {
            candidates.push_back(enlarge_pattern_double(pattern, variable));
        }
        return candidates;
    }

    int find_pdb_hillclimbing(
            Pattern original_pattern,
            shared_ptr <PatternDatabaseDouble> &current_pdb,
            const TaskProxy &task,
            const Options &opts,
            const int remaining_states,
            const std::vector<double> &operator_costs = std::vector<double>()
    ) {
        int state_budget = remaining_states;
        int used_budget = 0;
        //utils::g_log << "Original Pattern: " << original_pattern << endl;
        double min_improvement = opts.get<double>("min_impr_compression");
        int max_iterations = opts.get<int>("max_iterations");
        int iterations= 0;
        // Determine initial candidates
        vector <Pattern> candidates = compute_candidates_double(original_pattern, task);
        shared_ptr <PatternDatabaseDouble> best_pdb = current_pdb;
        double best_improvement_score = 0.0;
        double prev_improvement_score = 0.0;
        Pattern best_pattern = original_pattern;
        // Hillclimbing iterations
        while (iterations < max_iterations) {
            // Min as parameter
            double last_improvement_delta = best_improvement_score - prev_improvement_score;
            if (iterations != 0 && last_improvement_delta < min_improvement) {
                current_pdb = best_pdb;
                return used_budget;
            }
            for (Pattern candidate : candidates) {
                if (used_budget >= state_budget) {
                    current_pdb = best_pdb;
                    return used_budget;
                }
                // Create enlargened PDB
                shared_ptr <PatternDatabaseDouble> larger_pdb = create_default_pdb_double(
                        task,
                        candidate,
                        false,
                        operator_costs);
                used_budget += larger_pdb->get_size();
                // Compress PDB
                // Find additional variables compared to original for compression
                vector<int> additional_vars;
                int length_difference = candidate.size() - original_pattern.size();
                for (int i = 0; i < length_difference; i++) {
                    additional_vars.push_back(candidate[i]);
                }
                // Calculate domain size of additional variables
                int dom_size = 1;
                for (int i : additional_vars) {
                    dom_size *= task.get_variables()[i].get_domain_size();
                }
                shared_ptr <PatternDatabaseDouble> compressed_pdb = min_compress(current_pdb, larger_pdb, dom_size);

                // Evaluate and update if best
                double improvement_score = compare_pdbs(compressed_pdb, current_pdb);
                if (improvement_score > best_improvement_score) {
                    // Time check and return best if time ran out
                    prev_improvement_score = best_improvement_score;
                    best_improvement_score = improvement_score;
                    best_pdb = compressed_pdb;
                    best_pattern = candidate;
                    utils::g_log << "Compression - Improvement found: " << improvement_score << endl;
                }
            }
            // Prepare new candidates
            candidates = compute_candidates_double(best_pattern, task);

            iterations++;
        }
        current_pdb = best_pdb;
        return used_budget;
    }

    int find_pdb_randomwalk(
            Pattern original_pattern,
            shared_ptr <PatternDatabaseDouble> &current_pdb,
            const TaskProxy &task,
            const Options &opts,
            const int remaining_states,
            const std::vector<double> &operator_costs = std::vector<double>()
    ) {
        // Get starting parameters (size of original pattern, pattern and limit for the new size)
        int num_states = current_pdb->get_size();
        int states_budget = remaining_states;

        // Get the all causally relevant variables currently not in the pattern
        std::vector<int> missing_vars = missing_variables_double(task, original_pattern);
        // Choose one of those at random and check if the size would still be within the limits

        const VariablesProxy task_variables = task.get_variables();
        Pattern new_pattern = original_pattern;
        if (missing_vars.empty()) {
            return 0;
        }
        while (num_states < states_budget) {
            int rnd_num = rand() % missing_vars.size();
            int candidate = missing_vars[rnd_num];
            if (task_variables[candidate].get_domain_size() * num_states < states_budget) {
                // Add variable to pattern, update values and proceed
                num_states *= task_variables[candidate].get_domain_size();
                new_pattern = enlarge_pattern_double(new_pattern, candidate);
            }
            missing_vars.erase(missing_vars.begin() + rnd_num);
            if (missing_vars.empty()) {
                break;
            }
        }
        // Create PDB from largest pattern after randomly creating the larger pattern
        std::shared_ptr<PatternDatabaseDouble> larger_pdb = create_default_pdb_double(
                task,
                new_pattern,
                false,
                operator_costs);
        // Compress PDB
        // Find additional variables compared to original for compression
        vector<int> additional_vars;
        int length_difference = new_pattern.size() - original_pattern.size();
        for (int i = 0; i < length_difference; i++) {
            additional_vars.push_back(new_pattern[i]);
        }
        // Calculate domain size of additional variables
        int dom_size = 1;
        for (int i : additional_vars) {
            dom_size *= task.get_variables()[i].get_domain_size();
        }
        shared_ptr <PatternDatabaseDouble> compressed_pdb = min_compress(current_pdb, larger_pdb, dom_size);
        // Compare the two pdbs, if the compressed is better, return it
        if (compare_pdbs(compressed_pdb, current_pdb) > std::numeric_limits<double>::epsilon()) {
            current_pdb = compressed_pdb;
        }
        return num_states;
    }

    int compute_best_pdb_double(Pattern original_pattern,
                                                  shared_ptr <PatternDatabaseDouble> &current_pdb,
                                                  const TaskProxy &task,
                                                  const Options &opts,
                                                  const int remaining_states_collection,
                                                  const std::vector<double> &operator_costs) {
        // Calculate state budget for boosting
        int max_states = min(remaining_states_collection, opts.get<int>("max_size_compr"));

        CompressionAlgorithm compr_algo = opts.get<CompressionAlgorithm>("compr_algo");
        if (compr_algo == CompressionAlgorithm::HILLCLIMBING) {
            return find_pdb_hillclimbing(original_pattern, current_pdb, task, opts, max_states, operator_costs);
        } else if (compr_algo == CompressionAlgorithm::RANDOMWALK) {
            return find_pdb_randomwalk(original_pattern, current_pdb, task, opts, max_states, operator_costs);
        } else {
            utils::g_log << "Invalid compression algorithm provided. Aborting" << endl;
            abort();
        }
    }
}