#include "uniform_pdbs.h"

#include "compression_double.h"
#include "pattern_database_double.h"
#include "pattern_database_factory_double.h"

#include "../task_proxy.h"

#include "../utils/logging.h"

#include <iostream>
#include <limits>
#include <memory>
#include <vector>
#include <array>
#include <math.h>

using namespace std;

namespace pdbs {

    /*
     * Helper function to determine if an operator is relevant in the provided pattern
     */
    bool check_operator_relevance_in_pattern(const OperatorProxy &op, const Pattern &pattern) {
        for (EffectProxy effect : op.get_effects()) {
            int var_id = effect.get_fact().get_variable().get_id();
            if (binary_search(pattern.begin(), pattern.end(), var_id)) {
                return true;
            }
        }
        return false;
    }

    UniformPDBs::UniformPDBs(const TaskProxy &task_proxy, const PatternCollection &patterns, const Options &opts) {
        OperatorsProxy operators = task_proxy.get_operators();
        const size_t operator_size = operators.size();
        vector<double> operator_costs;
        operator_costs.reserve(operator_size);
        vector<int> operator_count;
        operator_count.reserve(operator_size);

        int state_budget = opts.get<int>("max_size_compr_col");
        int remaining_budget = state_budget;

        // Setting up operator vectors
        for (OperatorProxy op : operators) {
            operator_costs.push_back(static_cast<double>(op.get_cost()));
            operator_count.push_back(0);
        }

        // Vector to map pattern to relevance
        vector<vector<bool>> pattern_to_operatorrelevance_map;
        pattern_to_operatorrelevance_map.reserve(patterns.size());
        // For each pattern check each operator if it is relevant and add to relevance map and operator counter
        for (size_t i = 0; i < patterns.size(); i++) {
            vector<bool> current_map;
            const Pattern current_pattern = patterns[i];
            current_map.reserve(operator_size);
            for (OperatorProxy op: operators) {
                // Check if relevant
                bool is_relevant = check_operator_relevance_in_pattern(op, current_pattern);
                current_map.push_back(is_relevant);
                if (is_relevant) {
                    operator_count[op.get_id()] += 1;
                }
            }
            pattern_to_operatorrelevance_map.push_back(current_map);
        }

        pattern_databases.reserve(patterns.size());
        // Build operator cost vectors according to CP and calculate pdbs
        for (size_t i = 0; i < patterns.size(); i++) {
            vector<double> current_costs;
            current_costs.reserve(operator_size);
            for (OperatorProxy op : operators) {
                double costs = 0;
                // if the operator is irrelevant in the given pattern, costs = 0
                if (pattern_to_operatorrelevance_map[i][op.get_id()]) {
                    costs = operator_costs[op.get_id()] /
                            static_cast<double>(operator_count[op.get_id()]);
                }
                current_costs.push_back(costs);
            }
            shared_ptr<PatternDatabaseDouble> pdb = create_default_pdb_double(
                    task_proxy, patterns[i], false, current_costs);

            if (opts.get<bool>("compress_pdbs") && remaining_budget > 0) {
                int used_state_costs = compute_best_pdb_double(
                        patterns[i],
                        pdb,
                        task_proxy,
                        opts,
                        remaining_budget,
                        current_costs
                );
                remaining_budget -= used_state_costs;
            }
            pattern_databases.push_back(pdb);
        }

        print_avg_finite_mean_for_PDBCollectionDouble(pattern_databases);

    }


    double UniformPDBs::get_value(const State &state) const {
        /*
          Because we use cost partitioning, we can simply add up all
          heuristic values of all patterns in the pattern collection.
        */
        double h_val = 0;
        for (const shared_ptr<PatternDatabaseDouble> &pdb : pattern_databases) {
            double pdb_value = pdb->get_value(state);
            if (pdb_value == numeric_limits<double>::max())
                return numeric_limits<double>::max();
            h_val += pdb_value;
        }
        return h_val;
    }



    double UniformPDBs::compute_approx_mean_finite_h() const {
        double approx_mean_finite_h = 0;
        for (const shared_ptr<PatternDatabaseDouble> &pdb : pattern_databases) {
            approx_mean_finite_h += pdb->compute_mean_finite_h();
        }
        return approx_mean_finite_h;
    }

    void UniformPDBs::dump() const {
        for (const shared_ptr<PatternDatabaseDouble> &pdb : pattern_databases) {
            utils::g_log << pdb->get_pattern() << endl;
        }
    }
}
