#include "offline_heuristic.h"

#include "approximation.h"
#include "iterative_deepening_search.h"
#include "thts.h"
#include "utils/random.h"
#include "utils/system_utils.h"
#include "utils/string_utils.h"

#include <dirent.h> 
#include <fstream>
#include <iostream>
#include <numeric>
#include <limits>
#include <stdlib.h>
#include <sstream>
#include <vector>

/******************************************************************
                OfflineHeuristic Search Engine Creation
******************************************************************/

OfflineHeuristic::HashMap OfflineHeuristic::rewardCache;

OfflineHeuristic::OfflineHeuristic()
    : ProbabilisticSearchEngine("OfflineHeuristic"),
      trainingSetGenerator(nullptr),
      methodType(MethodType::StochasticGradientDescent),
      step(0.5),
      loadHeuristicFrom(""),
      loadTrainingSetFrom(""),
      numberOfEpochs(1),
      learnType(LearnType::LearnOnly),
      momentum(0), // Usually 0.9 (ref. https://arxiv.org/pdf/1609.04747.pdf)
      learningDecayRate(0.05),
      details(""),
      combineFeatures(1),
      log(false),
      outputDir("weights"),
      oldTotalMSE(-1),
      newTotalMSE(-1) {}

void OfflineHeuristic::learn() {
    if (trainingSetGenerator) {
        trainingSetGenerator->learn();
        minReward = trainingSetGenerator->rewardCPF->getMinVal();
        maxReward = trainingSetGenerator->rewardCPF->getMaxVal();
    }

    if (loadHeuristicFrom == "") {
        if (learnType == LearnType::LearnOnly) {
            learnFromTrainingSet(1);
        } else if (learnType == LearnType::LearnAndTest) {
            learnFromTrainingSet(0.7);
        } else {
            std::cout << "No learning done!";
        }

        std::ofstream outfile;      
        outfile.open("heuristics-offline/" + getDetails());
        storeHeuristic(outfile);
    } else {
        loadHeuristic();
    }
}

void OfflineHeuristic::learnFromTrainingSet(double learningSetSize) {

    if (loadTrainingSetFrom == "") {
        storeTrainingData(learningSetSize);
        outputTrainingData(); // for writing training data to a file

        // SystemUtils::abort("No need for further calculations since only collection"
        //                 " of data is performed. Line 70 in offline_heuristic.cc");
    } else {
        loadTrainingData();
        // outputTrainingData();
    }


    std::cout << learningSetSize * 100;
    std::cout << "% of data will be used for training." << std::endl;

    std::cout << (1 - learningSetSize) * 100;
    std::cout << "% of data will be used for testing." << std::endl;

    for(unsigned j = 0; j < numberOfEpochs; j++) {

        // Early stopping -- if the change in evaluation falls below threshold
        if (!performTraining(learningSetSize, j)) {
            break;
        }

        // Learning decay rate
        // Annealing the step size for learningDecayRate for every 50 epochs
        if (j > 0 && j % 50 == 0) {
            step -= step * learningDecayRate;
        }

        // Shuffling before every epoch
        for(auto& entry : dataSet) {
            std::random_shuffle(entry.second.begin(), 
                    entry.second.begin() + entry.second.size()*learningSetSize);
        }

 
        if (MathUtils::doubleIsEqual(1, learningSetSize) || 
            !loggingOn()) {
            if (!MathUtils::doubleIsEqual(1, learningSetSize)) {
                std::cout << "Testing won't be performed ";
                std::cout << " since logging is disabled." << std::endl;
            }
            continue;
        } else {
            performTesting(learningSetSize, j);
        }
    }
    // Clearing data set after it was being used
    dataSet.clear();
} 

void OfflineHeuristic::storeTrainingData(double learningSetSize) {
    std::cout << "Training set collection started." << std::endl;
    for (unsigned it = 0; it < (trainingSet.size() * learningSetSize); it++) {
        std::vector<double> estimates(numberOfActions);
        std::vector<int> actions = getApplicableActions(trainingSet[it]);

        // TODO: change this so that it loads these values from a file
        // Something like a pre-defined function that basically loads the whole 
        // data set
        ProbabilisticSearchEngine::stateValueCache.clear();
        ProbabilisticSearchEngine::applicableActionsCache.clear();

        DeterministicSearchEngine::stateValueCache.clear();
        DeterministicSearchEngine::applicableActionsCache.clear();

        IDS::rewardCache.clear();

        assert(trainingSetGenerator);

        State s(trainingSet[it]);
        s.stepsToGo() = SearchEngine::horizon;
        trainingSetGenerator->estimateQValues(s, actions, estimates);

        for (unsigned i = 0; i < estimates.size(); i++) {
            double reward = 0.0;
            if (actions[i] == i) {
                reward = estimates[i];
            } else {
                reward = estimates[actions[i]];
            }
            reward /= (double)SearchEngine::horizon;

            // TODO: Scale the reward between 0 and 1
            std::pair<std::vector<double>, double> p = 
                    std::make_pair(parseStateToValues(s), reward);
            dataSet[i].push_back(p);
        }
    }

    // Shuffling dataset
    for(auto& entry : dataSet) {
        std::random_shuffle(entry.second.begin(), entry.second.end());
    }

    std::cout << "Training set collection finished." << std::endl;
}

void OfflineHeuristic::outputTrainingData() {
    std::ofstream outfile;
    outfile.open(outputDir + "/" + SearchEngine::name + "-" 
                        + SearchEngine::taskName + "_", std::ios_base::out);

    for(auto& actionEntry : dataSet) {
        for(auto& stateRewardEntry : actionEntry.second) {
            outfile << actionEntry.first << ";";
            for(double f : stateRewardEntry.first) {
                outfile << f << " ";
            }
            outfile << ";";
            outfile << stateRewardEntry.second;
            outfile << std::endl;
        }
    }
    outfile.close();
    std::cout << "Outputing training data set done!" << std::endl;
}

void OfflineHeuristic::loadTrainingData() {
    std::cout << "Trying to find training set for instance ";
    std::cout << SearchEngine::taskName << " in folder ";
    std::cout << loadTrainingSetFrom << std::endl;
    DIR *dir;
    struct dirent *ent;
    std::string fileName(loadTrainingSetFrom + "/");
    if ((dir = opendir (loadTrainingSetFrom.c_str())) != nullptr) {
        /* print all the files and directories within directory */
        while ((ent = readdir(dir)) != nullptr) {
            std::string file(ent->d_name);
            if (file.find("-" + SearchEngine::taskName + "_") != std::string::npos) {
               fileName += file;
               std::cout << "Trying to load file " << fileName << std::endl;
               break;
            }
        }
        // No file found
        if (fileName[fileName.size() - 1] == '/') {
            SystemUtils::abort("Could not find training set data for "
                                + SearchEngine::taskName + " in folder " 
                                + fileName + ".");        
        }

        closedir(dir);
    } else {
        SystemUtils::abort("Could not open directory " + loadHeuristicFrom  
                            + " for loading training set data.");
    }

    std::ifstream input(fileName);
    // Parsing training data. Format is <action>;<state including facts>;reward
    size_t lineNum = 1;
    size_t stateSize = -1;
    for( std::string line; getline(input, line); lineNum++) {
        int actionName;
        double reward;
        std::vector<std::string> res;

        StringUtils::split(line, res, ";");

        if (res.size() < 3) {
            std::cout << "Broken training set line: " << lineNum << std::endl;
            continue;
        }

        // std::cout << "res.size() = " << res.size() << std::endl;
        // std::cout << "res[0] = " << res[0] << std::endl;
        // std::cout << "res[1] = " << res[1] << std::endl;
        // std::cout << "res[2] = " << res[2] << std::endl;


        actionName = std::stoi(res[0]);
        reward = atof(res[2].c_str());

        assert(!MathUtils::doubleIsMinusInfinity(reward));

        std::vector<std::string> stateString;
        StringUtils::split(res[1], stateString, " ");
        std::vector<double> state(stateString.size());

        for(size_t i = 0; i < stateString.size(); i++) {
            state[i] = atof(stateString[i].c_str());
        }

        if (stateSize == -1) {
            stateSize = state.size();
        } else if (stateSize != state.size()) {
            std::cout << "Broken training set line: " << lineNum << std::endl;
            continue;            
        }

        std::pair<std::vector<double>, double> p = std::make_pair(state,reward);
        dataSet[actionName].push_back(p);    
    }

    // Shuffling dataset
    for(auto& entry : dataSet) {
        std::random_shuffle(entry.second.begin(), entry.second.end());
    }

    std::cout << "Loaded training data from: " << fileName << std::endl;
}

bool OfflineHeuristic::performTraining(double learningSetSize, int epoch) {
    std:: cout << "Training started. Epoch: " << epoch << std::endl;

    double squaredDifference = 0.0;
    std::map<int, double> squaredDifferencePerAction;
    std::map<int, int> squaredDifferencePerActionCounter;
    int trainingSetSize = 0;
    for(auto& actionEntry : dataSet) {
        int actionName = actionEntry.first;

        if (methodType == MethodType::MiniBatchGradientDescent) {
            // Setting size of the mini batch to 10% of the whole training set
            // for one particular action
            MiniBatchGradientDescentApproximation::batchSize = 
                                                    (int) (learningSetSize
                                                    * actionEntry.second.size()
                                                    * 0.1);

            if (MiniBatchGradientDescentApproximation::batchSize == 0) {
                MiniBatchGradientDescentApproximation::batchSize = 100;
            }

            // std::cout << actionEntry.second.size() << std::endl;
            // std::cout << actionEntry.first << std::endl;
            // std::cout << "MiniBatchGradientDescentApproximation::batchSize "; 
            // std::cout << MiniBatchGradientDescentApproximation::batchSize; 
            // std::cout << std::endl;
        }

        approximate(actionName, 
                    epoch,
                    actionEntry.second, 
                    learningSetSize,
                    squaredDifference,
                    squaredDifferencePerAction, 
                    squaredDifferencePerActionCounter);
        trainingSetSize += actionEntry.second.size() * learningSetSize;
    }

    if (loggingOn() /*&& (epoch == getNumberOfEpochs()-1)*/) {
        SystemUtils::log2(getDetails(), 
                        "MSE in training per action for"
                            " iteration #", epoch);
        for (auto& entry : squaredDifferencePerAction) {
            SystemUtils::log2(getDetails(), "MSE training for " 
                + std::to_string(epoch) + 
                "#" + std::to_string(entry.first) + ": ", 
                entry.second / squaredDifferencePerActionCounter[entry.first]);
        }
        
        SystemUtils::log2(getDetails(), "Total MSE in training: ", 
                squaredDifference / trainingSetSize);

        SystemUtils::log2(getDetails(), "Training finished. "
                                        "Number of states evaluated: ",
                                        trainingSetSize);
    }

    newTotalMSE = squaredDifference / trainingSetSize;

    std::cout << "\t\toldTotalMSE: " << oldTotalMSE << std::endl;
    std::cout << "\t\tnewTotalMSE: " << newTotalMSE << std::endl;

    // TODO: Update the early stopping parameter
    if (!MathUtils::doubleIsEqual(oldTotalMSE, -1)  && 
                MathUtils::doubleIsEqual(fabs(oldTotalMSE - newTotalMSE), 
                                            0.00001)) {
        std:: cout << "Training finished. Epoch: " << epoch << std::endl;
        std:: cout << "Early stopping due to change of MSE"
                    " falling below threshold" << std::endl;
        return false;
    } else {
        std:: cout << "Training finished. Epoch: " << epoch << std::endl;
        oldTotalMSE = newTotalMSE;
        return true;
    }

}

// TODO: Check
void OfflineHeuristic::performTesting(double learningSetSize, int epoch) {
    std:: cout << "Testing started. Epoch: " << epoch << std::endl;

    double calculatedReward;
    double squaredDifference = 0.0;
    std::map<int, double> squaredDifferencePerAction;
    std::map<int, int> squaredDifferencePerActionCounter;
    int testingSetSize = 0;
    for(auto& actionEntry : dataSet) {
        int actionName = actionEntry.first;

        assert(heuristic.count(actionName));
        testingSetSize = actionEntry.second.size() 
                        - actionEntry.second.size() * learningSetSize;

        for(unsigned i = actionEntry.second.size() * learningSetSize; 
                    i < actionEntry.second.size(); 
                    i++) {
            double reward = actionEntry.second[i].second;

            calculatedReward = calculateSumOfMultipliers(heuristic[actionName], 
                                                actionEntry.second[i].first,
                                                this);

            if (loggingOn()) {
                SystemUtils::log2(getDetails(), "Actual reward in testing " + 
                                    std::to_string(epoch) + "#" + 
                                    std::to_string(actionName) +
                                    ": ", reward);
                SystemUtils::log2(getDetails(), "Calculated reward in testing " + 
                                    std::to_string(epoch) + "#" + 
                                    std::to_string(actionName) +
                                    ": ", calculatedReward);
            }

            squaredDifference += (reward - calculatedReward) 
                                    * (reward - calculatedReward);

            if (squaredDifferencePerAction.count(actionName)) {
                squaredDifferencePerAction[actionName] += 
                        (reward - calculatedReward) 
                        * (reward - calculatedReward);
                squaredDifferencePerActionCounter[actionName] +=1;
            } else {
                squaredDifferencePerAction[actionName] = 
                        (reward - calculatedReward) 
                        * (reward - calculatedReward);
                squaredDifferencePerActionCounter[actionName] = 1;
            }
        }
    }

    if (loggingOn()) {
        SystemUtils::log2(getDetails(), "MSE in testing per action for"
                            "iteration #", epoch);

        for (auto& entry : squaredDifferencePerAction) {
            SystemUtils::log2(getDetails(), "MSE testing for " + std::to_string(epoch) + 
                "#" + std::to_string(entry.first) + ": ", 
                entry.second / squaredDifferencePerActionCounter[entry.first]);
        }

        assert(testingSetSize != 0);

        SystemUtils::log2(getDetails(), "Total MSE in testing: ", 
                                squaredDifference / testingSetSize);

        SystemUtils::log2(getDetails(), "Testing finished. Number of states evaluated: ",
                                testingSetSize);
    }

    std:: cout << "Testing finished. Epoch: " << epoch << std::endl;
}

bool OfflineHeuristic::setValueFromString(std::string& param, std::string& value) {
    if (param == "-step") {
        // Set the step of the learning rate
        setStep(atof(value.c_str()));
        return true;
    } else if (param == "-loadHeuristicFrom") {
        // Set where the offline heuristic will be stored 
        setHeuristicLoadLocation(value.c_str());
        return true;
    } else if (param == "-loadTrainingSetFrom") {
        // Set where the offline heuristic will be stored 
        setTrainingSetLoadLocation(value.c_str());
        return true;
    } else if (param == "-method") {
        // Method that's being applied for learning
        if (value == "StochasticGradientDescent") {
            setMethod(OfflineHeuristic::MethodType::StochasticGradientDescent);
        } else if (value == "GradientDescent") { 
            setMethod(OfflineHeuristic::MethodType::GradientDescent);
        } else if (value == "MiniBatchGradientDescent") { 
            setMethod(OfflineHeuristic::MethodType::MiniBatchGradientDescent);
        } else {
            SystemUtils::abort("Unknown learning method: " + value);
        }
        return true;
    } else if (param == "-numberOfEpochs") {
        setNumberOfEpochs(atoi(value.c_str()));
        return true;
    } else if (param == "-momentum") {
        setMomentum(atoi(value.c_str()));
        return true;
    } else if (param == "-learningDecayRate") {
        setLearningDecayRate(atoi(value.c_str()));
        return true;
    } else if (param == "-learnType") {
        if (value == "learnOnly") {
            setLearnType(OfflineHeuristic::LearnType::LearnOnly);
        } else if (value == "learnAndTest") {
            setLearnType(OfflineHeuristic::LearnType::LearnAndTest);            
        } else {
            SystemUtils::abort("Unknown learning type: " + value);            
        } 

        return true;
    } else if (param == "-logging") {
        if (value == "ON") {
            setLogOption(true);
            return true;
        } else if (value == "OFF") {
            setLogOption(false);
            return true;
        }
        return false;
    } else if (param == "-trainingSetGenerator") {
        setTrainingSetGenerator(SearchEngine::fromString(value));
        return true;
    } else if (param == "-outputDir") {
        setOutputDir(value.c_str());
        return true;
    } else if (param == "-combineFeatures") {
        setCombineFeaturesSize(atoi(value.c_str()));

        if (getCombineFeatures() < 1 || getCombineFeatures() > 3) {
            SystemUtils::abort("Number of features that can be combined is at"
                                " most 2 and at least 0 (no combination).");
        }

        return true;
    }   
    return ProbabilisticSearchEngine::setValueFromString(param, value);
}

void OfflineHeuristic::setStep(double _step) {
    if (MathUtils::doubleIsGreater(_step, 1) || MathUtils::doubleIsSmaller(_step, 0)) {
        SystemUtils::abort("Step takes value between 0 and 1.'");
    }
    step = _step;
}


void OfflineHeuristic::storeHeuristic(std::ostream& out) {
    out << "#### " << name << ": ";
    out << SearchEngine::taskName << " ####" << std::endl;
    out << "# " <<  getDetails() << std::endl;
    out << "# Number of features that are combined" << std::endl;
    out << combineFeatures << std::endl;
    out << "# Number of actions" << std::endl;
    out << heuristic.size() << std::endl;
    for(auto& action : heuristic) {
        out << "# Action: " << std::endl;
        out << action.first << std::endl;

        std::string pom = "[";
        for(auto& coeff : action.second) {
            std::ostringstream temp;
            temp << coeff;
            pom += temp.str() + ",";
        }
        pom = pom.substr(0, pom.size() - 1);
        pom +=  "]\n";

        out << pom;

        out << "########" << std::endl;
    }
}

// Loading the first heuristic that satisfies all the input data
void OfflineHeuristic::loadHeuristic() {
    std::cout << "Loading heuristic from folder: " << loadHeuristicFrom << std::endl;

    // Check if heuristic was already loaded
    if (heuristic.size() > 0) {
        SystemUtils::abort("Error: heuristic already contains data");
    }


    // List all heuristics from a given folder and one that contains:
    //      - exact instance name (I.e. _academic_advising_inst_mdp__1_)
    //      - exact method name (I.e. _GradientDescent)
    //      - exact number of epochs (I.e. numberOfEpochs=100)
    //      - exact step size (I.e. step=0.1) 

    std::stringstream strStep;
    strStep << step;

    std::cout << "Searching for file that contains config: " << std::endl;
    std::cout << "\t" << SearchEngine::taskName + "_" << std::endl;
    std::cout << "\t" << "_" + getMethodDesc() << std::endl;
    std::cout << "\t" << "numberOfEpochs=" + std::to_string(numberOfEpochs);
    std::cout << " " << "step=" + strStep.str() << std::endl;

    DIR *dir;
    struct dirent *ent;
    std::string fileName(loadHeuristicFrom + "/");
    if ((dir = opendir (loadHeuristicFrom.c_str())) != nullptr) {
        /* print all the files and directories within directory */
        while ((ent = readdir(dir)) != nullptr) {
            std::string file(ent->d_name);
            if (file.find(SearchEngine::taskName + "_") != std::string::npos
                && file.find(getMethodDesc()) != std::string::npos
                && file.find("numberOfEpochs=" + std::to_string(numberOfEpochs)) 
                        != std::string::npos
                && file.find("step=" + strStep.str()) != std::string::npos) {
               fileName += file;

               std::cout << "Trying to load file " << fileName << std::endl;
               break;
            }
        }
        if (fileName[fileName.size() - 1] == '/') {
            SystemUtils::abort("Could not find heuristic with given config.");        
        }
        closedir (dir);
    } else {
        SystemUtils::abort("Could not open directory " + loadHeuristicFrom);
    }

    std::string heuristicDesc;
    if (!SystemUtils::readFile(fileName, heuristicDesc, "#")) {
       SystemUtils::abort("Error: Unable to read problem file: " +
                         fileName);
    }
    std::stringstream desc(heuristicDesc);

    // First line is number features that are combined
    // Second line is number of actions
    // Then follows every action with number of states per action and then it's
    // states and corresponding vector of coefficients i.e.
    //      action_name
    //      [w_1,w_2,w_3,....,w_2n]
    // Where n is the number of features

    int combValue;
    desc >> combValue;
    if (combValue < 1 || combValue > 3) {
        SystemUtils::abort("Wrong combineFeature number.");
    }

    setCombineFeaturesSize(combValue);

    int numActions;
    desc >> numActions;    
    for(unsigned i = 0; i < numActions; i++) {
        int action;
        desc >> action;

        // TODO: Write rule
        if (heuristic.count(action)) {
            SystemUtils::abort("Error: Action " + std::to_string(action) + 
    			         " was already loaded. "
                         "See rules for writing offlineHeuristic file. "
                         "Problematic file: " + fileName);
        }

        std::string coeffsString;
        desc >> coeffsString;
        StringUtils::trim(coeffsString);
        std::vector<double> coeffs;
        parseCoefficientsFromString(coeffsString, coeffs);
        
        // Store coefficients
        heuristic[action] = coeffs;            
    }

    std::cout << "Heuristic " << fileName << " loaded." << std::endl;
}

void OfflineHeuristic::parseCoefficientsFromString(std::string coeffsString,
					      std::vector<double>& coeffs) {
    StringUtils::removeFirstAndLastCharacter(coeffsString);
    std::vector<std::string> res;
    StringUtils::split(coeffsString, res, ",");

    coeffs.clear();
    coeffs.resize(res.size());

    for(unsigned i = 0; i < res.size(); i++) {
        coeffs[i] = atof(res[i].c_str());
    }
}

std::string OfflineHeuristic::getMethodDesc() {
    switch(methodType) {
        case MethodType::StochasticGradientDescent: {
            return "StochasticGradientDescent";
        } case MethodType::GradientDescent: {
            return "GradientDescent";
        } case MethodType::MiniBatchGradientDescent: {
            return "MiniBatchGradientDescent";
        } default: {
            SystemUtils::abort("Unknown method for learning.");
	        return "";
        } 
    };  
} 

std::string OfflineHeuristic::getLearnDesc() {
    if (learnType == OfflineHeuristic::LearnType::LearnOnly) {
        return "learnOnly";
    } else if (learnType == OfflineHeuristic::LearnType::LearnAndTest) {
        return "learnAndTest";        
    } else {
        SystemUtils::abort("Unknow learning type.");
        return "";
    }
}

std::string OfflineHeuristic::getDetails() {
    if (details != "") {
        return details;
    }

    std::stringstream str;
    str << SearchEngine::name << "_";
    str << SearchEngine::taskName << "_";
    str << getMethodDesc() << "-" << getLearnDesc();
    str << "-step=" << step;
    str << "-numberOfEpochs=" << numberOfEpochs;
    str << "-momentum=" << momentum;
    str << "-combinedFeatures=" << combineFeatures;
    details = str.str();
    return details;
}

size_t OfflineHeuristic::comboValue(int indexI, 
                                    int indexJ, 
                                    size_t numberOfFeatures) {
    return indexI * numberOfFeatures + indexJ;
}


/******************************************************************
                       Main Search Functions
******************************************************************/
void OfflineHeuristic::estimateQValues(State const& state,
                         std::vector<int> const& actionsToExpand,
                         std::vector<double>& qValues) {
    HashMap::iterator it = rewardCache.find(state);
    if (it != rewardCache.end()) {
        qValues = it->second;
    } else {
        assert(qValues.size() == actionsToExpand.size());
        for (size_t index = 0; index < actionsToExpand.size(); ++index) {
            if (actionsToExpand[index] == index) {
                // std::cout << "index = " << index << std::endl;
                // std::cout << "heuristic.size() = " << heuristic.size() << std::endl;
                // std::cout << "heuristic[index].size() = " << heuristic[index].size() << std::endl;
                // std::cout << "state.toString() = " << state.toString() << std::endl;

                // assert(heuristic.count(index));

                // TODO: Re-scale result before the storing
                double result = 0.0;

                // It might happen that an action is not visited at all during 
                // the training (not with IDS nor with THTS, 
                // I.e. triangle_tireworld_inst_mdp__3 action 31) (??)
                // So in that case we assign -inf to qValue for that action
                if (heuristic[index].size() == 0) {
                    // result = minReward;
                    result = -std::numeric_limits<double>::max();
                    std::cout << "Missing action " << index << std::endl;
                } else {
                    // No combination of features is done
                    if (combineFeatures == 1) {
                        size_t j = 0;
                        for(double const& ss : state.deterministicStateFluents){
                            if (MathUtils::doubleIsEqual(0, ss)) {
                                result += heuristic[index][j];
                                j += 2;
                            } else {
                                j++;
                                result += heuristic[index][j];
                                j++;
                            } 
                        }
                        for(double const& ss : state.probabilisticStateFluents){
                            if (MathUtils::doubleIsEqual(0, ss)) {
                                result += heuristic[index][j];
                                j += 2;
                            } else {
                                j++;
                                result += heuristic[index][j];
                                j++;
                            } 
                        }
                    } else if (combineFeatures == 2) {
                        // Combine deterministicStateFluents 
                        // and probabilisticStateFluents into one vector
                        std::vector<double> stateValues(
                                        state.deterministicStateFluents.begin(),
                                        state.deterministicStateFluents.end());
                        stateValues.insert(stateValues.end(), 
                                        state.probabilisticStateFluents.begin(), 
                                        state.probabilisticStateFluents.end());

                        // For every combination of 2 features in created vector
                        // get the weight and add to result
                        for (size_t indexI = 0; 
                             indexI < stateValues.size();
                             indexI += 2) {
                            for(size_t indexJ = indexI + 2; 
                                indexJ < stateValues.size();
                                indexJ += 2) {

                                // int factI = ((int)stateValues[indexI] + 1) % 2;
                                // int factJ = ((int)stateValues[indexJ] + 1) % 2;

                                std::ptrdiff_t pos = comboValue(
                                                        indexI, indexJ,
                                                        stateValues.size()
                                                        );

                                result += heuristic[index][pos];
                            }
                        }
                    }

                }
                
                qValues[index] = result;
                // std::cout << "\tqValues[" << index << "]" << result << std::endl; 
                // if (MathUtils::doubleIsMinusInfinity(qValues[index])) {
                //     std::cout << "result: " << result << std::endl; 
                // }
                // assert(!MathUtils::doubleIsMinusInfinity(qValues[index]));
            }

            if (cachingEnabled) {
                rewardCache[state] = qValues;
            }
        }
    }

    for (size_t index = 0; index < qValues.size(); ++index) {
        // qValues[index] *= static_cast<double>(state.stepsToGo());

        if (actionsToExpand[index] == index) {
            qValues[index] *= static_cast<double>(state.stepsToGo());
            // std::cout << "\tqValues[" << index << "]" << qValues[index] << std::endl; 
            assert(!MathUtils::doubleIsMinusInfinity(qValues[index]));
        } else {
            qValues[index] = -std::numeric_limits<double>::max();
        }

    }
}

/******************************************************************
                       Approximation
******************************************************************/

// If the action has been evaluated before, it is just updated, otherwise it is
// all of the coefficients are initiated to a random value in range [-0.1, 0.1]
void OfflineHeuristic::approximate(
            int actionName,
            int epoch, 
            std::vector<std::pair<std::vector<double>, double>>& actionEntries,
            double learningSetSize,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter) {

    if (!heuristic.count(actionName)) {
        if (getCombineFeatures() == 1) {
            std::vector<double> coeffs(dataSet[actionName][0].first.size());
            for (unsigned i = 0; i < coeffs.size(); i++) {
                coeffs[i] = MathUtils::rnd->genDouble(-0.1, 0.1);
            }
            heuristic[actionName] = coeffs;
        } else if (getCombineFeatures() == 2) {
            size_t coeffSize = dataSet[actionName][0].first.size() 
                                * dataSet[actionName][0].first.size();

            std::vector<double> coeffs(coeffSize, 0.0);
            for (size_t i = 0; i < coeffs.size(); i++) {
                coeffs[i] = MathUtils::rnd->genDouble(-0.1, 0.1);
            }

            heuristic[actionName] = coeffs;            
        } else {
            SystemUtils::abort("Number of combined features not implemented.");
        }

    }

    switch(methodType) {
        case MethodType::StochasticGradientDescent: {
            StochasticGradientDescentApproximation::updateCoefficients(
                actionName, // name of the action 
                epoch, // current epoch of approximation
                heuristic[actionName], // coefficients
                actionEntries, // Test data gathered for that function
                learningSetSize, // To know how much from data set to take
                step, // Step
                momentum, // Momentum
                squaredDifference, // for calculating total MSE
                squaredDifferencePerAction, // For calculating MSE per action
                squaredDifferencePerActionCounter,
                this); // name of the log file for storing stats 
            break;
        } case MethodType::GradientDescent: {
            GradientDescentApproximation::updateCoefficients(
                actionName,
                epoch, 
                heuristic[actionName], 
                actionEntries,
                learningSetSize, 
                step, 
                momentum,
                squaredDifference, 
                squaredDifferencePerAction,
                squaredDifferencePerActionCounter,
                this);
            break;
        } case MethodType::MiniBatchGradientDescent: {
            MiniBatchGradientDescentApproximation::updateCoefficients(
                actionName,
                epoch, 
                heuristic[actionName], 
                actionEntries,
                learningSetSize, 
                step, 
                momentum,
                squaredDifference, 
                squaredDifferencePerAction,
                squaredDifferencePerActionCounter,
                this);
            break;            
        } default: {
            SystemUtils::abort("Unknown method for learning.");
            break;
        } 
    };

}
