#ifndef ONLINE_HEURISTIC_H
#define ONLINE_HEURISTIC_H

#include "search_engine.h"

#include <map>
#include <vector>

class State;
class THTS;

class OnlineHeuristic : public ProbabilisticSearchEngine {
public:
    enum MethodType {
        StochasticGradientDescent,
        GradientDescent
        // NeuralNetwork
    };

    enum RewardType {
        TrialReward,
        QValue
    };

    OnlineHeuristic();

    // This is called initially to learn parameter values from a random training
    // set.
    void learn() override {}

    // Set parameters from command line
    bool setValueFromString(std::string& param, std::string& value) override;

    // Create OnlineHeuristic component from string
    static OnlineHeuristic* fromString(std::string& desc, THTS* _thts); 

    // Start the search engine to calculate best actions
    // void estimateBestActions(State const& _rootState,
    //                          std::vector<int>& bestActions) override;

    // Start the search engine to estimate the Q-value of a single action
    void estimateQValue(State const& /*state*/, int /*actionIndex*/,
                        double& /*qValue*/) override {
        assert(false);
    }

    // Start the search engine to estimate the Q-values of all applicable
    // actions
    void estimateQValues(State const& state,
                         std::vector<int> const& actionsToExpand,
                         std::vector<double>& qValues) override;

    void approximate(int actionName,
            int epoch, 
            std::vector<std::pair<std::vector<double>, double>>& actionEntries,
            double learningSetSize,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter);

    void setMethod(MethodType _method) {
        methodType = _method;
    }

    void setStep(double _step);

    void setNumberOfEpochs(int _numberOfEpochs) {
        numberOfEpochs = _numberOfEpochs;
    }

    void setRewardType(RewardType _rewardType) {
        rewardType = _rewardType;        
    }

    void setMomentum(double _momentum) {
        momentum = _momentum;
    }

    void setLearningDecayRate(double _learningDecayRate) {
        learningDecayRate = _learningDecayRate;
    }

    void setTrainingSetSize(int _trainingSetSize) {
        trainingSetSize = _trainingSetSize;
    }

    void setLogOption(bool _log) {
        log = _log;
    }

    bool loggingOn() {
        return log;
    }

    RewardType getRewardType() {
        return rewardType;
    }

    void storeHeuristic(std::ostream& out) override;

    void parseCoefficientsFromString(std::string coeffsString,
				     std::vector<double>& coeffs);
    
    void addToDataSet(State state, int actionState, 
                        double const& trialReward, double const& qValue);

    std::string getDetails() override;
    std::string getMethodDesc();
    std::string getRewardDesc();
    
    std::map<int, std::vector<std::pair<std::vector<double>, double>>> dataSet;
    std::map<int, std::vector<double>> heuristic; 

    THTS* thts;

private:
    MethodType methodType; // Method
    double step; // How fast to the solution
    std::string loadFrom; // Path from where to load heuristic
    int numberOfEpochs; // How many times estimations are done
    RewardType rewardType; // All the states for learning or some for testing
    double momentum; // Helps escaping local minimas 
    double learningDecayRate; // Decreases the learning step as the time flies
    std::string details; // heuristic description
    int trainingSetSize; // How many entries to collect before the training

    void performTraining(double learningSetSize, int epoch, int actionState);
    bool log;
};

#endif
