#include "opport_uniform_pdbs.h"

#include "compression_double.h"
#include "pattern_database_factory_double.h"

#include "../task_proxy.h"

#include "../utils/logging.h"

#include <iostream>
#include <limits>
#include <memory>
#include <vector>
#include <array>
#include <math.h>

using namespace std;

namespace pdbs {

    // 0, if action not relevant for heuristic, remaining_costs / operator_count otherwise
    std::vector<double> calculate_offered_costs(const std::vector<double> &remaining_costs,
                                                const std::vector<int> &operator_count,
                                                const std::vector<bool> pattern_to_operatorrelevance_map) {
        size_t operators = remaining_costs.size();
        std::vector<double> offered_costs;
        offered_costs.reserve(operators);
        for (size_t i = 0; i < operators; i++) {
            if (pattern_to_operatorrelevance_map[i]) {
                offered_costs.push_back(remaining_costs[i] / operator_count[i]);
            } else {
                offered_costs.push_back(0.0);
            }
        }
        return offered_costs;
    }

    std::vector<double> OpportUniformPDBs::calculate_scf(
            const OperatorsProxy &operators_proxy,
            const std::shared_ptr<PatternDatabaseDouble> &pdb) {

        std::vector<double> used_costs(operators_proxy.size());
        const std::vector<double> distances = pdb->get_distances();

        const std::vector<std::vector<std::pair<int, int>>> operator_transitions = pdb->get_operator_transitions();
        for (OperatorProxy op : operators_proxy) {
            // Find all state transitions (st) (state s to state s') for each operator
            // 1) Check if operator is relevant, if not default value 0 (not bothering to get transitions)
            if (!pdb->is_operator_relevant(op)) {
                continue;
            }

            int max = 0;
            for (std::pair<int, int> p : operator_transitions[op.get_id()]) {
                double value = distances[p.first] - distances[p.second];
                max = (value > max) ? value : max;
            }
            used_costs[op.get_id()] = max;
        }
        return used_costs;
    }

    /*
     * Helper function to determine if an operator is relevant in the provided pattern
     */
    bool OpportUniformPDBs::check_operator_relevance_in_pattern(const OperatorProxy &op, const Pattern &pattern) {
        for (EffectProxy effect : op.get_effects()) {
            int var_id = effect.get_fact().get_variable().get_id();
            if (binary_search(pattern.begin(), pattern.end(), var_id)) {
                return true;
            }
        }
        return false;
    }

    /*
     * Constructs a pattern database collection with Opportunistic Uniform cost partitioning.
     * - As first step for each operator in the task we count in how many patterns it has relevance.
     * - A lookup table is constructed to later check if an operator was relevant for a certain pattern.
     */
    OpportUniformPDBs::OpportUniformPDBs(
            const TaskProxy &task_proxy,
            const PatternCollection &patterns,
            const Options &opts) {
        OperatorsProxy operators = task_proxy.get_operators();
        const size_t operator_size = operators.size();
        vector<double> remaining_costs;
        remaining_costs.reserve(operator_size);
        vector<int> operator_count;
        operator_count.reserve(operator_size);

        int state_budget = opts.get<int>("max_size_compr_col");
        int remaining_budget = state_budget;

        // Construct c_0_dash = c
        for (OperatorProxy op : operators) {
            remaining_costs.push_back(static_cast<double>(op.get_cost()));
            operator_count.push_back(0);
        }
        pattern_databases.reserve(patterns.size());

        // Vector to map pattern to relevance (as in uniform CP)
        vector<vector<bool>> pattern_to_operatorrelevance_map;
        pattern_to_operatorrelevance_map.reserve(patterns.size());
        // For each pattern check each operator if it is relevant and add to relevance map and operator counter
        for (size_t i = 0; i < patterns.size(); i++) {
            vector<bool> current_map;
            const Pattern current_pattern = patterns[i];
            current_map.reserve(operator_size);
            for (OperatorProxy op: operators) {
                // Check if relevant
                bool is_relevant = OpportUniformPDBs::check_operator_relevance_in_pattern(op, current_pattern);
                current_map.push_back(is_relevant);
                if (is_relevant) {
                    operator_count[op.get_id()] += 1;
                }
            }
            pattern_to_operatorrelevance_map.push_back(current_map);
        }

        // Loop for calculating c_waved (offered_costs), c_i (used_costs) and the next c_hat
        for (size_t i = 0; i < patterns.size(); i++) {

            //c_wave_i (offered costs)
            std::vector<double> offered_costs = calculate_offered_costs(
                    remaining_costs,
                    operator_count,
                    pattern_to_operatorrelevance_map[i]
                    );
            // Calculate PDB
            std::shared_ptr<PatternDatabaseDouble> pdb = create_default_pdb_double(
                    task_proxy,
                    patterns[i],
                    false,
                    offered_costs,
                    true
                    );

            // Calculate used costs (c_i)
            std::vector<double> used_costs = OpportUniformPDBs::calculate_scf(operators, pdb);

            if (opts.get<bool>("compress_pdbs") && remaining_budget > 0) {
                int used_state_costs = compute_best_pdb_double(
                        patterns[i],
                        pdb,
                        task_proxy,
                        opts,
                        remaining_budget,
                        offered_costs
                        );
                remaining_budget -= used_state_costs;
            }
            pattern_databases.push_back(pdb);

            // Elementwise subtraction
            std::vector<double> result;
            std::transform(remaining_costs.begin(),
                           remaining_costs.end(),
                           used_costs.begin(),
                           std::back_inserter(result), [&](double l, double r) {
                        return (l - r);
                    });
            remaining_costs = result;
        }

        print_avg_finite_mean_for_PDBCollectionDouble(pattern_databases);

    }

    double OpportUniformPDBs::get_value(const State &state) const {
        /*
          Because we use cost partitioning, we can simply add up all
          heuristic values of all patterns in the pattern collection.
        */
        double h_val = 0;
        for (const shared_ptr<PatternDatabaseDouble> &pdb : pattern_databases) {
            double pdb_value = pdb->get_value(state);
            if (pdb_value == numeric_limits<double>::max()) {
                return numeric_limits<double>::max();
            }
            h_val += pdb_value;
        }
        return h_val;
    }



    double OpportUniformPDBs::compute_approx_mean_finite_h() const {
        double approx_mean_finite_h = 0;
        for (const shared_ptr<PatternDatabaseDouble> &pdb : pattern_databases) {
            approx_mean_finite_h += pdb->compute_mean_finite_h();
        }
        return approx_mean_finite_h;
    }

    void OpportUniformPDBs::dump() const {
        for (const shared_ptr<PatternDatabaseDouble> &pdb : pattern_databases) {
            utils::g_log << pdb->get_pattern() << endl;
        }
    }
}
