#include "approximation.h"

#include "utils/math_utils.h"
#include "utils/system_utils.h"

#include <numeric>
#include <iostream>

std::map<int, std::vector<double>> lastChange;


/******************************************************************
                Stochastic Gradient Descent 
******************************************************************/
void StochasticGradientDescentApproximation::updateCoefficients(
            int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries, 
            double learningSetSize,
            double step, 
            double momentum,
            double& squaredDifference,
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter,
            OfflineHeuristic* offlineHeuristic) { 

    if (!lastChange.count(actionName)) {
        lastChange[actionName].assign(coefficients.size(), 0);
    }

    // For every entry up to learningSize do calculations 
    for(size_t i = 0; i < actionEntries.size() * learningSetSize; i++) {
        double reward = actionEntries[i].second;
        double calculatedReward = calculateSumOfMultipliers(coefficients, 
                                        actionEntries[i].first,
                                        offlineHeuristic);
        double difference = calculatedReward - reward;

        assert(!(coefficients.size() == 0));
        double delta = difference * step;
        if (offlineHeuristic->getCombineFeatures() == 1) {
            for(size_t it = 0; it < coefficients.size(); it++) {
                coefficients[it] -= actionEntries[i].first[it] * delta
                                + momentum * lastChange[actionName][it];

                lastChange[actionName][it] = delta;
            }
        } else if (offlineHeuristic->getCombineFeatures() == 2) {
            // When we are combining two features, for every possible combination
            // we check, based on the current state, if the certain weight
            // should be changed. Using string hashmap `weightsMapping` 
            // we find the exact position of the coefficient we want to change
            for (size_t indexI = 0; 
                    indexI < actionEntries[i].first.size();
                    indexI += 2) {
                for(size_t indexJ = indexI + 2; 
                        indexJ < actionEntries[i].first.size();
                        indexJ += 2) {

                    std::ptrdiff_t pos = offlineHeuristic->comboValue(
                        indexI, indexJ,
                        actionEntries[i].first.size()
                    );

                    assert(pos > 0);
                    assert(pos < coefficients.size());

                    coefficients[pos] -= delta 
                                        + momentum * lastChange[actionName][pos];
    
                    lastChange[actionName][pos] = delta;
                }
            }
        } else {
            SystemUtils::abort("Given number of combined"
                                " features not supported");
        }
    }    
    
    logStats(actionName, 
                epoch, 
                coefficients,
                actionEntries,
                learningSetSize, 
                offlineHeuristic,
                squaredDifference, 
                squaredDifferencePerAction, 
                squaredDifferencePerActionCounter);
}

/******************************************************************
                       Gradient Descent
******************************************************************/
void GradientDescentApproximation::updateCoefficients(
            int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries, 
            double learningSetSize,
            double step, 
            double momentum,
            double& squaredDifference,
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter,
            OfflineHeuristic* offlineHeuristic) { 

    if (!lastChange.count(actionName)) {
        lastChange[actionName].assign(coefficients.size(), 0);
    }

    std::vector<double> derivative(coefficients.size(), 0.0);
    double con = 1.0 / (actionEntries.size() * learningSetSize);
    for(size_t j = 0; j < actionEntries.size() * learningSetSize; j++) {
        double rewardDiff = calculateSumOfMultipliers(coefficients, 
                                                        actionEntries[j].first,
                                                        offlineHeuristic) 
                            - actionEntries[j].second;

        if (offlineHeuristic->getCombineFeatures() == 1) {        
            for(size_t i = 0; i < coefficients.size(); i++) {
                derivative[i] += con * (rewardDiff) * actionEntries[j].first[i];
            }
        } else if (offlineHeuristic->getCombineFeatures() == 2) {
            // When we are combining two features, for every possible combination
            // we check, based on the current state, if the certain weight
            // should be changed. Using string hashmap `weightsMapping` 
            // we find the exact position of the coefficient we want to change
            for (size_t indexI = 0; 
                    indexI < actionEntries[j].first.size();
                    indexI += 2) {
                for(size_t indexJ = indexI + 2; 
                        indexJ < actionEntries[j].first.size();
                        indexJ += 2) {

                    std::ptrdiff_t pos = offlineHeuristic->comboValue(
                        indexI, indexJ,
                        actionEntries[j].first.size()
                    );

                    derivative[pos] += con * (rewardDiff);
                }
            }
        } else {
            SystemUtils::abort("Given number of combined"
                                " features not supported");
        }
    }

    for(size_t i = 0; i < coefficients.size(); i++) {
        coefficients[i] -= step * derivative[i]
                        + momentum * lastChange[actionName][i];

        lastChange[actionName][i] = derivative[i];
    }

    logStats(actionName, 
                epoch, 
                coefficients,
                actionEntries,
                learningSetSize, 
                offlineHeuristic,
                squaredDifference, 
                squaredDifferencePerAction, 
                squaredDifferencePerActionCounter);
}

/******************************************************************
                    MiniBatch Gradient Descent
******************************************************************/
// TODO: make this as a parameter of the heuristic
// Batch size is being updated to 10% of the training set if that is possible
int MiniBatchGradientDescentApproximation::batchSize = 100; 
void MiniBatchGradientDescentApproximation::updateCoefficients(
            int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries, 
            double learningSetSize,
            double step, 
            double momentum,
            double& squaredDifference,
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter,
            OfflineHeuristic* offlineHeuristic) { 

    if (!lastChange.count(actionName)) {
        lastChange[actionName].assign(coefficients.size(), 0);
    }

    // While whole training set is not exhausted, repeat computation
    size_t start = 0;
    size_t end = 0;
    size_t total = 0;
    size_t currentBatchSize = MiniBatchGradientDescentApproximation::batchSize;
    while (total < actionEntries.size() * learningSetSize){
        start = end;
        end = start + MiniBatchGradientDescentApproximation::batchSize;
        total += MiniBatchGradientDescentApproximation::batchSize;
        if (end > actionEntries.size() * learningSetSize) {
            end = actionEntries.size() * learningSetSize;
            currentBatchSize = end - start;
        }
        assert(end <= actionEntries.size() * learningSetSize);
        
        std::vector<double> derivative(coefficients.size(), 0.0);
        double con = 1.0 / currentBatchSize;
        for(size_t j = start; j < end; j++) {
            double rewardDiff = calculateSumOfMultipliers(coefficients, 
                                                        actionEntries[j].first,
                                                        offlineHeuristic) 
                                - actionEntries[j].second;

            if (offlineHeuristic->getCombineFeatures() == 1) {        
                for(size_t i = 0; i < coefficients.size(); i++) {
                    derivative[i] += con * (rewardDiff) * actionEntries[j].first[i];
                }
            } else if (offlineHeuristic->getCombineFeatures() == 2) {
                // When we are combining two features, for every possible combination
                // we check, based on the current state, if the certain weight
                // should be changed. Using string hashmap `weightsMapping` 
                // we find the exact position of the coefficient we want to change
                for (size_t indexI = 0; 
                        indexI < actionEntries[j].first.size();
                        indexI += 2) {
                    for(size_t indexJ = indexI + 2; 
                            indexJ < actionEntries[j].first.size();
                            indexJ += 2) {

                        std::ptrdiff_t pos = offlineHeuristic->comboValue(
                            indexI, indexJ,
                            actionEntries[j].first.size()
                        );

                        derivative[pos] += con * (rewardDiff);
                    }
                }
            } else {
                SystemUtils::abort("Given number of combined"
                                    " features not supported");
            }
        }

        for(size_t i = 0; i < coefficients.size(); i++) {
            coefficients[i] -= step * derivative[i] 
                + momentum * lastChange[actionName][i];
                
            lastChange[actionName][i] = derivative[i];
        }
    }

    logStats(actionName, 
                epoch, 
                coefficients,
                actionEntries,
                learningSetSize, 
                offlineHeuristic,
                squaredDifference, 
                squaredDifferencePerAction, 
                squaredDifferencePerActionCounter);
}

/******************************************************************
                       General use functions
******************************************************************/

double calculateSumOfMultipliers(std::vector<double> coefficients,
                                std::vector<double> stateValues,
                                OfflineHeuristic* offlineHeuristic) {
    
    // For needs of OnlineHeuristic
    if (!offlineHeuristic || offlineHeuristic->getCombineFeatures() == 1 ) {
        double res = 0.0;
        assert(coefficients.size() == stateValues.size());
        for(size_t i = 0; i < coefficients.size(); i++) {
            res += coefficients[i] * stateValues[i];
        }
        return res;
    } else if (offlineHeuristic->getCombineFeatures() == 2) {
        double res = 0.0;
        for (size_t indexI = 0; 
                indexI < stateValues.size();
                indexI += 2) {
            for(size_t indexJ = indexI + 2; 
                    indexJ < stateValues.size();
                    indexJ += 2) {

                std::ptrdiff_t pos = offlineHeuristic->comboValue(
                    indexI, indexJ,
                    stateValues.size()
                );

                res += coefficients[pos];
            }
        }
        return res; 
    } else {
        SystemUtils::abort("Calculating reward with this number of features"
                            " not implemented.");
        return 0;
    }
}


std::vector<double> parseStateToValues(State state) {
    std::vector<double> values((state.deterministicStateFluents.size() 
                                + state.probabilisticStateFluents.size())
                            * 2 , 0.0);
    size_t j = 0;
    for(double ss : state.deterministicStateFluents) {
        if (MathUtils::doubleIsEqual(0, ss)) {
            values[j++] = 1;
            values[j++] = 0;
        } else {
            values[j++] = 0;           
            values[j++] = 1;
        } 
    }

    for(double ss : state.probabilisticStateFluents) {
        if (MathUtils::doubleIsEqual(0, ss)) {
            values[j++] = 1;
            values[j++] = 0;
        } else {
            values[j++] = 0;           
            values[j++] = 1;
        } 
    }

    return values;
}

bool compareStates(State const& lhs, std::vector<double> const& rhs) {
    assert(rhs.size() 
                == (lhs.deterministicStateFluents.size() 
                    + lhs.probabilisticStateFluents.size()) * 2);
    size_t j = 0;
    for (size_t i = 0; i < lhs.deterministicStateFluents.size();
            ++i) {
        if (MathUtils::doubleIsSmaller(
                rhs[j++],
                lhs.deterministicStateFluents[i])) {
            return false;
        } else if (MathUtils::doubleIsSmaller(
                        lhs.deterministicStateFluents[i],
                        rhs[j++])) {
            return true;
        }
    }

    for (size_t i = 0; i < lhs.probabilisticStateFluents.size();
            ++i) {
        if (MathUtils::doubleIsSmaller(
                rhs[j++],
                lhs.probabilisticStateFluents[i])) {
            return false;
        } else if (MathUtils::doubleIsSmaller(
                        lhs.probabilisticStateFluents[i],
                        rhs[j++])) {
            return true;
        }
    }
    return false;
}

void logStats(int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries,
            double learningSetSize,
            OfflineHeuristic* offlineHeuristic, 
            double& squaredDifference,
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter) {

        
    // Evaluating after the changing of coefficients
    for(size_t i = 0; i < actionEntries.size() * learningSetSize; i++) {
        double calculatedReward = 
                            calculateSumOfMultipliers(coefficients, 
                                    actionEntries[i].first, 
                                    offlineHeuristic);
        double reward = actionEntries[i].second;

        if (offlineHeuristic && offlineHeuristic->loggingOn() /*&&
                        epoch == (offlineHeuristic->getNumberOfEpochs() - 1)*/) {
            SystemUtils::log2(offlineHeuristic->getDetails(), 
                                "Actual reward in training " + 
                                std::to_string(epoch) + "#" + 
                                std::to_string(actionName) +
                                ": ", reward);
            SystemUtils::log2(offlineHeuristic->getDetails(), 
                                "Calculated reward in training " + 
                                std::to_string(epoch) + "#" + 
                                std::to_string(actionName) +
                                ": ", calculatedReward);
        }
        squaredDifference += (reward - calculatedReward) 
                            * (reward - calculatedReward);

        if (squaredDifferencePerAction.count(actionName)) {
            squaredDifferencePerAction[actionName] += 
                    (reward - calculatedReward) 
                    * (reward - calculatedReward);
            squaredDifferencePerActionCounter[actionName] +=1;
        } else {
            squaredDifferencePerAction[actionName] = 
                    (reward - calculatedReward) 
                    * (reward - calculatedReward);
            squaredDifferencePerActionCounter[actionName] = 1;
        }
    }
}
