#include "potential_calculator_greedy.h"

#include "potential_calculator.h"
//#include "pdb_based_potential_heuristics.h"
#include "util.h"

#include "../pdbs/pattern_database.h"
#include "../pdbs/pdb_h_tensor.h"
#include "../option_parser.h"
#include "../plugin.h"

#include <vector>
#include <map>

using namespace std;

namespace potentials{

FactList::FactList(const vector<pdbs::PDBhTensor> &h_tensors,
    const vector<size_t> &domain_sizes, const vector<pdbs::Pattern> &patterns):
    root(nullptr),patterns(patterns),domain_sizes(domain_sizes){

    /*debug*/cout<<"generating influence matrix"<<flush;/*end*/
    num_vars = domain_sizes.size();
    size_t num_patterns = patterns.size();
    influence_matrix.resize(num_vars);
    for(size_t var=0; var<num_vars;++var){
        influence_matrix[var].resize(num_vars);
    }
    size_t tenth=num_patterns/10;
    if(tenth==0)tenth++;
    size_t dom_size_prod;
    for(size_t pattern_id=0; pattern_id<num_patterns;++pattern_id){
        /*debug*/if(pattern_id%tenth==(tenth-1))cout<<"."<<flush;/*end*/
        pdbs::Pattern pattern = patterns[pattern_id];
        dom_size_prod=1;
        for(int var:pattern){
            dom_size_prod*=domain_sizes[var];
        }
        /*debug/cout<<"\n    dom_size_prod="<<dom_size_prod<<flush;/end*/
        for(int var1:pattern){
            influence_matrix[var1][var1]+=(dom_size_prod/domain_sizes[var1]);
            /*debug/cout<<"\n        matrix["<<var1<<"]["<<var1<<"]="
                <<influence_matrix[var1][var1]<<flush;/end*/
            for(int var2:pattern){
                if(var1==var2)continue;
                influence_matrix[var1][var2]
                    +=(dom_size_prod/domain_sizes[var1]/domain_sizes[var2]);
            /*debug/cout<<"\n        matrix["<<var1<<"]["<<var2<<"]="
                <<influence_matrix[var1][var2]<<flush;/end*/
            }
        }
    }
    /*debug*/cout<<" done\ngenerating fact list"<<flush;/*end*/
    num_facts=0;
    for(size_t dom_size:domain_sizes){
        num_facts+=dom_size;
    }
    /*debug/cout<<"number of facts: "<<num_facts<<"\n";/end*/
    elements.resize(num_facts);
    size_t id = 0;
    size_t dom_size;
    for(size_t var_id=0; var_id<num_vars; ++var_id){
        dom_size = domain_sizes[var_id];
        for(size_t val=0; val<dom_size; ++val){
            double delta
                = get_average_h(h_tensors, var_id, val, patterns,domain_sizes);
            ListElementFact* next_element = new ListElementFact(var_id,val,delta);
            elements[id++]=next_element;
            if(root!=nullptr){
                root->prev=next_element;
                next_element->next=root;
                root=next_element;
                /*debug/cout<<"bubblesort start\n";/end*/
                bubblesort(next_element);
                /*debug/cout<<"bubblesort end\n";/end*/
            }
            else{
                /*debug/cout<<"new root\n";/end*/
                root=next_element;
            }
        }
    }
    /*debug*/cout<<" done\n";/*end*/
    
}

ListElementFact* FactList::operator[](const array<size_t,2>& var_val){
    size_t var = var_val[0];
    size_t index = 0;
    while(var>0){
        index+=domain_sizes[--var];
    }
    index+=var_val[1];
    return elements[index];
}
void FactList::bubblesort(ListElementFact* to_sort){
    size_t num_loop = 0;
    if(to_sort->bubble_up()){
        /*debug/cout<<"bubble up ("<<to_sort->prev->abs_delta
            <<"<"<<to_sort->abs_delta<<"); "<<flush;/end*/
        if(to_sort->next!=nullptr){
            to_sort->prev->next=to_sort->next;
            to_sort->next->prev=to_sort->prev;
        }
        else{
            to_sort->prev->next=nullptr;
        }
        ListElementFact* iterator=to_sort->prev;
        while(iterator->abs_delta<to_sort->abs_delta){
            if((iterator->prev)!=nullptr){iterator=iterator->prev;}
            else{
                iterator=nullptr;
                ++num_loop;
                break;
            }
            if(++num_loop>num_facts){
                cerr<<"bubbled more than num_fact times! "
                    "This is an endless loop.\n";
                utils::exit_with(utils::ExitCode::CRITICAL_ERROR);
            }
        }
        /*debug/cout<<"bubbled "<<num_loop<<" times.\n";/end*/
        if(iterator!=nullptr){
            to_sort->prev=iterator;
            to_sort->next=iterator->next;
            iterator->next=to_sort;
            to_sort->next->prev=to_sort;
            /*debug/cout<<"bubbled up ("<<to_sort->prev->abs_delta
                <<">"<<to_sort->abs_delta<<")\n";/end*/
        }
        else{
            to_sort->prev=nullptr;
            to_sort->next=root;
            root->prev=to_sort;
            root=to_sort;
            /*debug/cout<<"bubbled all the way up\n";/end*/
        }
    }
    else{
        if(to_sort->bubble_down()){
            /*debug/cout<<"bubble down ("<<to_sort->abs_delta
                <<"<"<<to_sort->next->abs_delta<<"); "<<flush;/end*/
            if(to_sort->prev!=nullptr){
                to_sort->prev->next=to_sort->next;
                to_sort->next->prev=to_sort->prev;
            }
            else{
                to_sort->next->prev=nullptr;
                root=to_sort->next;
            }
            while(to_sort->next->abs_delta>to_sort->abs_delta){
                if((to_sort->next->next)!=nullptr){
                    to_sort->next=to_sort->next->next;
                }
                else{
                    to_sort->prev=to_sort->next;
                    to_sort->next->next=to_sort;
                    to_sort->next=nullptr;
                    ++num_loop;
                    break;
                }
                if(++num_loop>num_facts+1){
                    cerr<<"bubbled more than num_fact times!"
                        "This is an endless loop.\n";
                    utils::exit_with(utils::ExitCode::CRITICAL_ERROR);
                }
            }
            /*debug/cout<<"bubbled "<<num_loop<<" times.\n";/end*/
            if(to_sort->next!=nullptr){
                /*debug/cout<<"."<<flush;/end*/
                if(to_sort->next->prev!=nullptr){
                    to_sort->prev=to_sort->next->prev;
                    to_sort->prev->next=to_sort;
                }
                else{
                    to_sort->prev=nullptr;
                    root=to_sort;
                }
                to_sort->next->prev=to_sort;
                /*debug/cout<<"."<<flush;/end*/
                /*debug/cout<<"bubbled down ("<<to_sort->abs_delta
                    <<">"<<to_sort->next->abs_delta<<")\n";/end*/
            }
            else{
                /*debug/cout<<"bubbled all the way down\n";/end*/
            }
        }
        else{
            /*debug/cout<<"not bubbled\n";/end*/
        }
    }
    //return;
}

void FactList::iterate(){
    ListElementFact* actual = peak();
    double delta1 = actual->delta;
    double delta2;
    size_t var1 = actual->var;
    //size_t val1 = actual.val;
    //ListElementFact fact = (*this)[{var1,val1}];
    /*debug/cout<<"    got fact\n";/end*/
    actual->delta=0.;
    actual->abs_delta=0.;
    /*debug/cout<<"    new delta="<<actual->delta<<" @var_"<<actual->var<<
        "="<<actual->val<<"\n";/end*/
    bubblesort(actual);
    for(size_t var2=0; var2<num_vars; ++var2){
        if(influence_matrix[var1][var2]==0
            ||var1==var2
            /*||influence_matrix[var2][var2]==0*/){
            continue;
        }
        /*debug/cout<<"            -var_"<<var2<<flush;/end*/
        delta2 = delta1
            *influence_matrix[var1][var2]
            /influence_matrix[var2][var2];
        assert(delta1>=delta2);
        for(size_t val2=0; val2<domain_sizes[var2];++val2){
            ListElementFact* fact = (*this)[{var2,val2}];
            assert(fact->var==var2&&fact->val==val2);
            fact->delta-=delta2;
            /*debug/cout<<"."<<flush;/end*/
            if(fact->delta<0.){fact->abs_delta=-fact->delta;}
            else{fact->abs_delta = fact->delta;}
            /*if(fact.abs_delta<EPS){
                fact.abs_delta=EPS;
                fact.delta=EPS;
            }*/
            bubblesort(fact);
        }
        /*debug/cout<<" done\n";/end*/
    }
}

PotentialCalculatorGreedy::PotentialCalculatorGreedy(const Options &opts)
    :PotentialCalculator(){
    num_iterations=opts.get<int>("num_iterations");
    eps=opts.get<double>("eps");
    handle_dead_ends=opts.get<bool>("handle_dead_ends");
}

bool PotentialCalculatorGreedy::iterate(){
    ListElementFact* fact = delta_list.peak();
    if(fact->abs_delta<eps)return false;
    /*debug/cout<<"max delta: "<<fact->delta
        <<" @var_"<<fact->var<<"="<<fact->val<<"\n";/end*/
    potentials[fact->var][fact->val]+=fact->delta;
    /*debug/cout<<"    potential adjusted\ncompensating\n";/end*/
    delta_list.iterate();
    fact = delta_list.peak();
    /*debug/cout<<"    new max delta: "<<fact->delta
        <<" @var_"<<fact->var<<"="<<fact->val<<"\n";/end*/
    /*debug/cout<<"compensating done\n"/end*/;
    return true;
}

vector<vector<double>> PotentialCalculatorGreedy::get_potentials(
        const vector<pdbs::PatternDatabase> &pdbs,
        const shared_ptr<AbstractTask> &task){
//--------------------initialization---------------------------
    /*debug*/cout << "initializing\n";/*end*/
    TaskProxy task_proxy(*task);
    VariablesProxy vars = task_proxy.get_variables();
    size_t num_vars,num_patterns,domain_size,var_id;
    pdbs::Pattern pattern;
    vector<pdbs::PDBhTensor> h_tensors;
    vector<int> values;
    num_vars = vars.size();
    num_patterns = pdbs.size();
    /*debug*/cout << "number of patterns: " << num_patterns << "\n";/*end*/
    h_tensors.resize(num_patterns);
    values.resize(num_patterns);
    potentials.resize(num_vars);
    deltas.resize(num_vars);
    vector<size_t> domain_sizes;
    domain_sizes.resize(num_vars);
    for (VariableProxy var:vars){
        var_id=var.get_id();
        domain_size = var.get_domain_size();
        domain_sizes[var_id]=domain_size;
        potentials[var_id].resize(domain_size);
        deltas[var_id].resize(num_vars);
    }
    pdbs::PatternCollection patterns;
    for(pdbs::PatternDatabase pdb:pdbs){
        patterns.push_back(pdb.get_pattern());
    }
//---------------------get all h-values------------------------
    /*debug*/cout << "getting h-values... "<<flush;/*end*/
    for(size_t pattern_nr=0; pattern_nr<num_patterns; ++pattern_nr){
        pdbs::PatternDatabase pdb = pdbs[pattern_nr];
        pattern = patterns[pattern_nr];
        h_tensors[pattern_nr] = pdbs::PDBhTensor(pdb,vars,handle_dead_ends);
    }
    /*debug*/cout << "done\n";/*end*/
    delta_list = FactList(h_tensors, domain_sizes, patterns);
//--------------------------iterate---------------------------
    /*debug*/cout << "iterating";/*end*/
    size_t i=0;
    while(num_iterations-i++>0/*&&iterate(h_tensors,domain_sizes,patterns)*/){
        /*debug/cout << "    iteration " << i <<flush;/end*/
        if(!iterate()){
            /*debug*/cout << "\neps criterion reached";/*end*/
            break;
        }
        /*debug*/if(i%(num_iterations/5)==(size_t)num_iterations/5-1){
            cout << "." << flush;
        }/*end*/
    }
    /*debug*/cout << "\nreturning potentials\n";/*end*/
    return potentials;
}

static shared_ptr<PotentialCalculator> _parse(OptionParser &parser) {
    parser.document_synopsis(
        "Greedy algorithm for potentials from pdbs",
        "Algorithm which fixes a potential, then subtracts" 
        "its effects from each pdb. Then it repeats until done.");

    parser.add_option<int>(
        "num_iterations",
        "Maximal number of iterations to be done.",
        "5000",
        Bounds("0", "infinity"));
    parser.add_option<double>(
        "eps",
        "Epsilon at which iteration will end (disregarding num_iterations).",
        "1e-4",
        Bounds("0", "infinity"));
    parser.add_option<bool>(
        "handle_dead_ends",
        "Prevents infinity (max int) as h-value.",
        "true");

    Options opts = parser.parse();
    if (parser.dry_run())
        return nullptr;

    return make_shared<PotentialCalculatorGreedy>(opts);
}

static PluginShared<PotentialCalculator> _plugin("greedy", _parse);
}
