#include "potential_calculator_algebraic.h"

#include "potential_calculator.h"
#include "util.h"

#include "../option_parser.h"
#include "../plugin.h"

#include "../pdbs/pattern_database.h"
#include "../utils/lin_alg.h"

#include <vector>
#include <map>
#include <queue>

using namespace std;

namespace potentials{

PotentialCalculatorAlgebraic::PotentialCalculatorAlgebraic(const Options &opts)
    :PotentialCalculator(){
    handle_dead_ends=opts.get<bool>("handle_dead_ends");
}

vector<vector<double>> PotentialCalculatorAlgebraic::make_Atb(
    const pdbs::PDBhTensor h_tensor,
    const pdbs::Pattern pattern,
    const vector<size_t> domain_sizes){
    size_t pattern_size = pattern.size();
    vector<vector<double>> atb;
    atb.resize(pattern_size);
    size_t num_vars = domain_sizes.size();
    vector<int> values;
    values.resize(num_vars);
    size_t dom_size_prod = 1;
    for(size_t i=0;i<pattern_size;++i){
        dom_size_prod*=domain_sizes[pattern[i]];
    }
    for(size_t i=0;i<pattern_size;++i){
        int var_id = pattern[i];
        size_t dom_size = domain_sizes[var_id];
        pdbs::Pattern free_vars = pattern_without(pattern,var_id);
        for(int val=0; val<static_cast<int>(dom_size)-1; ++val){
            values[var_id]=val;
            IntermediateAverage temp 
                = get_avg_h(h_tensor, values, free_vars, domain_sizes);
            //sum over all h-vals, where var_id==val
            double b = (temp.sum*dom_size-h_tensor.get_total_sum())/dom_size_prod;
            atb[i].push_back(b);
        }
    }
    return atb;
}

vector<vector<vector<double>>> PotentialCalculatorAlgebraic::make_AtA(
    pdbs::Pattern pattern, vector<size_t> domain_sizes){
    size_t pattern_size = pattern.size();
    vector<vector<vector<double>>> atas;
    atas.resize(pattern_size);
    for(size_t i=0; i<pattern_size; ++i){
        size_t dom_size = domain_sizes[pattern[i]];
        atas[i].resize(dom_size-1);
        double val = -1./dom_size;
        for(size_t row=0; row<dom_size-1; ++row){
            atas[i][row].resize(dom_size-1);
            for(size_t col=0; col<dom_size-1; ++col){
                atas[i][row][col]=val;
                if(row==col)atas[i][row][col]+=1;
            }
        }
    }
    return atas;
}
vector<vector<double>> PotentialCalculatorAlgebraic::get_potentials(
        const vector<pdbs::PatternDatabase> &pdbs,
        const shared_ptr<AbstractTask> &task){
    TaskProxy task_proxy(*task);
    VariablesProxy vars = task_proxy.get_variables();
    size_t num_vars,domain_size,var_id,dom_size_sum=0,dom_size_prod=1;
    num_vars = vars.size();
    //num_patterns = pdbs.size();
    vector<double> var_weights;
    var_weights.resize(num_vars);
    vector<vector<double>> weighted_pots;
    weighted_pots.resize(num_vars);
    potentials.resize(num_vars);
    vector<size_t> domain_sizes;
    domain_sizes.resize(num_vars);
/*    vector<vector<IntermediateAverage>> interm_result;
    interm_result.resize(num_vars);*/
    pdbs::PatternCollection patterns;
    for(pdbs::PatternDatabase pdb:pdbs){
        pdbs::Pattern pattern = pdb.get_pattern();
        patterns.push_back(pattern);
        for(size_t var:pattern){
            var_weights[var]+=1./pdb.get_size();
        }
    }
    vector<size_t> num_patterns_with_var
        = get_num_patterns_with_var(patterns, num_vars);
    for (VariableProxy var:vars){
        var_id=var.get_id();
        domain_size = var.get_domain_size();
        domain_sizes[var_id]=domain_size;
/*        interm_result[var_id].resize(domain_size);*/
        potentials[var_id].resize(domain_size);
        weighted_pots[var_id].resize(domain_size);
    }
    pdbs::Pattern pattern;
    pdbs::PDBhTensor h_tensor;
    for(pdbs::PatternDatabase pdb:pdbs){
        pattern = pdb.get_pattern();
        double weight = 1./((double)pdb.get_size());
        size_t pattern_size = pattern.size();
        dom_size_sum=0;
        dom_size_prod=1;
        for (size_t i=0;i<pattern_size;++i){
            var_id = pattern[i];
            dom_size_sum+=domain_sizes[var_id];
            dom_size_prod*=domain_sizes[var_id];
        }
        h_tensor = pdbs::PDBhTensor(pdb,vars,handle_dead_ends);
        vector<vector<double>> atbs = make_Atb(h_tensor,pattern,domain_sizes);
        vector<vector<vector<double>>> atas = make_AtA(pattern,domain_sizes);
        double last_b = ((double)h_tensor.get_total_sum())/((double)dom_size_prod);
        for(size_t i=0; i<pattern_size-1;++i){
            size_t var = pattern[i];
            vector<double> pots = utils::solve_gauss(atas[i],atbs[i]);
            pots.push_back(0.);
            for(size_t val=0; val<pots.size(); ++val){
                potentials[var][val] += pots[val]/num_patterns_with_var[var];
                weighted_pots[var][val] += pots[val]*weight;
                last_b-=pots[val]/domain_sizes[var];
            }
        }
        size_t var = pattern[pattern_size-1];
        size_t dom_size = domain_sizes[var];
        vector<double> last_row;
        last_row.resize(dom_size);
        for(size_t col=0; col<dom_size; ++col){
            last_row[col] = 1./dom_size;
        }
        for(size_t row=0; row<dom_size-1; ++row){
            atas[pattern_size-1][row].push_back(-1./dom_size);
        }
        atas[pattern_size-1].push_back(last_row);
        atbs[pattern_size-1].push_back(last_b);
        vector<double> last_pots =
            utils::solve_gauss(atas[pattern_size-1],atbs[pattern_size-1]);
        for(size_t val=0; val<last_pots.size(); ++val){
            size_t var = pattern[pattern_size-1];
            potentials[var][val] += last_pots[val]/num_patterns_with_var[var];
            weighted_pots[var][val] += last_pots[val]*weight;
        }
    }
    for(VariableProxy var_proxy:vars){
        size_t var = var_proxy.get_id();
        /*debug*/cout<<"var_weights["<<var<<"]= "<<var_weights[var]<<"\n";
        if(var_weights[var]*var_weights[var]<1e-20){
            /*debug*cout<<"weight<1e-10\n";/end*/
            continue;
        }
        for(size_t val=0; val<domain_sizes[var]; ++val){
            weighted_pots[var][val] /= var_weights[var];
            /*debug*/cout<<"["<<var<<"]["<<val<<
                "] :\n    w_pots: "<<weighted_pots[var][val]<<
                "\n    pots:   "<<potentials[var][val]<<"\n";/*end*/
        }
    }
    return potentials;/*weighted_pots;*/
}

static shared_ptr<PotentialCalculator> _parse(OptionParser &parser) {
    parser.document_synopsis(
        "Algebraic algorithm for potentials from pdbs",
        "Solves the least error squares problem algebraicly");

    parser.add_option<bool>(
        "handle_dead_ends",
        "Prevents having infinity (max int) as h-value.",
        "true");

    Options opts = parser.parse();
    if (parser.dry_run())
        return nullptr;

    return make_shared<PotentialCalculatorAlgebraic>(opts);
}

static PluginShared<PotentialCalculator> _plugin("algebraic", _parse);
}

