#include "util.h"

#include "potential_function.h"
#include "potential_optimizer.h"

#include "../heuristic.h"
#include "../option_parser.h"
#include "../sampling.h"
#include "../successor_generator.h"
#include "../task_tools.h"

#include "../utils/markup.h"
#include "../pdbs/pattern_database.h"

#include <limits>

using namespace std;

namespace pdbs{
struct comp
{
    bool operator()(PatternDatabase &a, PatternDatabase &b)
    {
        return a.compute_mean_finite_h() < b.compute_mean_finite_h();
    }
};
}

namespace potentials {
vector<State> sample_without_dead_end_detection(
    PotentialOptimizer &optimizer, int num_samples) {
    const shared_ptr<AbstractTask> task = optimizer.get_task();
    const TaskProxy task_proxy(*task);
    State initial_state = task_proxy.get_initial_state();
    optimizer.optimize_for_state(initial_state);
    SuccessorGenerator successor_generator(task_proxy);
    int init_h = optimizer.get_potential_function()->get_value(initial_state);
    return sample_states_with_random_walks(
        task_proxy, successor_generator, num_samples, init_h,
        get_average_operator_cost(task_proxy));
}

string get_admissible_potentials_reference() {
    return "The algorithm is based on" + utils::format_paper_reference(
        {"Jendrik Seipp", "Florian Pommerening", "Malte Helmert"},
        "New Optimization Functions for Potential Heuristics",
        "http://ai.cs.unibas.ch/papers/seipp-et-al-icaps2015.pdf",
        "Proceedings of the 25th International Conference on"
        " Automated Planning and Scheduling (ICAPS 2015)",
        "193-201",
        "AAAI Press 2015");
}

void prepare_parser_for_admissible_potentials(OptionParser &parser) {
    parser.document_language_support("action costs", "supported");
    parser.document_language_support("conditional effects", "not supported");
    parser.document_language_support("axioms", "not supported");
    parser.document_property("admissible", "yes");
    parser.document_property("consistent", "yes");
    parser.document_property("safe", "yes");
    parser.document_property("preferred operators", "no");
    parser.add_option<double>(
        "max_potential",
        "Bound potentials by this number",
        "1e8",
        Bounds("0.0", "infinity"));
    lp::add_lp_solver_option_to_parser(parser);
    Heuristic::add_options_to_parser(parser);
}


/*
  Creates the pattern wich is the oldPattern without var_id.
  This might go into utils.
*/
pdbs::Pattern pattern_without(const pdbs::Pattern &oldPattern,
                                     const int &var_id){
    size_t oldSize = oldPattern.size();
    assert(oldSize>0);
    pdbs::Pattern result;
    result.resize(oldSize-1);
    if(oldSize==1){
        assert(oldPattern[0]==var_id);
        return result;
    }
    int offset=0;
    for(size_t i=0; i<oldSize-1; ++i){
        if (oldPattern[i]==var_id)
            offset=1;
        result[i]=oldPattern[i+offset];
    }
    assert((offset==1)||oldPattern[oldSize-1]==var_id);
    return result;
}

vector<int> make_important_vars_list(const TaskProxy &task_proxy,
                                     const VariablesProxy &vars){
    int num_vars = vars.size();
    vector<int> result;
    vector<double> importance;
    importance.resize(num_vars);
    result.resize(num_vars);
    pdbs::Pattern pattern;
    pattern.resize(1);
    for (int i=0; i<num_vars; ++i){
        pattern[0]=i;
        pdbs::PatternDatabase pdb(task_proxy, pattern);
        importance[i] = pdb.compute_mean_finite_h();
        result[i]=i;
    }
    for (int j=num_vars-1; j>0; --j){
        for (int i=0; i<j; ++i){
            if(importance[i]<importance[i+1]){
                double tempd = importance[i];
                importance[i]=importance[i+1];
                importance[i+1]=tempd;
                int tempi=result[i];
                result[i]=result[i+1];
                result[i+1]=tempi;
            }
        }
    }
    return result;
}

/*

*/
IntermediateAverage get_avg_pdbh(vector<int> values,
                        const pdbs::PatternDatabase &pdb,
                        const pdbs::Pattern &free_vars,
                        const vector<size_t> &domain_sizes){
    IntermediateAverage result,temp;
    result.sum=0.;
    result.num=0;
    if (free_vars.size()>0){
        int next_var = free_vars[0];
        size_t domain_size = domain_sizes[next_var];
        //remove first element from pattern and sum
        pdbs::Pattern next_free_vars = pattern_without(free_vars, next_var);
        for (size_t val=0; val<domain_size; ++val){
            values[next_var]=val;
            temp = get_avg_pdbh(values, pdb, next_free_vars,
                                     domain_sizes);
            if(temp.num!=0){
                result.sum+=temp.sum;
                result.num+=temp.num;
            }
        }
    }
    else {
        double value = pdb.get_value(State(*g_root_task(), values));
        double eps = 0.5;
        if(value+eps<numeric_limits<int>::max()){
            result.sum = value;
            result.num = 1;
        }
    }
    return result;
}

/*
  Recursively calculates the average h-value in all pdbs over all possible values
  of the variables in pattern.
*/
double get_average_pdbh(vector<int> values,
                        const vector<pdbs::PatternDatabase> &pdb_s,
                        const pdbs::Pattern &free_vars,
                        const vector<size_t> &domain_sizes){
    IntermediateAverage interm_avg, temp;
    interm_avg={0.,0};
    for(pdbs::PatternDatabase pdb:pdb_s){
        temp = get_avg_pdbh(values, pdb, free_vars, domain_sizes);
        interm_avg.sum+=temp.sum;
        interm_avg.num+=temp.num;
    }
    if(interm_avg.num!=0){
        return (interm_avg.sum)/(interm_avg.num);
    }
    else{return static_cast<double>(numeric_limits<int>::max());}
}

vector<size_t> get_num_patterns_with_var(
    const pdbs::PatternCollection &patterns,const size_t &num_vars){
    vector<size_t> result;
    result.resize(num_vars);
    pdbs::Pattern pattern;
    for(size_t i=0; i<patterns.size(); ++i){
        pattern = patterns[i];
        for(size_t j:pattern){
            ++result[j];
        }
    }
    return result;
}

vector<size_t> get_pattern_indices_with_var(
    const pdbs::PatternCollection &patterns, const size_t &var_id){
    vector<size_t> result;
    pdbs::Pattern pattern;
    for(size_t i=0; i<patterns.size(); ++i){
        pattern = patterns[i];
        for(size_t j:pattern){
            if(j==var_id){
                result.push_back(i);
                break;
            }
        }
    }
    return result;
}

/*
  Extracts all values from a pdb.
*/
map<vector<int>,double> &get_all_h(const pdbs::PatternDatabase &pdb,
    const pdbs::Pattern &pattern, const vector<size_t> &domain_sizes,
    map<vector<int>,double> &result, TaskProxy &task_proxy,
    vector<int> values){
    if (pattern.size()>0){
        size_t next_var = pattern[0];
        size_t domain_size = domain_sizes[next_var];
        pdbs::Pattern next_pattern = pattern_without(pattern, next_var);
        for (int val=0; val<static_cast<int>(domain_size); ++val){
            /*debug:cout << "processing var_" << next_var << " = " << val << "\n";*/
            values[next_var]=val;
            get_all_h(pdb,next_pattern,domain_sizes,
                               result,task_proxy,values);
            /*debug:cout << "done with var_" << next_var << " = " << val << "\n";*/
        }
    }
    else {
        double h = pdb.get_value(State(*g_root_task(), values));
        /*debug*/cout << "got h: " << h << "\n";
        result[values] = h;
    }
    return result;
}

IntermediateAverage get_avg_h(
    const pdbs::PDBhTensor &h_tensor, vector<int> values,
    const pdbs::Pattern free_vars, const vector<size_t> &domain_sizes){
    /*debug/cout << "*" << flush;/end*/
    IntermediateAverage result,temp;
    result.sum=0.;
    result.num=0;
    temp.sum=0.;
    temp.num=0;
    /*debug/cout << "*" << flush;/end*/
    if (free_vars.size()>0){
        /*debug/cout << "v" << flush;/end*/
        int next_var = free_vars[0];
        /*debug/cout << "next_var: " << next_var << flush;/end*/
        size_t domain_size = domain_sizes[next_var];
        /*debug/cout << "." << flush;/end*/
        pdbs::Pattern next_free_vars = pattern_without(free_vars, next_var);
        for (size_t val=0; val<domain_size; ++val){
            values[next_var]=val;
            /*debug/cout << "." << flush;/end*/
            temp = get_avg_h(h_tensor,values,next_free_vars,domain_sizes);
            if(temp.num!=0){
                result.sum+=temp.sum;
                result.num+=temp.num;
            }
        }
    }
    else {
        /*debug/cout << "x" << flush;/end*/
        double value = h_tensor[values];
        /*double eps = 0.5;
        if(value+eps<numeric_limits<int>::max()){
            result.sum = value;
            result.num = 1;
        }
        else{
            result.sum = 0.;
            result.num = 0;
        }*/
        result.sum = value;
        result.num = 1;
    }
    return result;
}

double get_average_h(
        const vector<pdbs::PDBhTensor> &h_tensors, const size_t &var_id,
        const size_t &value, const pdbs::PatternCollection &patterns,
        const vector<size_t> &domain_sizes){
    vector<size_t> patterns_with=get_pattern_indices_with_var(patterns,var_id);
    pdbs::Pattern pattern;
    size_t num_vars = domain_sizes.size();
    vector<int> values;
    values.resize(num_vars);
    values[var_id]=value;
    IntermediateAverage temp, result;
    temp.num = 0;
    result.num = 0;
    temp.sum = 0.;
    result.sum = 0.;
    for(size_t pattern_id:patterns_with){
        /*debug/cout << "*" << flush;/end*/
        pattern = pattern_without(patterns[pattern_id], var_id);
        /*debug/cout << "." << flush;/end*/
        temp = get_avg_h(h_tensors[pattern_id],values,pattern,domain_sizes);
        /*debug/cout << "*" << flush;/end*/
        /*debug/cout << "temp.sum = " << temp.sum << "\n";/end*/
        if(temp.num>0){
            result.sum+=temp.sum;
            result.num+=temp.num;
        }
    }
    /*debug/cout << "\nresult.num = " << result.num << "\n";/end*/
    if(result.num>0){
        /*debug/ cout << "    average h for var_"<< var_id
            << " = " << values[var_id] << ": " << result.sum/result.num << "\n";
            /end*/
        return result.sum/result.num;
    }
    return static_cast<double>(numeric_limits<int>::max());
}
vector<vector<double>> make_ata(vector<size_t> &domain_sizes){
    size_t dom_size_prod = 1,dom_size_sum=0;
    vector<size_t> first_facts;
    for(size_t dom_size:domain_sizes){
        first_facts.push_back(dom_size_sum);
        dom_size_prod*=dom_size;
        dom_size_sum+=dom_size;
    }
    vector<vector<double>> ata;
    ata.resize(dom_size_sum);
    for(vector<double> row:ata)row.resize(dom_size_sum);
    size_t index1,index2,main_val;
    for(size_t var1=0; var1<domain_sizes.size(); ++var1){
        for(size_t var2=0; var2<domain_sizes.size();++var2){
            if(var1==var2){
                for(size_t i=0; i<domain_sizes[var1]; ++i){
                    index1 = first_facts[var1]+i;
                    main_val = dom_size_prod/domain_sizes[var1];
                    ata[index1][index1]=main_val;
                }
            }
            else{
                for(size_t i=0; i<domain_sizes[var1]; ++i){
                    index1 = first_facts[var1]+i;
                    for(size_t j=0; j<domain_sizes[var2]; ++j){
                        index2 = first_facts[var2]+j;
                        main_val 
                            = dom_size_prod/domain_sizes[var1]/domain_sizes[var2];
                        ata[index1][index2]=main_val;
                    }
                }
            }
        }
    }
    return ata;
}

vector<pdbs::PatternDatabase> filter_patterns(
    const pdbs::PatternCollection &all_patterns,
    const size_t max_num_patterns,
    const TaskProxy &task_proxy){
    size_t num_patterns = all_patterns.size();
    vector<pdbs::PatternDatabase> pdbs;
    priority_queue<pdbs::PatternDatabase,
        vector<pdbs::PatternDatabase>,pdbs::comp> pdb_queue;
    for (pdbs::Pattern pattern:all_patterns){
        pdb_queue.push(pdbs::PatternDatabase(task_proxy, pattern));
    }
    if(max_num_patterns<num_patterns)num_patterns=max_num_patterns;
    for(size_t pattern_nr=0; pattern_nr<num_patterns; ++pattern_nr){
        pdbs.push_back(pdb_queue.top());
        pdb_queue.pop();
    }
    return pdbs;
}
}
