#include "pdb_h_tensor.h"

#include "pattern_database.h"
#include "types.h"

#include "../globals.h"
//#include "../utils/system.h"
#include "../potentials/util.h"

#include <vector>
#include <limits>

using namespace std;
 
namespace pdbs{

size_t PDBhTensor::get_var_nr(int var_id){
    size_t result = 0;
    while(pattern[result]!=var_id){
        result++;
        if(result==pattern.size()){
            cerr << "var_"<< var_id << " not in pattern\n";
            utils::exit_with(utils::ExitCode::CRITICAL_ERROR);
        }
    }
    return result;
}

void PDBhTensor::get_all_h(const PatternDatabase &pdb,Pattern free_vars,
    vector<int> values){
    if(free_vars.size()==0){
        size_t index = 0;
        size_t forward_prod = 1;
        for (size_t var_nr=0; var_nr<pattern.size()-1; ++var_nr){
            index += forward_prod*values[pattern[var_nr]];
            forward_prod *= domain_sizes[var_nr];
        }
        index += forward_prod*values[pattern[pattern.size()-1]];
        assert(index<dom_size_prod);
        double mean = pdb.compute_mean_finite_h();
        if(!mean || mean<1.)mean=1.;
        double value = pdb.get_value(State(*g_root_task(), values));
        double eps = 0.5;/*might be unnecessary (already done in get_avg_h)*/
        if(handle_dead_ends&&value+eps>numeric_limits<int>::max())value=mean;
        h_values[index]=value;
        /*debug/cout << "h_values[" << index << "]=" << h_values[index] <<"\n";
            /end*/
    }
    else{
        int next_var = free_vars[0];
        Pattern next_free_vars = potentials::pattern_without(free_vars, next_var);
        size_t var_nr = get_var_nr(next_var);
        for(int val=0; val<static_cast<int>(domain_sizes[var_nr]); ++val){
            values[next_var]=val;
            get_all_h(pdb, next_free_vars, values);
        }
    }
}
PDBhTensor::PDBhTensor(const PatternDatabase &pdb,
    const VariablesProxy &vars, const bool &handle_dead_ends)
        :pattern(pdb.get_pattern()), handle_dead_ends(handle_dead_ends){
/*----------------------------------init--------------------------------------*/
    domain_sizes.resize(pattern.size());
    dom_size_prod = 1;
    for(size_t i=0; i<pattern.size(); ++i){
        domain_sizes[i] = vars[pattern[i]].get_domain_size();
        dom_size_prod *= domain_sizes[i];
    }
    h_values = new double[dom_size_prod];
    vector<int> values;
    values.resize(vars.size());
/*--------------------------------get_all_h----------------------------------*/
    get_all_h(pdb, pattern, values);
}

double &PDBhTensor::operator[](vector<int> values) const{
    size_t index = 0;
    size_t forward_prod = 1;
    for (size_t var_nr=0; var_nr<pattern.size()-1; ++var_nr){
        index += forward_prod*values[pattern[var_nr]];
        forward_prod *= domain_sizes[var_nr];
    }
    index += forward_prod*values[pattern[pattern.size()-1]];
    assert(index<dom_size_prod);
    return h_values[index];
}

double PDBhTensor::get(vector<int> values) const{
    size_t index = 0;
    size_t forward_prod = 1;
    for (size_t var_nr=0; var_nr<pattern.size()-1; ++var_nr){
        index += forward_prod*values[pattern[var_nr]];
        forward_prod *= domain_sizes[var_nr];
    }
    index += forward_prod*values[pattern[pattern.size()-1]];
    assert(index<dom_size_prod);
    return h_values[index];
}
double PDBhTensor::get_total_sum() const{
    double sum = 0.;
    for (size_t i=0;i<dom_size_prod;++i){
        double h = h_values[i];
        sum+=h;
    }
    return sum;
}
}
