#include "increasing_horizon_NN.h"
#include "utils/system_utils.h"
#include "initializer.h"

#include "torch/torch.h"

using namespace std;

bool IncreasingHorizonNN::setValueFromString(std::string& param, std::string& value) {
    cout << "Setting " << param << " to " << value << endl;

    if(param == "-hidden"){
        setHiddenLayer(atoi(value.c_str()));
        return true;
    } else if (param == "-breadth"){
        setLayerBreadth(atoi(value.c_str()));
        return true;
    } else if (param == "-epochs"){
        setNumberOfEpochs(atoi(value.c_str()));
        return true;
    } else if (param == "-batch"){
        setBatchSize(atoi(value.c_str()));
        return true;
    } else if (param == "-learning"){
        setLearningRate(stod(value.c_str()));
        return true;
    } else if (param == "-train"){
        setNetworkPreparation(atoi(value.c_str()));
        return true;
    } else if (param == "-threshold"){
        setSampleExpansionThreshold(atoi(value.c_str()));
        return true;
    } else if (param == "-path"){
        setModulePath(value.c_str());
        return true;
    } else if (param == "-nntype"){
        setNetworkType(value.c_str());
        return true;
    } else if (param == "-version"){
        setVersion(value.c_str());
        return true;
    } else if (param == "-bound"){
        setHorizonBound(atoi(value.c_str()));
        return true;
    } else if (param == "-loadPreparedStates"){
        loadStates(atoi(value.c_str()));
        return true;
    } else if (param == "-loadAllStates"){
        loadAllStates(atoi(value.c_str()));
        return true;
    } else if (param == "-createTrainingSet"){
        safeTrainingSet(atoi(value.c_str()));
        return true;
    } else{
        return false;
    } 
}

void IncreasingHorizonNN::initSession(){

    cout << "Init Session..." << endl;
    //./build.py --debug -j 3 && ./prost.py --debug elevators_inst_mdp__1 --parser-options "-rewardInterval ZERO_TO_INF" "[PROST -s 1 -se [UCTStar -init [Expand -h [IHNN -hidden 1 -breadth 1 -batch 50 -learning 0.01 -epochs 20 -nntype policy -train 1 -version 1.0]]]]"

    //./build.py --debug -j 3 && ./prost.py --debug elevators_inst_mdp__1 --parser-options " -trainingSimulations 1000 -trainingSetSize 10000 -rewardInterval ZERO_TO_INF" "[PROST -s 1 -se [UCTStar -init [Expand -h [IHNN -hidden 2 -breadth 3 -batch 100 -learning 0.01 -epochs 100 -nntype policy -train 1 -version 1.0]]]]"
    
    //./build.py --debug -j 3 && ./prost.py --debug elevators_inst_mdp__2 --parser-options "-rewardInterval ZERO_TO_INF" "[PROST -s 1 -se [UCTStar -init [Expand -h [IHNN -hidden 1 -breadth 1 -batch 500 -learning 0.01 -epochs 100 -nntype policy -train 0 -loadPreparedStates 1 -version 1.1_standardSet]]]]"

    stopwatch.reset();    
    
    cout << "SET " << trainingSet.size() << endl;
    //shuffle trainingSet to ensure that we get different states for test/train sets
    random_shuffle(trainingSet.begin(), trainingSet.end());

    if(networkPreparation){
        switch(networkType){
            case NetworkType::QValueNetwork:
            {   
                
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Q-Value_" + version;
                cout << "Generate Q-Value Network ..." << endl;
                setInputLayerSize(SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size());
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                
                nets_impl.emplace_back(initNN);

                initQValueNetwork();
                trainQValueNetworks();
                cout << "seconds: " << stopwatch << endl;
                exit(0);    
            }
            case NetworkType::PolicyNetwork:
            {   
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Policy_" + version;
                cout << "Generate Policy Network ..." << endl;
                setInputLayerSize(SearchEngine::stateFluents.size());
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                nets_impl.emplace_back(initNN);
                cout << initNN << endl;

                initPolicyNetwork();
                trainPolicyNetworks();
                cout << "seconds: " << stopwatch << endl;
                exit(0);
            }
            case NetworkType::BoundedQvalue:
            {   
                cout << "There is no NN preparation offered for this type of Network: bounded" << endl;
                exit(0);
            }
            case NetworkType::BoundedPolicy:
            {   
                cout << "There is no NN preparation offered for this type of Network: bounded" << endl;
                exit(0);
            }
            case NetworkType::HorizonNetwork:
            {   
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Horizon_" + version;
                cout << "Generate Horizon Network ..." << endl;
                setInputLayerSize(SearchEngine::stateFluents.size()+1);
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                nets_impl.emplace_back(initNN);
                cout << initNN << endl;

                initHorizonNetwork();
                trainHorizonNetwork();
                cout << "seconds: " << stopwatch << endl;
                exit(0);
            }
        }
    }else{
        switch(networkType){
            case NetworkType::QValueNetwork:
            {
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Q-Value_" + version;
                setBatchSize(numberOfActions);
                setInputLayerSize(SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size());
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                nets_impl = trainer.loadModels(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, 0, horizon);
                break;
            }
            case NetworkType::PolicyNetwork:
            {
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                
                version = "Policy_" + version;
                setInputLayerSize(SearchEngine::stateFluents.size());
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                nets_impl = trainer.loadModels(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions, horizon);
                break;
            }
            case NetworkType::BoundedQvalue:
            {
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Q-Value_" + version;
                setInputLayerSize(SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size());
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                nets_impl = trainer.loadModels(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, 0, horizonBound);
                break;
            }
            case NetworkType::BoundedPolicy:
            {
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Policy_" + version;
                setInputLayerSize(SearchEngine::stateFluents.size());
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                nets_impl = trainer.loadModels(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions, horizonBound);
                break;
            }
            case NetworkType::HorizonNetwork:
            {   
                netConfig = "hidden_" + to_string(numberOfHiddenLayer) + "-breadth_" + to_string(layerBreadth*stateFluents.size()) + "-epochs_" + to_string(numberOfEpochs) + "-batch_" + to_string(batchSize) + "-lr_" + to_string(learningRate);
                version = "Horizon_" + version;
                setInputLayerSize(SearchEngine::stateFluents.size()+1);
                Net initNN = Net(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);
                trainer = NetworkTrainer(initNN, networkType, modulePath, netConfig, version);
                trainer.loadModel(initNN, horizon);
                nets_impl.push_back(initNN);
                break;
            }
        }
    }
}

void IncreasingHorizonNN::estimateQValue(State const& state, int actionIndex,
                                        double& qValue){

    assert(state.stepsToGo()-2 <= nets_impl.size()+2);
    assert(state.stepsToGo() != 0);
    if(state.stepsToGo() == 1){
        calcReward(state, actionIndex, qValue);
    }else{
        nets_impl[state.stepsToGo()-3].get()->predictQValue(state, actionIndex, qValue);    
    }
    assert(false);
}


void IncreasingHorizonNN::estimateQValues(State const& _rootState,
                                 std::vector<int> const& actionsToExpand,
                                 std::vector<double>& qValues){
    assert(nets_impl.size() > 0);

    if(_rootState.stepsToGo() < 2){
        for (size_t i = 0; i < actionsToExpand.size(); ++i) {
            if (actionsToExpand[i] == i) {
                calcReward(_rootState, i, qValues[i]);
            }
        }
        return;
    }

    switch (networkType)
    {
        case NetworkType::QValueNetwork:{
            if(_rootState.stepsToGo()-1>nets_impl.size()){
                nets_impl.back().get()->predictQValues(_rootState, actionsToExpand, qValues);
                for(double& value : qValues){
                    value = value/(nets_impl.size()+1) * _rootState.stepsToGo();
                }
            }else{
                nets_impl[_rootState.stepsToGo()-2].get()->predictQValues(_rootState, actionsToExpand, qValues);
            }
            break;
        }
        case NetworkType::PolicyNetwork:{
            if(_rootState.stepsToGo()-1>nets_impl.size()){
                nets_impl.back().get()->predictQValuesPolicyNetwork(_rootState, actionsToExpand, qValues);
                for(double& value : qValues){
                    value = value/(nets_impl.size()+1) * _rootState.stepsToGo();
                }
            }else{
                nets_impl[_rootState.stepsToGo()-2].get()->predictQValuesPolicyNetwork(_rootState, actionsToExpand, qValues);
            }
            break;
        }
        case NetworkType::BoundedQvalue:{
            if(_rootState.stepsToGo()-1 > nets_impl.size()){
                nets_impl.back().get()->predictQValues(_rootState, actionsToExpand, qValues);   
                for(double& value : qValues){
                    value = value/(nets_impl.size()+1) * _rootState.stepsToGo();
                }
            }else{
                nets_impl[_rootState.stepsToGo()-2].get()->predictQValues(_rootState, actionsToExpand, qValues);  
            }                
            break;
        }
        case NetworkType::BoundedPolicy:{
            if(_rootState.stepsToGo()-1 > nets_impl.size()){
                nets_impl.back().get()->predictQValuesPolicyNetwork(_rootState, actionsToExpand, qValues);
                for(double& value : qValues){
                    value = value/(nets_impl.size()+1) * _rootState.stepsToGo();
                }
            }else{
                nets_impl[_rootState.stepsToGo()-2].get()->predictQValuesPolicyNetwork(_rootState, actionsToExpand, qValues);
            }                
            break;
        }
        case NetworkType::HorizonNetwork:{
            nets_impl.back().get()->predictQValuesHorizonNetwork(_rootState, actionsToExpand, qValues, _rootState.stepsToGo());    
            break;
        }
    }
}

//calculates the reward of depth 2 for the initial network trainingdata
void IncreasingHorizonNN::calcActualReward(State const& state, std::vector<int> const& actionsToExpand, std::vector<double>& qValues, int stepsToGo){
    
    if(stepsToGo == 1){
        for (size_t i = 0; i < actionsToExpand.size(); ++i) {
            if (actionsToExpand[i] == i) {
                calcReward(state, i, qValues[i]);
            }
        }
    }else{
        
        double reward;

        for(int actionIndex = 0; actionIndex < actionsToExpand.size(); actionIndex++){
            if(actionsToExpand[actionIndex] == actionIndex){
                
                PDState currentPDState(state); 
                vector<pair<PDState, double>> outcomePairs; 
            
                calcSuccessorState(state, actionIndex, currentPDState);


                if(currentPDState.getNumberOfPDSuccessors(threshold)){
                    outcomePairs = currentPDState.expand();
                }else{
                    outcomePairs = currentPDState.sampleOutcomes(threshold);
                }

                calcReward(state, actionIndex, reward);
                for(pair<PDState, double> const& singlePair : outcomePairs){

                    double maxQ = 0;
                    double currQValue;
                    vector<int> currentActions = getApplicableActions(singlePair.first);

                    for(int i = 0; i < currentActions.size(); i++){
                        if(actionsToExpand[i] == i){
                            calcReward(singlePair.first, i, currQValue);
                            maxQ = max(maxQ, currQValue);
                        }
                    }
                    reward += singlePair.second * maxQ;
                }
                qValues[actionIndex] = reward;
            }   
        }
    }
}


void IncreasingHorizonNN::initQValueNetwork(){

    cout << "Initialize NN_2..." << endl;
    vector<double> dataX;
    vector<double> dataY;

    vector<double> testX;
    vector<double> testY;

    for (State const& currState : SearchEngine::trainingSet){
        vector<int> actions = getApplicableActions(currState);
        vector<double> rewards(actions.size(), 0.0);
        monitorRAMUsage();
        calcActualReward(currState, actions, rewards, 2);

        for(int currAction = 0; currAction < actions.size(); currAction++){
            if(actions[currAction] == currAction){
                dataY.push_back(rewards[currAction]);
            }
        }
        NetImpl::transformInputFlatten(currState,actions, dataX);
    }

    NetImpl::createTrainTestDataQvalue(dataX, dataY, testX, testY, dataX.size()/inputLayerSize, inputLayerSize);
    Net deeperNet(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate);

    deeperNet->trainNNBatches(dataX, dataY);
    nets_impl.clear();
    nets_impl.push_back(deeperNet);

    trainer.safeModel(deeperNet, nets_impl.size()+1);
    
    cout << "trainingset size: " << dataY.size() << endl;
    cout << "Loss of Training: ";
    NetworkTrainer::evaluateQValueNetwork(deeperNet, dataX, dataY);
    cout << "testset size: " << testY.size() << endl;
    cout << "Loss of Test: ";
    NetworkTrainer::evaluateQValueNetwork(deeperNet, testX, testY);
    cout << "Init training finished..." << endl;
}

void IncreasingHorizonNN::initPolicyNetwork(){

    cout << "Initialize NN_2..." << endl;
    vector<double> dataX;
    vector<double> dataY;
    
    vector<double> testX;
    vector<double> testY;
    

    for (State const& currState : SearchEngine::trainingSet){
        vector<int> actions = getApplicableActions(currState);
        vector<double> rewards(actions.size(), 0.0);
        monitorRAMUsage();
        calcActualReward(currState, actions, rewards, 2);
            
        for(int currAction = 0; currAction < actions.size(); currAction++){
            if(actions[currAction] == currAction){
                dataY.push_back(rewards[currAction]);
            }else{
                //value doesnt matter because it is replaced by target value since action is not applicable/reasonable anyway and should not affect loss
                dataY.push_back(-1);
            }
        }
        NetImpl::transformInputPolicyNetwork(currState, dataX);
    }
    //nets_impl.back().get()->forward_custom(dataX);
    //nets_impl.back().get()->validationData = std::make_pair(dataX, dataY);

    //nets_impl.back().get()->trainPolicyNetwork(dataX, dataY);
    
    NetImpl::createTrainTestDataPolicy(dataX, dataY, testX, testY, dataX.size()/inputLayerSize, inputLayerSize, numberOfActions);
    Net deeperNet(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);

    deeperNet->trainPolicyNetwork(dataX, dataY);

    nets_impl.clear();
    nets_impl.push_back(deeperNet);
    
    trainer.safeModel(deeperNet, nets_impl.size()+1);

    cout << "trainingset size: " << dataX.size()/inputLayerSize << endl;
    cout << "Loss of Training: ";
    NetworkTrainer::evaluatePolicyNetwork(deeperNet, dataX, dataY);
    cout << "testset size: " << testX.size()/inputLayerSize << endl;
    cout << "Loss of Test: ";
    NetworkTrainer::evaluatePolicyNetwork(deeperNet, testX, testY);

    cout << "Init training finished..." << endl;
}

void IncreasingHorizonNN::initHorizonNetwork(){

    cout << "Initialize NN..." << endl;
    vector<double> dataX;
    vector<double> dataY;

    for (State const& state : SearchEngine::trainingSet){
        vector<int> actions = getApplicableActions(state);
        vector<double> rewards(actions.size(), 0.0);
        monitorRAMUsage();
        calcActualReward(state, actions, rewards, 2);
        for(int actionIndex = 0; actionIndex < actions.size(); actionIndex++){
            if(actions[actionIndex] == actionIndex){
                dataY.push_back(rewards[actionIndex]);
            }else{
                //value doesnt matter because it is replaced by target value since action is not applicable/reasonable anyway and should not affect loss
                dataY.push_back(-1);
            }
        }
        NetImpl::transformInputHorizonNetwork(state, dataX, 2);
    }
    nets_impl.back().get()->validationData = std::make_pair(dataX, dataY);
    
    nets_impl.back().get()->trainHorizonNetwork(dataX, dataY);
    
    trainer.safeModel(nets_impl.back(), nets_impl.size()+1);
    cout << "Init training finished..." << endl;
}

void IncreasingHorizonNN::trainQValueNetworks(){
    
    while(nets_impl.size() < horizon-1){

        cout << "start learning horizon: " << nets_impl.size()+2 << endl;
        cout << "---------------------------------------------------------------------" << endl;
        
        vector<double> dataX;
        int dataXsize = (SearchEngine::stateFluents.size()+SearchEngine::actionFluents.size())*SearchEngine::actionStates.size()*SearchEngine::trainingSet.size();
        dataX.reserve(dataXsize);
        vector<double> dataY;
        int dataYsize = SearchEngine::actionStates.size()*SearchEngine::trainingSet.size();
        dataY.reserve(dataYsize);

        vector<double> testX;
        vector<double> testY;
        double reward;


        for(State const& currState : SearchEngine::trainingSet){
            std::vector<int> actions = getApplicableActions(currState);
            monitorRAMUsage();
            
            for(int currAction = 0; currAction < actions.size(); currAction++){
                std::vector<std::pair<PDState, double>> outcomePairs;    

                if(actions[currAction] == currAction){
                    PDState currentPDState(currState);
                    
                    calcSuccessorState(currState, currAction, currentPDState);
                    
                    if(currentPDState.getNumberOfPDSuccessors(threshold)){
                        outcomePairs = currentPDState.expand();
                    }else{
                        outcomePairs = currentPDState.sampleOutcomes(threshold);
                    }
                    
                    calcReward(currState, currAction, reward);
                    
                    for(pair<PDState, double> const& singlePair : outcomePairs){
                        
                        double maxQ = 0;
                        double currQValue;
                        vector<int> currentActions = getApplicableActions(singlePair.first);
                        
                        for(int actionIndex = 0; actionIndex < currentActions.size(); actionIndex++){
                            if(currentActions[actionIndex] == actionIndex){
                                
                                nets_impl.back().get()->predictQValue(singlePair.first, actionIndex, currQValue);
                                maxQ = max(maxQ, currQValue);
                            }
                        }
                        
                        //max of qValues
                        reward += singlePair.second * maxQ;
                    }
                    dataY.push_back(reward);
                }
            }
            NetImpl::transformInputFlatten(currState, actions, dataX);
        }

        NetImpl::createTrainTestDataQvalue(dataX, dataY, testX, testY, dataY.size(), inputLayerSize);
        Net deeperNet(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate);
        trainer.loadModel(deeperNet, nets_impl.size()+1);
        deeperNet.get()->trainNNBatches(dataX, dataY);

        nets_impl.push_back(deeperNet);
        trainer.safeModel(deeperNet, nets_impl.size()+1);

        cout << "trainingset size: " << dataY.size() << endl;
        cout << "Loss of Training: ";
        NetworkTrainer::evaluateQValueNetwork(deeperNet, dataX, dataY);
        cout << "testset size: " << testY.size() << endl;
        cout << "Loss of Test: ";
        NetworkTrainer::evaluateQValueNetwork(deeperNet, testX, testY);


        
        
        cout << "Finished Training with horizon: " << nets_impl.size()+1 << endl;
    }
    return;
}

void IncreasingHorizonNN::trainPolicyNetworks(){
    
    
    while(nets_impl.size() < horizon-1){
        cout << "start learning horizon: " << nets_impl.size()+2 << endl;
        cout << "---------------------------------------------------------------------" << endl;
        
        vector<double> dataX;
        int dataXsize = (SearchEngine::stateFluents.size()+SearchEngine::actionFluents.size())*SearchEngine::actionStates.size()*SearchEngine::trainingSet.size();
        dataX.reserve(dataXsize);
        vector<double> dataY;
        int dataYsize = SearchEngine::actionStates.size()*SearchEngine::trainingSet.size();
        dataY.reserve(dataYsize);

        
        vector<double> testX;
        vector<double> testY;
        double reward;

        for(State const& currState : SearchEngine::trainingSet){
            std::vector<int> actions = getApplicableActions(currState);
            monitorRAMUsage();

            for(int currAction = 0; currAction < actions.size(); currAction++){
                std::vector<std::pair<PDState, double>> outcomePairs;    
                if(actions[currAction] == currAction){
                    PDState currentPDState(currState);
                    
                    calcSuccessorState(currState, currAction, currentPDState);
                    
                    if(currentPDState.getNumberOfPDSuccessors(threshold)){
                        outcomePairs = currentPDState.expand();
                    }else{
                        outcomePairs = currentPDState.sampleOutcomes(threshold);
                    }
                    
                    calcReward(currState, currAction, reward);
                    for(pair<PDState, double> const& singlePair : outcomePairs){
                        
                        double maxQ = 0;
                        double currQValue;
                        vector<int> currentActions = getApplicableActions(singlePair.first);
                        for(int actionIndex = 0; actionIndex < currentActions.size(); actionIndex++){
                            if(currentActions[actionIndex] == actionIndex){
                                nets_impl.back().get()->predictQValuePolicyNetwork(singlePair.first, actionIndex, currQValue);
                                maxQ = max(maxQ, currQValue);
                            }
                        }
                        
                        //max of qValues
                        reward += singlePair.second * maxQ;

                    }
                    dataY.push_back(reward);
                }else{
                    dataY.push_back(-1);
                }
            }
            NetImpl::transformInputPolicyNetwork(currState, dataX);
        }
        
        NetImpl::createTrainTestDataPolicy(dataX, dataY, testX, testY, dataX.size()/inputLayerSize, inputLayerSize, numberOfActions);
        Net deeperNet(inputLayerSize, numberOfHiddenLayer, layerBreadth, numberOfEpochs, batchSize, learningRate, numberOfActions);
        trainer.loadModel(deeperNet, nets_impl.size()+1);
        
        deeperNet.get()->trainPolicyNetwork(dataX, dataY);
        
        nets_impl.push_back(deeperNet);
        trainer.safeModel(deeperNet, nets_impl.size()+1);

        cout << "trainingset size: " << dataX.size()/inputLayerSize << endl;
        cout << "Loss of Training: ";
        NetworkTrainer::evaluatePolicyNetwork(deeperNet, dataX, dataY);
        cout << "testset size: " << testX.size()/inputLayerSize << endl;
        cout << "Loss of Test: ";
        NetworkTrainer::evaluatePolicyNetwork(deeperNet, testX, testY);
        
        
        cout << "Finished Training with horizon: " << nets_impl.size()+1 << endl;
    }

    return;
}

void IncreasingHorizonNN::trainHorizonNetwork(){

    
    for(int depth = 3; depth <= horizon; ++depth){
        cout << "start learning horizon: " << depth << endl;
        cout << "---------------------------------------------------------------------" << endl;
        
        vector<double> dataX;
        int dataXsize = (SearchEngine::stateFluents.size()+SearchEngine::actionFluents.size())*SearchEngine::actionStates.size()*SearchEngine::trainingSet.size();
        dataX.reserve(dataXsize);
        vector<double> dataY;
        int dataYsize = SearchEngine::actionStates.size()*SearchEngine::trainingSet.size();
        dataY.reserve(dataYsize);
        double reward;
        
        for(State const& currState : SearchEngine::trainingSet){
            std::vector<int> actions = getApplicableActions(currState);
            monitorRAMUsage();
            for(int currAction = 0; currAction < actions.size(); currAction++){
                std::vector<std::pair<PDState, double>> outcomePairs;    
                if(actions[currAction] == currAction){
                    PDState currentPDState(currState);
                    
                    calcSuccessorState(currState, currAction, currentPDState);
                    
                    if(currentPDState.getNumberOfPDSuccessors(threshold)){
                        outcomePairs = currentPDState.expand();
                    }else{
                        outcomePairs = currentPDState.sampleOutcomes(threshold);
                    }
                    
                    calcReward(currState, currAction, reward);
                    for(pair<PDState, double> const& singlePair : outcomePairs){
                        
                        double maxQ = 0;
                        double currQValue;
                        vector<int> currentActions = getApplicableActions(singlePair.first);
                        for(int actionIndex = 0; actionIndex < currentActions.size(); actionIndex++){
                            if(currentActions[actionIndex] == actionIndex){
                                
                                nets_impl.back().get()->predictQValueHorizonNetwork(singlePair.first, actionIndex, currQValue, depth-1);
                                
                                maxQ = max(maxQ, currQValue);
                            }
                        }
                        //max of qValues
                        reward += singlePair.second * maxQ;
                    }
                    dataY.push_back(reward);
                }else{
                    dataY.push_back(-1);
                }
            }
            
            NetImpl::transformInputHorizonNetwork(currState, dataX, depth);
        }
        
        nets_impl.back().get()->trainHorizonNetwork(dataX, dataY);
        trainer.safeModel(nets_impl.back(), horizon);
        
        
        cout << "Finished Training with horizon: " << depth<< endl;
    }

    return;
}


void IncreasingHorizonNN::monitorRAMUsage() {
    if (cachingEnabled && (SystemUtils::getRAMUsedByThis() > 2097152)) {
        disableCaching();

        SearchEngine::cacheApplicableActions = false;
        for (size_t i = 0; i < State::numberOfDeterministicStateFluents; ++i) {
            SearchEngine::deterministicCPFs[i]->disableCaching();
        }

        for (size_t i = 0; i < State::numberOfProbabilisticStateFluents; ++i) {
            SearchEngine::probabilisticCPFs[i]->disableCaching();
            SearchEngine::determinizedCPFs[i]->disableCaching();
        }

        SearchEngine::rewardCPF->disableCaching();

        for (size_t i = 0; i < SearchEngine::actionPreconditions.size(); ++i) {
            SearchEngine::actionPreconditions[i]->disableCaching();
        }
        cout << "CACHING DISABLED" << endl;
    }
}

void IncreasingHorizonNN::evaluateTask(){

    cout << "numberOfActions: " <<SearchEngine::numberOfActions << endl;
    cout << "horizon: " << SearchEngine::horizon << endl;
    cout << "trainingset-size: " << SearchEngine::trainingSet.size() << endl;
    
    int totalNumberOfAppliedActions = 0;
    

    int currentThreshold = threshold;
    double tmp = 0;
    double result = 0;

    cout << "deterministic_fluents: " << trainingSet[0].numberOfDeterministicStateFluents << endl;
    cout << "probabilistic_fluents: " << trainingSet[0].numberOfProbabilisticStateFluents << endl;

    int exceededThreshold  = 0;

    for(State s : trainingSet){
        

        
        int appliedActionsPerState = 0;
        double outcomes = 0;
        vector<int> currentActions = getApplicableActions(s);
        
        for(int actionIndex = 0; actionIndex < numberOfActions; actionIndex++){
            if(currentActions[actionIndex] == actionIndex){
                appliedActionsPerState++;
                PDState currentPDState(s); 
                calcSuccessorState(s, actionIndex, currentPDState);
                
                if (!currentPDState.getNumberOfPDSuccessors(currentThreshold)){
                    exceededThreshold++;
                    outcomes += currentThreshold;
                }else{
                    tmp = 1;
                    for(int index = 0; index < currentPDState.numberOfProbabilisticStateFluents; index++){
                        tmp *= currentPDState.probabilisticStateFluentAsPD(index).size();
                    }
                    assert(tmp/currentThreshold < 1);
                    outcomes = outcomes + tmp;
                }
            }   
        }
        assert(outcomes > 0);
        result += outcomes/appliedActionsPerState;

        if (appliedActionsPerState>0) {
            totalNumberOfAppliedActions += appliedActionsPerState;
        }
    }
    
    result = result / trainingSet.size();
    cout  << "-------------------------------------" << endl;
    cout << "average number of outcomes:         " << result << endl;
    cout << "given the threshold:               " << currentThreshold << endl;
    cout << "total number of applied actions:  " << totalNumberOfAppliedActions << endl;
    cout << "number of training states:         " << trainingSet.size() << endl;
    cout << "number of actions:                 " << numberOfActions << endl;
    cout << "threshold exceeded:                " << exceededThreshold << endl;
}