#ifndef OFFLINE_HEURISTIC_H
#define OFFLINE_HEURISTIC_H

#include "search_engine.h"

#include <map>
#include <vector>

class State;

class OfflineHeuristic : public ProbabilisticSearchEngine {
public:
    enum MethodType {
        GradientDescent,
        MiniBatchGradientDescent,
        StochasticGradientDescent
        // NeuralNetwork
    };

    enum LearnType {
        LearnOnly,
        LearnAndTest
    };

    OfflineHeuristic();

    // This is called initially to learn parameter values from a random training
    // set.
    void learn() override;

    // Use learningSetSize % of states for training 
    // and 1-learningSetSize % for testing
    void learnFromTrainingSet(double learningSetSize);

    // Set parameters from command line
    bool setValueFromString(std::string& param, std::string& value) override;

    // Start the search engine to calculate best actions
    // void estimateBestActions(State const& _rootState,
    //                          std::vector<int>& bestActions) override;

    // Start the search engine to estimate the Q-value of a single action
    void estimateQValue(State const& /*state*/, int /*actionIndex*/,
                        double& /*qValue*/) override {
        assert(false);
    }

    // Start the search engine to estimate the Q-values of all applicable
    // actions
    void estimateQValues(State const& state,
                         std::vector<int> const& actionsToExpand,
                         std::vector<double>& qValues) override;

    void approximate(int actionName,
            int epoch, 
            std::vector<std::pair<std::vector<double>, double>>& actionEntries,
            double learningSetSize,
            double& squaredDifference, 
            std::map<int, double>& squaredDifferencePerAction, 
            std::map<int, int>& squaredDifferencePerActionCounter);

    void setMethod(MethodType _method) {
        methodType = _method;
    }

    void setStep(double _step);

    void setHeuristicLoadLocation(std::string _loc) {
        loadHeuristicFrom = _loc;
    }

    void setTrainingSetLoadLocation(std::string _loc) {
        loadTrainingSetFrom = _loc;
    }

    void setNumberOfEpochs(int _numberOfEpochs) {
        numberOfEpochs = _numberOfEpochs;
    }

    void setCombineFeaturesSize(int _combineFeatures) {
        combineFeatures = _combineFeatures;
    }

    int getCombineFeatures() {
        return combineFeatures;
    }

    void setLearnType(LearnType _learnType) {
        learnType = _learnType;
    }

    void setMomentum(double _momentum) {
        momentum = _momentum;
    }

    void setLearningDecayRate(double _learningDecayRate) {
        learningDecayRate = _learningDecayRate;
    }

    void setLogOption(bool _log) {
        log = _log;
    }

    void setOutputDir(std::string _outputDir) {
        outputDir = _outputDir;
    }

    void setTrainingSetGenerator(SearchEngine* _trainingSetGenerator) {
        if (trainingSetGenerator) {
            delete trainingSetGenerator;
        }
        trainingSetGenerator = _trainingSetGenerator;
    }
    
    // void setEstimateSource(bool ids, bool thts) {
    //     idsEstimate = ids;
    //     thtsEstimate = thts;
    // }

    bool loggingOn() {
        return log;
    }

    int getNumberOfEpochs() {
        return numberOfEpochs;
    }

    void storeHeuristic(std::ostream& out) override;
    void loadHeuristic();

    void parseCoefficientsFromString(std::string coeffsString,
				     std::vector<double>& coeffs);
    
    std::string getDetails() override;
    std::string getLearnDesc();
    std::string getMethodDesc();
    
    void outputTrainingData();

    std::map<int, std::vector<std::pair<std::vector<double>, double>>> dataSet;
    std::map<int, std::vector<double>> heuristic;

    // It is the same for every action so that is the reason we only need vector 
    // std::unordered_map<std::string, int> weightsMapping;
    size_t comboValue(int indexI, int indexJ, size_t numberOfFeatures);

private:
    SearchEngine* trainingSetGenerator;
    MethodType methodType; // Method
    double step; // How fast to the solution
    std::string loadHeuristicFrom; // Path from where to load heuristic
    std::string loadTrainingSetFrom; // Path from where to load training data
    int numberOfEpochs; // How many times estimations are done
    LearnType learnType; // All the states for learning or some also for testing
    double momentum; // Helps escaping local minimas 
    double learningDecayRate; // Decreases the learning step as the time flies
    std::string details; // heuristic description
    int combineFeatures; // Number of features that can be combined for approx

    void storeTrainingData(double learningSetSize);
    void loadTrainingData(); 
    bool performTraining(double learningSetSize, int epoch);
    void performTesting(double learningSetSize, int epoch);
    bool log;

    std::string outputDir;

    // Caching
    typedef std::unordered_map<State, std::vector<double>,
                               State::HashWithoutRemSteps,
                               State::EqualWithoutRemSteps>
        HashMap;
    static HashMap rewardCache;

    double minReward, maxReward;
    double oldTotalMSE, newTotalMSE;
};

#endif
