#ifndef INCREASING_HORIZON_NN_H
#define INCREASING_HORIZON_NN_H

#include "search_engine.h"
#include "states.h"
#include "network.h"
#include "network_trainer.h"

#include "utils/stopwatch.h"

class IncreasingHorizonNN : public ProbabilisticSearchEngine{
public:
    IncreasingHorizonNN()
    :ProbabilisticSearchEngine("IHNN"){}

    // Set parameters from command line
    bool setValueFromString(std::string& param, std::string& value);

    // Notify the search engine that the session starts
    void initSession();


    // Start the search engine to estimate the Q-value of a single action
    void estimateQValue(State const& state, int actionIndex,
                                double& qValue);

    // Start the search engine to estimate the Q-values of all applicable
    // actions
    void estimateQValues(State const& _rootState,
                                 std::vector<int> const& actionsToExpand,
                                 std::vector<double>& qValues);

    void printRoundStatistics(std::string /*indent*/) const {}
    void printStepStatistics(std::string /*indent*/) const {}

    void initQValueNetwork();
    void initPolicyNetwork();
    void initHorizonNetwork();

    void trainQValueNetworks();
    void trainPolicyNetworks();
    void trainHorizonNetwork();

    

    void calcActualReward(State const& state,
                            std::vector<int> const& actionsToExpand,
                            std::vector<double>& qValues,
                            int stepsToGo);
    
    void monitorRAMUsage();

    void evaluateTask();

    //Parameter setter
    void setHiddenLayer(int _numberOfHiddenLayer){
        numberOfHiddenLayer = _numberOfHiddenLayer;
    }    
    void setLayerBreadth(int _layerBreadth){
        layerBreadth = _layerBreadth;
    }
    void setNumberOfEpochs(int _numberOfEpochs){
        numberOfEpochs = _numberOfEpochs;
    }

    void setInputLayerSize(int _inputLayerSize){
        inputLayerSize = _inputLayerSize;
    }

    void setBatchSize(int _batchSize){
        if(_batchSize == 99){
            batchSize=trainingSet.size();
        }else{
            batchSize = _batchSize;
        }
    }

    void setLearningRate(double _learningRate){
        learningRate = _learningRate;
    }

    void setNetworkPreparation(int _preparation){
        if(_preparation == 1){
            networkPreparation = true;
        }
    }

    void setSampleExpansionThreshold(int _threshold){
        threshold = _threshold;
    }

    void setModulePath(std::string _modulePath){
        modulePath = _modulePath; 
    }

    void setNetworkType(std::string _networkType){
        
        if(_networkType == "qvalue"){
            networkType = NetworkType::QValueNetwork;
        }else if(_networkType == "policy"){
            networkType = NetworkType::PolicyNetwork;
        }else if(_networkType == "bounded_qvalue"){
            networkType = NetworkType::BoundedQvalue;
        }else if(_networkType == "bounded_policy"){
            networkType = NetworkType::BoundedPolicy;
        }else if(_networkType == "horizon"){
            networkType = NetworkType::HorizonNetwork;
        }else{
            std::cout << "NetworkType undefined, training aborted..." << std::endl;
            exit(0);
        }
    }

    void setVersion(std::string _version){
        version = _version;
    }

    void setHorizonBound(int _bound){
        horizonBound = _bound;
    }

    void loadStates(int input){
        if(input == 1){
            NetworkTrainer::parseStateFile();
        }
    }

    void loadAllStates(int input){
        if(input == 1){
            NetworkTrainer::parseStateFile2();
        }
    }

    void safeTrainingSet(int input){
        if(input == 1){
            NetworkTrainer::safeStates();
        }
    }

private:
    std::vector<Network> nets;
    std::vector<Net> nets_impl;

    Stopwatch stopwatch;
    int inputLayerSize = 1;
    int numberOfHiddenLayer = 1;
    int layerBreadth = 1;
    int numberOfEpochs = 1;
    int batchSize = 1;
    int threshold = 100;
    int horizonBound = 10;
    double learningRate = 0.1;
    bool networkPreparation = false;
    std::string modulePath = "src/search/trained_networks/";
    std::string netConfig;
    std::string version;

    NetworkTrainer trainer;
    NetworkType networkType = NetworkType::QValueNetwork;

};


#endif