#ifndef APPROXIMATION_H
#define APPROXIMATION_H

#include "states.h" // Could be done with forward declaration?
#include "offline_heuristic.h"

#include <vector>
#include <string>
#include <map>

// Difference betwen real and calculated value applied after every check
// In reality - Stochastic Gradient Descent with error function abs(calc - real)
// and batch size of 1
class StochasticGradientDescentApproximation {
public:
    static void updateCoefficients(
            int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries, 
            double learningSetSize,
            double step, 
            double /*momentum*/,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter,
            OfflineHeuristic* offlineHeuristic = nullptr);
};

// Derivative being done on MSE Cost function
class GradientDescentApproximation {
public:
    static void updateCoefficients(
            int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries, 
            double learningSetSize,
            double step, 
            double /*momentum*/,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter,
            OfflineHeuristic* offlineHeuristic = nullptr);   
};


// Derivative being done on MSE Cost function of a part of the training set
// with implementation of momentum
class MiniBatchGradientDescentApproximation {
public:
    static void updateCoefficients(
            int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries, 
            double learningSetSize,
            double step, 
            double momentum,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter,
            OfflineHeuristic* offlineHeuristic = nullptr);   
    static int batchSize;
};

/******************************************************************
                       General use functions
******************************************************************/
double calculateSumOfMultipliers(std::vector<double> coefficients, 
                                 std::vector<double> stateValues,
                                 OfflineHeuristic* offlineHeuristic);
std::vector<double> parseStateToValues(State state);
bool compareStates(State const& lhs, std::vector<double> const& rhs);
void logStats(int actionName,
            int epoch,
            std::vector<double>& coefficients,
            std::vector<std::pair<std::vector<double>, double>>& actionEntries,
            double learningSetSize,
            OfflineHeuristic* offlineHeuristic, 
            double& squaredDifference,
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter);

#endif
