#include "online_heuristic.h"

#include "approximation.h"
#include "iterative_deepening_search.h"
#include "utils/random.h"
#include "utils/system_utils.h"
#include "utils/string_utils.h"

#include <fstream>
#include <iostream>
#include <limits>
#include <numeric>
#include <stdlib.h>
#include <sstream>
#include <vector>

/******************************************************************
            OnlineHeuristic Search Engine Creation
******************************************************************/

OnlineHeuristic::OnlineHeuristic()
    : ProbabilisticSearchEngine("OnlineHeuristic"),
      methodType(MethodType::StochasticGradientDescent),
      step(0.5),
      numberOfEpochs(1),
      momentum(0),
      learningDecayRate(0),
      details(""),
      trainingSetSize(100),
      log(false) {}

OnlineHeuristic* OnlineHeuristic::fromString(std::string& desc, THTS* _thts) {
    StringUtils::trim(desc);
    assert(desc[0] == '[' && desc[desc.size() - 1] == ']');
    StringUtils::removeFirstAndLastCharacter(desc);
    StringUtils::trim(desc);

    OnlineHeuristic* result = nullptr;

    if (desc.find("OnlineHeuristic") == 0) {
        desc = desc.substr(15, desc.size());
        result = new OnlineHeuristic();    
    } else {
        SystemUtils::abort("Unknown online heuristic");
    }

    assert(result);
    StringUtils::trim(desc);

    while (!desc.empty()) {
        std::string param;
        std::string value;
        StringUtils::nextParamValuePair(desc, param, value);

        if (!result->setValueFromString(param, value)) {
            SystemUtils::abort("Unused parameter value pair: " + param + " / " +
                               value);
        }
    }

    result->thts = _thts;
    return result;
}

void OnlineHeuristic::performTraining(double learningSetSize, 
                                      int epoch,
                                      int actionState) {
    // std:: cout << "Online training started. Epoch: " << epoch << std::endl;

    step -= learningDecayRate * step; // Slowing down the learning 

    double squaredDifference;
    std::map<int, double> squaredDifferencePerAction;
    std::map<int, int> squaredDifferencePerActionCounter;
    int trainingSetSize = 0;
    for(auto& actionEntry : dataSet) {
        int actionName = actionEntry.first;
        if (actionName != actionState) {
            continue;
        }

        // Shuffling dataset
        std::random_shuffle(actionEntry.second.begin(), 
                                        actionEntry.second.end());

        squaredDifferencePerAction.clear();
        squaredDifferencePerActionCounter.clear();
        squaredDifference = 0.0;
        approximate(actionName, 
                    epoch,
                    actionEntry.second, 
                    learningSetSize,
                    squaredDifference,
                    squaredDifferencePerAction, 
                    squaredDifferencePerActionCounter);
        trainingSetSize += actionEntry.second.size() * learningSetSize;
    }

    if (loggingOn()) {
        SystemUtils::log2(getDetails(), 
                        "MSE in training per action for"
                            "iteration #", epoch);
        for (auto& entry : squaredDifferencePerAction) {
            SystemUtils::log2(getDetails(), "MSE training for " 
                + std::to_string(epoch) + 
                "#" + std::to_string(entry.first) + ": ", 
                entry.second / squaredDifferencePerActionCounter[entry.first]);
        }
        
        SystemUtils::log2(getDetails(), "Total MSE in training: ", 
                squaredDifference / trainingSetSize);

        SystemUtils::log2(getDetails(), "Training finished. "
                                        "Number of states evaluated: ",
                                        trainingSetSize);
    }

    // std:: cout << "Online training finished. Epoch: " << epoch << std::endl;
}

bool OnlineHeuristic::setValueFromString(std::string& param, std::string& value) {
    if (param == "-step") {
        // Set the step of the learning rate
        setStep(atof(value.c_str()));
        return true;
    } else if (param == "-method") {
        // Method that's being applied for learning
        if (value == "StochasticGradientDescent") {
            setMethod(OnlineHeuristic::MethodType::StochasticGradientDescent);
        } else if (value == "GradientDescent") { 
            setMethod(OnlineHeuristic::MethodType::GradientDescent);
        } else {
            SystemUtils::abort("Unknown learning method: " + value);
        }
        return true;
    } else if (param == "-numberOfEpochs") {
            setNumberOfEpochs(atoi(value.c_str()));
        return true;
    } else if (param == "-momentum") {
        setMomentum(atoi(value.c_str()));
        return true;
    } else if (param == "-learningDecayRate") {
        setLearningDecayRate(atoi(value.c_str()));
        return true;
    } else if (param == "-rewardType") {
        if (value == "TRIAL") {
            setRewardType(OnlineHeuristic::RewardType::TrialReward);
        } else if (value == "QVALUE") { 
            setRewardType(OnlineHeuristic::RewardType::QValue);
        } else {
            SystemUtils::abort("Unknown reward type: " + value);
        }
        return true;
    } else if (param == "-trainingSetSize") {
        setTrainingSetSize(atoi(value.c_str()));
        return true;
    } else if (param == "-logging") {
        if (value == "ON") {
            setLogOption(true);
            return true;
        } else if (value == "OFF") {
            setLogOption(false);
            return true;
        }
        return false;
    }    
    return ProbabilisticSearchEngine::setValueFromString(param, value);
}

void OnlineHeuristic::setStep(double _step) {
    if (MathUtils::doubleIsGreater(_step, 1) || MathUtils::doubleIsSmaller(_step, 0)) {
        SystemUtils::abort("Step takes value between 0 and 1.'");
    }
    step = _step;
}

void OnlineHeuristic::storeHeuristic(std::ostream& out) {
    out << "#### " << name << ": ";
    out << SearchEngine::taskName << " ####" << std::endl;
    out << "# " <<  getDetails() << std::endl;
    out << "# Number of actions" << std::endl;
    out << heuristic.size() << std::endl;
    for(auto& action : heuristic) {
        out << "# Action: " << std::endl;
        out << action.first << std::endl;

        std::string pom = "[";
        for(auto& coeff : action.second) {
            std::ostringstream temp;
            temp << coeff;
            pom += temp.str() + ",";
        }
        pom = pom.substr(0, pom.size() - 1);
        pom +=  "]\n";

        out << pom;

        out << "########" << std::endl;
    }
}

void OnlineHeuristic::parseCoefficientsFromString(std::string coeffsString,
					      std::vector<double>& coeffs) {
    StringUtils::removeFirstAndLastCharacter(coeffsString);
    std::vector<std::string> res;
    StringUtils::split(coeffsString, res, ",");

    coeffs.clear();
    coeffs.resize(res.size());

    for(unsigned i = 0; i < res.size(); i++) {
        coeffs[i] = atof(res[i].c_str());
    }
}

std::string OnlineHeuristic::getMethodDesc() {
    switch(methodType) {
        case MethodType::StochasticGradientDescent: {
            return "StochasticGradientDescent";
        } case MethodType::GradientDescent: {
            return "GradientDescent";
        } default: {
            SystemUtils::abort("Unknown method for learning.");
	        return "";
        } 
    };  
} 

std::string OnlineHeuristic::getRewardDesc() {
    switch(rewardType) {
        case RewardType::TrialReward: {
            return "rewardType=TrialReward";
        } case RewardType::QValue: {
            return "rewardType=QValue";
        } default: {
            SystemUtils::abort("Unknown rewardType for learning.");
	        return "";
        } 
    };  
    
}

std::string OnlineHeuristic::getDetails() {
    if (details != "") {
        return details;
    }

    std::stringstream str;
    str << SearchEngine::name << "_";
    str << SearchEngine::taskName << "_";
    str << getMethodDesc() << "_";
    str << getRewardDesc();
    str << "-step=" << step;
    str << "-numberOfEpochs=" << numberOfEpochs;
    details = str.str();
    return details;
}

void OnlineHeuristic::addToDataSet(State state, int actionState, 
                            double const& trialReward, double const& qValue) {

    // Different collection of data based on which reward type is used
    if (rewardType == RewardType::TrialReward) {
        // If we collect trial reward, we collect every state
        std::pair<std::vector<double>, double> p = 
                                std::make_pair(parseStateToValues(state), trialReward);
        dataSet[actionState].push_back(p);
    } else if (rewardType == RewardType::QValue) {
        // If we collect QValue, we should always update the value 
        // for already collected state
        for (unsigned i = 0; i < dataSet[actionState].size(); i++) {
            if (compareStates(state, dataSet[actionState][i].first)) {
                dataSet[actionState][i].second = qValue;
                continue;
            } 
        }
        std::pair<std::vector<double>, double> p = 
                std::make_pair(parseStateToValues(state), qValue);
        dataSet[actionState].push_back(p);
    } else {
        SystemUtils::abort("Reward type is not set.");
    }

    // Once the certain number of entries is collected for an action
    // Training is done for that action
    if (dataSet[actionState].size() == trainingSetSize) {
        for(unsigned j = 0; j < numberOfEpochs; j++) {
            performTraining(1, j, actionState);
        }
        dataSet[actionState].clear();
    }
}

/******************************************************************
                       Main Search Functions
******************************************************************/
void OnlineHeuristic::estimateQValues(State const& state,
                         std::vector<int> const& actionsToExpand,
                         std::vector<double>& qValues) {
    assert(qValues.size() == actionsToExpand.size());
    for (size_t index = 0; index < actionsToExpand.size(); ++index) {
        if (actionsToExpand[index] == index) {
            assert(heuristic.count(index));
            qValues[index] = calculateSumOfMultipliers(heuristic[index], 
                                parseStateToValues(state),
                                nullptr);            
        }
    }
}


/******************************************************************
                       Approximation
******************************************************************/
// If the action has been evaluated before, it is just updated, otherwise it is
// all of the coefficients are initiated to a random value in range [-0.1, 0.1]
void OnlineHeuristic::approximate(
            int actionName,
            int epoch, 
            std::vector<std::pair<std::vector<double>, double>>& actionEntries,
            double learningSetSize,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter) {

    if (!heuristic.count(actionName)) {
        std::vector<double> coeffs(dataSet[actionName][0].first.size());
        for (unsigned i = 0; i < dataSet[actionName][0].first.size(); i++) {
            coeffs[i] = MathUtils::rnd->genDouble(-0.1, 0.1);
        }
        heuristic[actionName] = coeffs;
    }

    switch(methodType) {
        case MethodType::StochasticGradientDescent: {
            StochasticGradientDescentApproximation::updateCoefficients(
                actionName, // name of the action 
                epoch, // current epoch of approximation
                heuristic[actionName], // coefficients
                actionEntries, // Test data gathered for that function
                learningSetSize, // To know how much from data set to take
                step, // Step
                momentum, // Momentum
                squaredDifference, // for calculating total MSE
                squaredDifferencePerAction, // For calculating MSE per action
                squaredDifferencePerActionCounter
                ); 
            break;
        } case MethodType::GradientDescent: {
            GradientDescentApproximation::updateCoefficients(
                actionName,
                epoch, 
                heuristic[actionName], 
                actionEntries,
                learningSetSize, 
                step, 
                momentum,
                squaredDifference, 
                squaredDifferencePerAction,
                squaredDifferencePerActionCounter);
            break;            
        } default: {
            SystemUtils::abort("Unknown method for learning.");
            break;
        } 
    };

}
