#include "saturated_pdbs.h"

#include "compression.h"
#include "pattern_database.h"
#include "pattern_database_factory.h"

#include "../task_proxy.h"

#include "../utils/logging.h"

#include <iostream>
#include <limits>
#include <vector>

using namespace std;

namespace pdbs {

    /* Whether and how the saturated cost function required for
     * saturated cost partitioning can be computed efficiently depends
     * on the type of heuristic. If h is an abstraction heuristic,
     * the saturated cost of operator o is the maximum over
     * h(s) − h(s_0) for all abstract state transitions s → s_0 in-
     * duced by o.
     * Seipp et al. ICAPS 2017, SCP Page 4
    */
    std::vector<int> SaturatedPDBs::calculate_scf(
        const OperatorsProxy &operators_proxy,
        const shared_ptr<PatternDatabase> &pdb) {

        std::vector<int> used_costs(operators_proxy.size());
        const std::vector<int> distances = pdb->get_distances();

        const std::vector<std::vector<std::pair<int, int>>> operator_transitions = pdb->get_operator_transitions();
        for (OperatorProxy op : operators_proxy) {
            // Find all state transitions (st) (state s to state s') for each operator
            // 1) Check if operator is relevant, if not default value 0 (not bothering to get transitions)
            if (!pdb->is_operator_relevant(op)) {
                continue;
            }

            int max = 0;
            for (std::pair<int, int> p : operator_transitions[op.get_id()]) {
                int value = distances[p.first] - distances[p.second];
                max = (value > max) ? value : max;
            }
            used_costs[op.get_id()] = max;
        }
        return used_costs;
}

SaturatedPDBs::SaturatedPDBs(
    const TaskProxy &task_proxy, const PatternCollection &patterns, const Options &opts) {
    vector<int> remaining_operator_costs;
    OperatorsProxy operators = task_proxy.get_operators();
    remaining_operator_costs.reserve(operators.size());
    // Initial cost vector c_0
    for (OperatorProxy op : operators) {
        remaining_operator_costs.push_back(op.get_cost());
    }

    int state_budget = opts.get<int>("max_size_compr_col");
    int remaining_budget = state_budget;

    pattern_databases.reserve(patterns.size());

    for (const Pattern &pattern : patterns) {
        shared_ptr<PatternDatabase> pdb = create_default_pdb(
            task_proxy, pattern, false, remaining_operator_costs, true);
        // used costs have to be calculated on the original pdb
        std::vector<int> used_costs = calculate_scf(operators, pdb);
        if (opts.get<bool>("compress_pdbs") && remaining_budget > 0) {
            int used_budget = compute_best_pdb(
                    pattern,
                    pdb,
                    task_proxy,
                    opts,
                    remaining_budget,
                    remaining_operator_costs
            );
            remaining_budget -= used_budget;
        }

        // Calculate new remaining_costs
        std::vector<int> result;
        std::transform(remaining_operator_costs.begin(),
                       remaining_operator_costs.end(),
                       used_costs.begin(),
                       std::back_inserter(result), [&](int l, int r) {
            return (l - r);
        });
        remaining_operator_costs = result;

        pattern_databases.push_back(pdb);
    }
    print_avg_finite_mean_for_PDBCollection(pattern_databases);
}

int SaturatedPDBs::get_value(const State &state) const {
    /*
      Because we use cost partitioning, we can simply add up all
      heuristic values of all patterns in the pattern collection.
    */
    int h_val = 0;
    for (const shared_ptr<PatternDatabase> &pdb : pattern_databases) {
        int pdb_value = pdb->get_value(state);
        if (pdb_value == numeric_limits<int>::max())
            return numeric_limits<int>::max();
        h_val += pdb_value;
    }
    return h_val;
}

double SaturatedPDBs::compute_approx_mean_finite_h() const {
    double approx_mean_finite_h = 0;
    for (const shared_ptr<PatternDatabase> &pdb : pattern_databases) {
        approx_mean_finite_h += pdb->compute_mean_finite_h();
    }
    return approx_mean_finite_h;
}

void SaturatedPDBs::dump() const {
    for (const shared_ptr<PatternDatabase> &pdb : pattern_databases) {
        utils::g_log << pdb->get_pattern() << endl;
    }
}
}
