#include "thts.h"

#include "action_selection.h"
#include "backup_function.h"
#include "initializer.h"
#include "outcome_selection.h"
#include "recommendation_function.h"
#include "prost_planner.h"

#include "utils/system_utils.h"

/******************************************************************
                     Search Engine Creation
******************************************************************/

THTS::THTS(std::string _name)
    : ProbabilisticSearchEngine(_name),
      ramLimit(2097152),
      cachingEnabled(true),
      actionSelection(nullptr),
      outcomeSelection(nullptr),
      backupFunction(nullptr),
      initializer(nullptr),
      recommendationFunction(nullptr),
      currentRootNode(nullptr),
      chosenOutcome(nullptr),
      tipNodeOfTrial(nullptr),
      states(SearchEngine::horizon + 1),
      stepsToGoInCurrentState(SearchEngine::horizon),
      stepsToGoInNextState(SearchEngine::horizon - 1),
      appliedActionIndex(-1),
      trialReward(0.0),
      currentTrial(0),
      initializedDecisionNodes(0),
      lastUsedNodePoolIndex(0),
      terminationMethod(THTS::TIME),
      maxNumberOfTrials(0),
      numberOfNewDecisionNodesPerTrial(1),
      numberOfRuns(0),
      cacheHits(0),
      accumulatedNumberOfStepsToGoInFirstSolvedRootState(0),
      firstSolvedFound(false),
      accumulatedNumberOfTrialsInRootState(0),
      accumulatedNumberOfSearchNodesInRootState(0) {
    setMaxNumberOfNodes(24000000);
    setTimeout(1.0);
    setRecommendationFunction(new ExpectedBestArmRecommendation(this));
}

bool THTS::setValueFromString(std::string& param, std::string& value) {
    // Check if this parameter encodes an ingredient
    if (param == "-act") {
        setActionSelection(ActionSelection::fromString(value, this));

        return true;
    } else if (param == "-out") {
        setOutcomeSelection(OutcomeSelection::fromString(value, this));

        return true;
    } else if (param == "-backup") {
        setBackupFunction(BackupFunction::fromString(value, this));

        return true;
    } else if (param == "-init") {
        setInitializer(Initializer::fromString(value, this));
        return true;
    } else if (param == "-rec") {
        setRecommendationFunction(
            RecommendationFunction::fromString(value, this));
        return true;
    }

    if (param == "-T") {
        if (value == "TIME") {
            setTerminationMethod(THTS::TIME);
            return true;
        } else if (value == "TRIALS") {
            setTerminationMethod(THTS::NUMBER_OF_TRIALS);
            return true;
        } else if (value == "TIME_AND_TRIALS") {
            setTerminationMethod(THTS::TIME_AND_NUMBER_OF_TRIALS);
            return true;
        } else {
            return false;
        }
    } else if (param == "-r") {
        setMaxNumberOfTrials(atoi(value.c_str()));
        return true;
    } else if (param == "-ndn") {
        if (value == "H") {
            setNumberOfNewDecisionNodesPerTrial(SearchEngine::horizon);
        } else {
            setNumberOfNewDecisionNodesPerTrial(atoi(value.c_str()));
        }
        return true;
    } else if (param == "-node-limit") {
        setMaxNumberOfNodes(atoi(value.c_str()));
        return true;
    }

    //Metareasoning
    if (param == "-meta") {
        if (value == "1"){
            enableMetaReasoning();
            setMetaReasoningVersion(1);
        } else if (value == "2"){
            enableMetaReasoning();
            setMetaReasoningVersion(2);
        } else if (value == "3"){
            enableMetaReasoning();
            setMetaReasoningVersion(3);
        } else if (value == "4"){
            enableMetaReasoning();
            setMetaReasoningVersion(4);
        } else if (value == "5"){
            enableMetaReasoning();
            setMetaReasoningVersion(5);
        } else if (value == "6"){
            enableMetaReasoning();
            setMetaReasoningVersion(6);
        }
        return true;
    }

    if (param == "-tmin") {
        setTMin(std::stod(value.c_str()));
        return true;
    }

    if (param == "-tmax") {
        setTMax(std::stod(value.c_str()));
        return true;
    }

    return SearchEngine::setValueFromString(param, value);
}

void THTS::setActionSelection(ActionSelection* _actionSelection) {
    if (actionSelection) {
        delete actionSelection;
    }
    actionSelection = _actionSelection;
}

void THTS::setOutcomeSelection(OutcomeSelection* _outcomeSelection) {
    if (outcomeSelection) {
        delete outcomeSelection;
    }
    outcomeSelection = _outcomeSelection;
}

void THTS::setBackupFunction(BackupFunction* _backupFunction) {
    if (backupFunction) {
        delete backupFunction;
    }
    backupFunction = _backupFunction;
}

void THTS::setInitializer(Initializer* _initializer) {
    if (initializer) {
        delete initializer;
    }
    initializer = _initializer;
}

void THTS::setRecommendationFunction(
    RecommendationFunction* _recommendationFunction) {
    if (recommendationFunction) {
        delete recommendationFunction;
    }
    recommendationFunction = _recommendationFunction;
}

/******************************************************************
                 Search Engine Administration
******************************************************************/

void THTS::disableCaching() {
    actionSelection->disableCaching();
    outcomeSelection->disableCaching();
    backupFunction->disableCaching();
    initializer->disableCaching();
    recommendationFunction->disableCaching();
    SearchEngine::disableCaching();
}

void THTS::learn() {
    // All ingredients must have been specified
    if (!actionSelection || !outcomeSelection || !backupFunction ||
        !initializer || !recommendationFunction) {
        SystemUtils::abort(
            "Action selection, outcome selection, backup "
            "function, initializer, and recommendation function "
            "must be defined in a THTS search engine!");
    }

    std::cout << name << ": learning..." << std::endl;
    actionSelection->learn();
    outcomeSelection->learn();
    backupFunction->learn();
    initializer->learn();
    recommendationFunction->learn();
    std::cout << name << ": ...finished" << std::endl;
}

/******************************************************************
                 Initialization of search phases
******************************************************************/

void THTS::initRound() {
    firstSolvedFound = false;

    actionSelection->initRound();
    outcomeSelection->initRound();
    backupFunction->initRound();
    initializer->initRound();
    recommendationFunction->initRound();
}

void THTS::initStep(State const& _rootState) {
    PDState rootState(_rootState);
    // Adjust maximal search depth and set root state
    if (rootState.stepsToGo() > maxSearchDepth) {
        maxSearchDepthForThisStep = maxSearchDepth;
        states[maxSearchDepthForThisStep].setTo(rootState);
        states[maxSearchDepthForThisStep].stepsToGo() =
            maxSearchDepthForThisStep;
    } else {
        maxSearchDepthForThisStep = rootState.stepsToGo();
        states[maxSearchDepthForThisStep].setTo(rootState);
    }
    assert(states[maxSearchDepthForThisStep].stepsToGo() ==
           maxSearchDepthForThisStep);

    stepsToGoInCurrentState = maxSearchDepthForThisStep;
    stepsToGoInNextState = maxSearchDepthForThisStep - 1;
    states[stepsToGoInNextState].reset(stepsToGoInNextState);

    // Reset step dependent counter
    currentTrial = 0;
    cacheHits = 0;

    // Reset search nodes and create root node
    currentRootNode = createRootNode();

    std::cout << name << ": Maximal search depth set to "
              << maxSearchDepthForThisStep << std::endl
              << std::endl;
}

inline void THTS::initTrial() {
    // Reset states and steps-to-go counter
    stepsToGoInCurrentState = maxSearchDepthForThisStep;
    stepsToGoInNextState = maxSearchDepthForThisStep - 1;
    states[stepsToGoInNextState].reset(stepsToGoInNextState);

    // Reset trial dependent variables
    initializedDecisionNodes = 0;
    trialReward = 0.0;
    tipNodeOfTrial = nullptr;

    // Init trial in ingredients
    actionSelection->initTrial();
    outcomeSelection->initTrial();
    backupFunction->initTrial();
    initializer->initTrial();
}

inline void THTS::initTrialStep() {
    --stepsToGoInCurrentState;
    --stepsToGoInNextState;
    states[stepsToGoInNextState].reset(stepsToGoInNextState);
}

/******************************************************************
                       Main Search Functions
******************************************************************/

void THTS::estimateBestActions(State const& _rootState,
                               std::vector<int>& bestActions) {
    assert(bestActions.empty());

    stopwatch.reset();

    // Init round (if this is the first call in a round)
    if (_rootState.stepsToGo() == SearchEngine::horizon) {
        initRound();
    }

    // Init step (this function is currently only called once per step) TODO:
    // maybe we should call initRound, initStep and printStats from "outside"
    // such that we can also use this as a heuristic without generating too much
    // output
    initStep(_rootState);

    // Check if there is an obviously optimal policy (as, e.g., in the last step
    // or in a reward lock)
    int uniquePolicyOpIndex = getUniquePolicy();
    if (uniquePolicyOpIndex != -1) {
        std::cout << "Returning unique policy: ";
        SearchEngine::actionStates[uniquePolicyOpIndex].printCompact(std::cout);
        std::cout << std::endl << std::endl;
        bestActions.push_back(uniquePolicyOpIndex);
        currentRootNode = nullptr;
        printStats(std::cout, (_rootState.stepsToGo() == 1));
        return;
    }

    //Metareasoning
//    std::vector<std::vector<double>> deltaQValues;
//    std::vector<std::pair<double, double>> deltaQIntervals;
    //Paper
    std::vector<double> currentDeltaQ;
    std::vector<double> nextDeltaQ;
    std::vector<double> currentQ;
    std::vector<double> lineM;
    std::vector<double> lineQ;
    std::vector<double> probs;
    std::vector<double> means;
    int trialNumber = 0;
    int applicableActionCount = 0;
    int k = 1;
    int thoughtCount = 0;
    int sameActionCount = 0;
    int sameActionLimit = 5;

    // Start the main loop that starts trials until some termination criterion
    // is fullfilled
    while (moreTrials()) {
//        std::cout<<"TEMP: starting trial..."<<std::endl;
        // std::cout <<
        // "---------------------------------------------------------" <<
        // std::endl;
        // std::cout << "TRIAL " << (currentTrial+1) << std::endl;
        // std::cout <<
        // "---------------------------------------------------------" <<
        // std::endl;

        monitorRAMUsage();

        // Metareasoning
//        if (metaReasoningEnabled) {
//            currentQVal = currentRootNode->getExpectedRewardEstimate();
//        }

//        std::cout<<"TEMP: visiting decision node..."<<std::endl;
        visitDecisionNode(currentRootNode);
//        std::cout<<"TEMP: visited decision node"<<std::endl;
        ++currentTrial;
        ++trialNumber;
//        std::cout<<"TEMP: updated counters"<<std::endl;

        if (metaReasoningEnabled) {
            //Metareasoning
            //calculate delta Q values

            //deltaQValues[appliedActionInRoot].size()==0
            if (currentTrial == 1) {
                //find actual action number
                int count = 0;
                for (int i = 0; i < currentRootNode->children.size(); i++) {
                    if (currentRootNode->children[i]) {
                        count++;
                    }
                }
                applicableActionCount = count;
//                deltaQValues.resize(currentRootNode->children.size(), std::vector<double>());
//                deltaQIntervals.resize(currentRootNode->children.size(), std::make_pair(0.0, 0.0));
                currentDeltaQ.resize(currentRootNode->children.size());
                nextDeltaQ.resize(currentRootNode->children.size());
                currentQ.resize(currentRootNode->children.size());
                lineM.resize(currentRootNode->children.size());
                lineQ.resize(currentRootNode->children.size());
                probs.resize(currentRootNode->children.size());
                means.resize(currentRootNode->children.size());
//                std::cout<<"TEMP: count = "<<count<<std::endl;
                //continue;
                //deltaQ = 0;
            }
//            std::cout<<"TEMP: updating current delta Q..."<<std::endl;
            assert(appliedActionInRoot < currentDeltaQ.size());
            
            currentDeltaQ[appliedActionInRoot] =
                    currentRootNode->children[appliedActionInRoot]->getExpectedRewardEstimate() -
                    currentQ[appliedActionInRoot];

//            std::cout << "Trial " << currentTrial << " Number " << trialNumber << " Applied Action: ";
//            actionStates[appliedActionInRoot].printCompact(std::cout);
//            std::cout << std::endl << "Current Q: " << currentQ[appliedActionInRoot] << std::endl;
//            std::cout << "Next Q: " << currentRootNode->children[appliedActionInRoot]->getExpectedRewardEstimate()
//                      << std::endl;
//            std::cout << "delta Q: " << currentDeltaQ[appliedActionInRoot] << std::endl;
            currentQ[appliedActionInRoot] = currentRootNode->children[appliedActionInRoot]->getExpectedRewardEstimate();
//            std::cout<<"TEMP: updated current q"<<std::endl;

            //Save deltaQValues
//            (deltaQValues[appliedActionInRoot]).push_back(deltaQ);

//            //Calculate nextDeltaQ
//            std::random_device rd;
//            std::mt19937 gen(rd());
//            std::uniform_real_distribution<> dis(0.0, 1.0);
//            double rho = dis(gen);
//            nextDeltaQ[appliedActionInRoot] = rho * currentDeltaQ[appliedActionInRoot];
//
//            //calculate line segment
//            double l_0 = currentQ[appliedActionInRoot];
//            double l_1 = currentQ[appliedActionInRoot] + nextDeltaQ[appliedActionInRoot];
//            lineM[appliedActionInRoot] = l_1 - l_0;
//            lineQ[appliedActionInRoot] = l_0;

            if (currentTrial == applicableActionCount) {
                trialNumber = 0;
            }

            if ((trialNumber * k) == applicableActionCount) {
                //sample rho
                std::random_device rd;
                std::mt19937 gen(rd());
                std::uniform_real_distribution<> dis(0.0, 1.0);
                double rho = dis(gen);

                for(int i=0; i<currentDeltaQ.size(); i++){
                    if(currentRootNode->children[i]){
                        //Calculate nextDeltaQ
                        nextDeltaQ[i] = rho * currentDeltaQ[i];

                        //calculate line segment
                        double l_0 = currentQ[i];
                        double l_1 = currentQ[i] + nextDeltaQ[i];
                        lineM[i] = l_1 - l_0;
                        lineQ[i] = l_0;
                    }
                }

                //scale line segments
                double highest_value = -std::numeric_limits<double>::max();
                double lowest_value = std::numeric_limits<double>::max();
                for(int i = 0; i<lineM.size(); i++){
                    double l_0 = lineQ[i];
                    double l_1 = lineM[i]+lineQ[i];
                    if(l_0>highest_value){
                        highest_value = l_0;
                    }
                    if(l_0<lowest_value){
                        lowest_value = l_0;
                    }
                    if(l_1>highest_value){
                        highest_value = l_1;
                    }
                    if(l_1<lowest_value){
                        lowest_value = l_1;
                    }
                }
//                std::cout << "---------------------------------------------------------" << std::endl;
//                std::cout<<"Lowest Value before scaling: "<<lowest_value<<std::endl;
//                std::cout<<"Highest Value before scaling: "<<highest_value<<std::endl;
//                std::cout<<"Line Segments:"<<std::endl;
                highest_value -= lowest_value;
                for(int i=0; i<lineM.size(); i++){
                    double l_0 = lineQ[i];
                    double l_1 = lineM[i]+lineQ[i];
                    l_0 = (l_0-lowest_value);
                    l_1 = (l_1-lowest_value);
                    if(highest_value!=0){
                        l_0 = l_0/highest_value;
                        l_1 = l_1/highest_value;
                    }
                    lineM[i] = l_1-l_0;
                    lineQ[i] = l_0;
//                    actionStates[i].printCompact(std::cout);
//                    std::cout<<"Slope: "<<lineM[i]<<" l_0: "<<lineQ[i]<<std::endl;
                }

                //calculate best starting line
                double max = lineQ[0];
                int bestLine = 0;
                int currentMaxIndex = 0;
                for (int i = 1; i < lineQ.size(); i++) {
                    if (lineQ[i] > max) {
                        max = lineQ[i];
                        bestLine = i;
                        currentMaxIndex = i;
                    }
                }

                //check if another line is equally good
                bool continueThinking = false;
                for (int i = 0; i<lineQ.size(); i++){
                    if(i!=bestLine && MathUtils::doubleIsEqual(lineQ[i], lineQ[bestLine])
                            && MathUtils::doubleIsEqual(lineM[i], lineM[bestLine])){
                        continueThinking = true;
                    }
                }
                if(continueThinking){
                    if(sameActionCount<sameActionLimit){
                        sameActionCount++;
                        trialNumber = 0;
                        thoughtCount++;
//                        std::cout << "---------------------------------------------------------" << std::endl;
//                        std::cout << "Continue thinking: same line for best actions" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        //reset
                        std::fill(probs.begin(), probs.end(),0);
                        std::fill(means.begin(), means.end(),0);
                        continue;
                    }
                }
//                std::cout<<"Best starting line: ";
//                actionStates[bestLine].printCompact(std::cout);
//                std::cout<<std::endl;

                //find all intersections and calculate probabilities
                double x_low = 0;
                int newLine = bestLine;
                while (x_low < 1) {
                    double x_high = 1;
                    double y = lineM[bestLine] + lineQ[bestLine];
                    for (int i = 0; i < lineM.size(); i++) {
                        double xIntersect = (lineQ[i] - lineQ[bestLine])/(lineM[bestLine] - lineM[i]) ;
                        double yIntersect = lineM[bestLine] * xIntersect + lineQ[bestLine];
                        if (x_low < xIntersect && xIntersect < x_high) {
                            x_high = xIntersect;
                            y = yIntersect;
                            newLine = i;
                        }
                    }
                    probs[bestLine] = x_high - x_low;
                    means[bestLine] = (lineQ[bestLine] + y) / 2;
                    x_low = x_high;
                    bestLine = newLine;
                }

                //calcuate expected Q value
                double qThink = 0;
                for (int i = 0; i < probs.size(); i++) {
                    qThink += probs[i] * means[i];
                }

                //find max Q value
//                int maxIndex = -1;
//                double maxQ = -std::numeric_limits<double>::max();
//                for (int i = 0; i < currentQ.size(); i++) {
//                    if (currentQ[i] > maxQ) {
//                        maxQ = currentQ[i];
//                        maxIndex = i;
//                    }
//                }

                //calculate qAct
                double qAct = (lineM[currentMaxIndex] + 2 * lineQ[currentMaxIndex]) / 2;

                //calculate cThink
                double cThink = 0;
                double timePerStep = (ProstPlanner::remainingTimeTotal-stopwatch())/ProstPlanner::remainingStepsTotal;
                if(metaReasoningVersion == 3 || metaReasoningVersion == 5){
                    if(metaReasoningVersion == 5){
                        t_max = ProstPlanner::tMax;
//                        std::cout<<"TMax: "<<t_max<<std::endl;
                    }
                    if(timePerStep > t_max){
                        cThink = 0;
                    } else if(timePerStep < t_min){
                        cThink = 1;
                    } else {
                        cThink = t_max+t_min-timePerStep;
                    }
                } else if (metaReasoningVersion == 4 || metaReasoningVersion == 6){
                    if(metaReasoningVersion == 6){
                        t_max = ProstPlanner::tMax;
//                        std::cout<<"TMax: "<<t_max<<std::endl;
                    }
                    if(timePerStep < t_min){
                        cThink = 1;
                    } else {
                        cThink = t_max+t_min-timePerStep;
                    }
                }


//                std::cout << "---------------------------------------------------------" << std::endl;
//                std::cout <<"Time elapsed: "<<stopwatch()<<std::endl;
//                std::cout << "qThink: " << qThink << std::endl;
//                std::cout << "qAct:   " << qAct << std::endl;
//                std::cout << "cThink: " << cThink << std::endl;
//                std::cout << "Time Remaining: " << ProstPlanner::remainingTimeTotal-stopwatch() << std::endl;
//                std::cout << "Steps Remaining: " << ProstPlanner::remainingStepsTotal << std::endl;
//                std::cout << "Probability of current best action: " << probs[currentMaxIndex] << std::endl;
//                std::cout << "Probability: " << std::endl;
//                for (int i = 0; i < probs.size(); i++) {
//                    actionStates[i].printCompact(std::cout);
//                    std::cout << probs[i] << std::endl;
//                }
//                std::cout << "Averages over best section: " << std::endl;
//                for (int i = 0; i < probs.size(); i++) {
//                    actionStates[i].printCompact(std::cout);
//                    std::cout << means[i] << std::endl;
//                }
//                std::cout << "Next DeltaQ: " << std::endl;
//                for (int i = 0; i < probs.size(); i++) {
//                    actionStates[i].printCompact(std::cout);
//                    std::cout << nextDeltaQ[i] << std::endl;
//                }
//                std::cout << "Current Q: " << std::endl;
//                for (int i = 0; i < probs.size(); i++) {
//                    actionStates[i].printCompact(std::cout);
//                    std::cout << currentQ[i] << std::endl;
//                }

                //evaluate
                if(metaReasoningVersion == 1){
                    if (qAct >= qThink/*-cThink*/ /*|| (probs[maxIndex])>0.5*/) {
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        std::cout << "Thought "<<thoughtCount<<" cycle(s)." << std::endl;
                        std::cout << "Stop thinking" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        break;
                    } else {
                        trialNumber = 0;
                        thoughtCount++;
//                        std::cout << "---------------------------------------------------------" << std::endl;
//                        std::cout << "Continue thinking" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        //reset
                        std::fill(probs.begin(), probs.end(),0);
                        std::fill(means.begin(), means.end(),0);
                    }
                } else if (metaReasoningVersion == 2) {
                    if (MathUtils::doubleIsGreater(stopwatch(),t_min) && qAct >= qThink /*|| (probs[maxIndex])>0.5*/) {
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        std::cout << "Thought "<<thoughtCount<<" cycle(s)." << std::endl;
                        std::cout << "Stop thinking" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        break;
                    } else {
                        trialNumber = 0;
                        thoughtCount++;
//                        std::cout << "---------------------------------------------------------" << std::endl;
//                        std::cout << "Continue thinking" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        //reset
                        std::fill(probs.begin(), probs.end(),0);
                        std::fill(means.begin(), means.end(),0);
                    }
                } else if (metaReasoningVersion == 3 || metaReasoningVersion == 4 || metaReasoningVersion == 5 || metaReasoningVersion == 6){
                    if (MathUtils::doubleIsGreater(stopwatch(),t_min) && qAct >= qThink-cThink /*|| (probs[maxIndex])>0.5*/) {
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        std::cout << "Thought "<<thoughtCount<<" cycle(s)." << std::endl;
                        std::cout << "Stop thinking" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        break;
                    } else {
                        trialNumber = 0;
                        thoughtCount++;
//                        std::cout << "---------------------------------------------------------" << std::endl;
//                        std::cout << "Continue thinking" << std::endl;
//                        std::cout << "---------------------------------------------------------" << std::endl;
                        //reset
                        std::fill(probs.begin(), probs.end(),0);
                        std::fill(means.begin(), means.end(),0);
                    }
                }
            }
//            std::cout<<"TEMP: evaluation done"<<std::endl;
        }
//        std::cout<<"TEMP: trial done"<<std::endl;
        //calculate mean
//            double sum = 0;
//            for (int j = 0; j < deltaQValues[appliedActionInRoot].size(); j++) {
//                sum += deltaQValues[appliedActionInRoot][j];
//            }
//            double mean = sum / deltaQValues[appliedActionInRoot].size();
//
//            //calculate variance
//            sum = 0;
//            for (int j = 0; j < deltaQValues[appliedActionInRoot].size(); j++) {
//                sum += pow((deltaQValues[appliedActionInRoot][j] - mean), 2);
//            std::cout << "Value "<<j<<": " << deltaQValues[appliedActionInRoot][j] << std::endl;
//            std::cout << "Diff "<<j<<": " << (deltaQValues[appliedActionInRoot][j] - mean) << std::endl;
//            std::cout << "Times "<<j<<": " << (deltaQValues[appliedActionInRoot][j] - mean)*(deltaQValues[appliedActionInRoot][j] - mean) << std::endl;
//            std::cout << "Squared " << pow((deltaQValues[appliedActionInRoot][j] - mean), 2) << std::endl;
//            std::cout << "Running Sum " << sum << std::endl;
//            }
//            double variance = sum / deltaQValues[appliedActionInRoot].size();
//        std::cout << "Sum " << sum << std::endl;
//        std::cout << "Size " << deltaQValues[appliedActionInRoot].size() << std::endl;
//
            //update delta Q interval
//            deltaQIntervals[appliedActionInRoot].first = mean - variance;
//            deltaQIntervals[appliedActionInRoot].second = mean + variance;
//
//            std::cout <<
//                      "---------------------------------------------------------" <<
//                      std::endl;
//            std::cout << "Trial " << currentTrial << std::endl;
//            std::cout << "Q1 " << currentQVal << std::endl;
//            std::cout << "Q2 " << currentRootNode->getExpectedRewardEstimate() << std::endl;
//            std::cout << "Delta Q " << deltaQ << std::endl;
//            std::cout << "Mean " << mean << std::endl;
//            std::cout << "Variance " << variance << std::endl;
//            std::cout << "Index " << appliedActionInRoot << std::endl;
//            for (int i = 0; i < deltaQIntervals.size(); i++) {
//                if (currentRootNode->children[i]) {
//                    actionStates[i].printCompact(std::cout);
//                    std::cout << " [" << deltaQIntervals[i].first << ", " << deltaQIntervals[i].second << "]" << std::endl;
//                }
//            }
//            std::cout <<
//                      "---------------------------------------------------------" <<
//            std::endl;
//
//            //Break condition
//            if(currentTrial >= (deltaQIntervals.size()*2)+1) {
//                bool end = false;
//                double low = deltaQIntervals[0].first;
//                double high = deltaQIntervals[0].second;
//                for (int j = 0; j < deltaQIntervals.size(); j++) {
//                    if (deltaQIntervals[j].first > low && deltaQIntervals[j].second > high) {
//                        low = deltaQIntervals[j].first;
//                        high = deltaQIntervals[j].second;
//                        end = true;
//                    }
//                }
//                if (!end) {
//                    end = true;
//                    low = deltaQIntervals[0].first;
//                    high = deltaQIntervals[0].second;
//                    for (int j = 0; j < deltaQIntervals.size(); j++) {
//                        if ((deltaQIntervals[j]).first > low || deltaQIntervals[j].second > high) {
//                            end = false;
//                        }
//                    }
//                }
//                if (end) {
//                    break;
//                }
                    //new break condition
//                    bool end = false;
//                    for(int i = 0; i<deltaQIntervals.size(); i++){
//                        bool best = true;
//                        double low = deltaQIntervals[i].first;
//                        for(int j = 0; j<deltaQIntervals.size(); j++){
//                            if(i!=j && low<=deltaQIntervals[j].second){
//                                best = false;
//                            }
//                        }
//                        if(best){
//                            end = true;
//                            break;
//                        }
//                    }
//            }
//        }

        // for(unsigned int i = 0; i < currentRootNode->children.size(); ++i) {
        //     if(currentRootNode->children[i]) {
        //         SearchEngine::actionStates[i].print(std::cout);
        //         std::cout << std::endl;
        //         currentRootNode->children[i]->print(std::cout, "  ");
        //     }
        // }
        // assert(currentTrial != 100);
    }

    recommendationFunction->recommend(currentRootNode, bestActions);
    assert(!bestActions.empty());

    // Update statistics
    ++numberOfRuns;

    if (currentRootNode->solved && !firstSolvedFound) {
        // TODO: This is the first root state that was solved, so everything
        // that could happen in the future is also solved. We should (at least
        // in this case) make sure that we keep the tree and simply follow the
        // optimal policy.
        firstSolvedFound = true;
        accumulatedNumberOfStepsToGoInFirstSolvedRootState +=
            _rootState.stepsToGo();
    }

    if (_rootState.stepsToGo() == SearchEngine::horizon) {
        accumulatedNumberOfTrialsInRootState += currentTrial;
        accumulatedNumberOfSearchNodesInRootState += lastUsedNodePoolIndex;
    }

    // Print statistics
    std::cout << "Search time: " << stopwatch << std::endl;
    printStats(std::cout, (_rootState.stepsToGo() == 1));
}

void THTS::monitorRAMUsage() {
    if (cachingEnabled && (SystemUtils::getRAMUsedByThis() > ramLimit)) {
        cachingEnabled = false;

        SearchEngine::cacheApplicableActions = false;
        for (size_t i = 0; i < State::numberOfDeterministicStateFluents; ++i) {
            SearchEngine::deterministicCPFs[i]->disableCaching();
        }

        for (size_t i = 0; i < State::numberOfProbabilisticStateFluents; ++i) {
            SearchEngine::probabilisticCPFs[i]->disableCaching();
            SearchEngine::determinizedCPFs[i]->disableCaching();
        }

        SearchEngine::rewardCPF->disableCaching();

        for (size_t i = 0; i < SearchEngine::actionPreconditions.size(); ++i) {
            SearchEngine::actionPreconditions[i]->disableCaching();
        }

        ProbabilisticSearchEngine::disableCaching();
//        cout << endl
//             << "CACHING ABORTED IN STEP " << (currentStep + 1) << " OF ROUND "
//             << (currentRound + 1) << endl
//             << endl;
    }
}

bool THTS::moreTrials() {
    // Check memory constraints and solvedness
    if (currentRootNode->solved ||
        (lastUsedNodePoolIndex >= maxNumberOfNodes)) {
        return false;
    }

    if (currentTrial == 0) {
        return true;
    }

    // Check selected termination criterion
    switch (terminationMethod) {
    case THTS::TIME:
        if (MathUtils::doubleIsGreater(stopwatch(), timeout)) {
            return false;
        }
        break;
    case THTS::NUMBER_OF_TRIALS:
        if (currentTrial == maxNumberOfTrials) {
            return false;
        }
        break;
    case THTS::TIME_AND_NUMBER_OF_TRIALS:
        if (MathUtils::doubleIsGreater(stopwatch(), timeout) ||
            (currentTrial == maxNumberOfTrials)) {
            return false;
        }
        break;
    }

    return true;
}

void THTS::visitDecisionNode(SearchNode* node) {
    if (node == currentRootNode) {
        initTrial();
    } else {
        // Continue trial (i.e., set next state to be the current)
        initTrialStep();

        // Check if there is a "special" reason to stop this trial (currently,
        // this is the case if the state value of the current state is cached,
        // if it is a reward lock or if there is only one step left).
        if (currentStateIsSolved(node)) {
            if (!tipNodeOfTrial) {
                tipNodeOfTrial = node;
            }
            return;
        }
    }

    // Initialize node if necessary
    if (!node->initialized) {
        if (!tipNodeOfTrial) {
            tipNodeOfTrial = node;
        }

        initializer->initialize(node, states[stepsToGoInCurrentState]);

        if (node != currentRootNode) {
            ++initializedDecisionNodes;
        }
    }

    // std::cout << std::endl << std::endl << "Current state is: " << std::endl;
    // states[stepsToGoInCurrentState].printCompact(std::cout);
    // std::cout << "Reward is " << node->immediateReward << std::endl;

    // Determine if we continue with this trial
    if (continueTrial(node)) {
        // Select the action that is simulated
        appliedActionIndex = actionSelection->selectAction(node);
        assert(node->children[appliedActionIndex]);
        assert(!node->children[appliedActionIndex]->solved);
        if (node == currentRootNode) {
            appliedActionInRoot = appliedActionIndex;
        }

        // std::cout << "Chosen action is: ";
        // SearchEngine::actionStates[appliedActionIndex].printCompact(std::cout);
        // std::cout << std::endl;

        // Sample successor state
        calcSuccessorState(states[stepsToGoInCurrentState], appliedActionIndex,
                           states[stepsToGoInNextState]);

        // std::cout << "Sampled PDState is " << std::endl;
        // states[stepsToGoInNextState].printPDStateCompact(std::cout);
        // std::cout << std::endl;

        lastProbabilisticVarIndex = -1;
        for (unsigned int i = 0; i < State::numberOfProbabilisticStateFluents;
             ++i) {
            if (states[stepsToGoInNextState]
                    .probabilisticStateFluentAsPD(i)
                    .isDeterministic()) {
                states[stepsToGoInNextState].probabilisticStateFluent(i) =
                    states[stepsToGoInNextState]
                        .probabilisticStateFluentAsPD(i)
                        .values[0];
            } else {
                lastProbabilisticVarIndex = i;
            }
        }

        // Start outcome selection with the first probabilistic variable
        chanceNodeVarIndex = 0;

        // Continue trial with chance nodes
        if (lastProbabilisticVarIndex < 0) {
            visitDummyChanceNode(node->children[appliedActionIndex]);
        } else {
            visitChanceNode(node->children[appliedActionIndex]);
        }

        // Backup this node
        backupFunction->backupDecisionNode(node);
        trialReward += node->immediateReward;

        // If the backup function labeled the node as solved, we store the
        // result for the associated state in case we encounter it somewhere
        // else in the tree in the future
        if (node->solved) {
            if (cachingEnabled &&
                ProbabilisticSearchEngine::stateValueCache.find(
                    states[node->stepsToGo]) ==
                    ProbabilisticSearchEngine::stateValueCache.end()) {
                ProbabilisticSearchEngine::stateValueCache
                    [states[node->stepsToGo]] =
                        node->getExpectedFutureRewardEstimate();
            }
        }
    } else {
        // The trial is finished
        trialReward = node->getExpectedRewardEstimate();
    }
}

bool THTS::currentStateIsSolved(SearchNode* node) {
    if (stepsToGoInCurrentState == 1) {
        // This node is a leaf (there is still a last decision, though, but that
        // is taken care of by calcOptimalFinalReward)

        calcOptimalFinalReward(states[1], trialReward);
        backupFunction->backupDecisionNodeLeaf(node, trialReward);
        trialReward += node->immediateReward;

        return true;
    } else if (ProbabilisticSearchEngine::stateValueCache.find(
                   states[stepsToGoInCurrentState]) !=
               ProbabilisticSearchEngine::stateValueCache.end()) {
        // This state has already been solved before
        trialReward = ProbabilisticSearchEngine::stateValueCache
            [states[stepsToGoInCurrentState]];
        backupFunction->backupDecisionNodeLeaf(node, trialReward);
        trialReward += node->immediateReward;

        ++cacheHits;
        return true;
    } else if (node->children.empty() &&
               isARewardLock(states[stepsToGoInCurrentState])) {
        // This state is a reward lock, i.e. a goal or a state that is such that
        // no matter which action is applied we'll always get the same reward

        calcReward(states[stepsToGoInCurrentState], 0, trialReward);
        trialReward *= stepsToGoInCurrentState;
        backupFunction->backupDecisionNodeLeaf(node, trialReward);
        trialReward += node->immediateReward;

        if (cachingEnabled) {
            assert(ProbabilisticSearchEngine::stateValueCache.find(
                       states[stepsToGoInCurrentState]) ==
                   ProbabilisticSearchEngine::stateValueCache.end());
            ProbabilisticSearchEngine::stateValueCache
                [states[stepsToGoInCurrentState]] =
                    node->getExpectedFutureRewardEstimate();
        }
        return true;
    }
    return false;
}

void THTS::visitChanceNode(SearchNode* node) {
    while (states[stepsToGoInNextState]
               .probabilisticStateFluentAsPD(chanceNodeVarIndex)
               .isDeterministic()) {
        ++chanceNodeVarIndex;
    }

    chosenOutcome = outcomeSelection->selectOutcome(
        node, states[stepsToGoInNextState], chanceNodeVarIndex,
        lastProbabilisticVarIndex);

    if (chanceNodeVarIndex == lastProbabilisticVarIndex) {
        State::calcStateFluentHashKeys(states[stepsToGoInNextState]);
        State::calcStateHashKey(states[stepsToGoInNextState]);

        visitDecisionNode(chosenOutcome);
    } else {
        ++chanceNodeVarIndex;
        visitChanceNode(chosenOutcome);
    }
    backupFunction->backupChanceNode(node, trialReward);
}

void THTS::visitDummyChanceNode(SearchNode* node) {
    State::calcStateFluentHashKeys(states[stepsToGoInNextState]);
    State::calcStateHashKey(states[stepsToGoInNextState]);

    if (node->children.empty()) {
        node->children.resize(1, nullptr);
        node->children[0] = createDecisionNode(1.0);
    }
    assert(node->children.size() == 1);

    visitDecisionNode(node->children[0]);
    backupFunction->backupChanceNode(node, trialReward);
}

/******************************************************************
                      Root State Analysis
******************************************************************/

int THTS::getUniquePolicy() {
    if (stepsToGoInCurrentState == 1) {
        std::cout << "Returning the optimal last action!" << std::endl;
        return getOptimalFinalActionIndex(states[1]);
    }

    std::vector<int> actionsToExpand =
        getApplicableActions(states[stepsToGoInCurrentState]);

    if (isARewardLock(states[stepsToGoInCurrentState])) {
        std::cout << "Current root state is a reward lock state!" << std::endl;
        states[stepsToGoInCurrentState].print(std::cout);
        for (unsigned int i = 0; i < actionsToExpand.size(); ++i) {
            if (actionsToExpand[i] == i) {
                return i;
            }
        }

        assert(false);
    }

    std::vector<int> applicableActionIndices =
        getIndicesOfApplicableActions(states[stepsToGoInCurrentState]);
    assert(!applicableActionIndices.empty());

    if (applicableActionIndices.size() == 1) {
        std::cout << "Only one reasonable action in current root state!"
                  << std::endl;
        return applicableActionIndices[0];
    }

    // There is more than one applicable action
    return -1;
}

/******************************************************************
                        Memory management
******************************************************************/

SearchNode* THTS::createRootNode() {
    for (SearchNode* node : nodePool) {
        if (node) {
            if (!node->children.empty()) {
                std::vector<SearchNode*> tmp;
                node->children.swap(tmp);
            }
        } else {
            break;
        }
    }

    SearchNode* res = nodePool[0];

    if (res) {
        res->reset(1.0, stepsToGoInCurrentState);
    } else {
        res = new SearchNode(1.0, stepsToGoInCurrentState);
        nodePool[0] = res;
    }
    res->immediateReward = 0.0;

    lastUsedNodePoolIndex = 1;
    return res;
}

SearchNode* THTS::createDecisionNode(double const& prob) {
    assert(lastUsedNodePoolIndex < nodePool.size());

    SearchNode* res = nodePool[lastUsedNodePoolIndex];

    if (res) {
        res->reset(prob, stepsToGoInNextState);
    } else {
        res = new SearchNode(prob, stepsToGoInNextState);
        nodePool[lastUsedNodePoolIndex] = res;
    }
    calcReward(states[stepsToGoInCurrentState], appliedActionIndex,
               res->immediateReward);

    ++lastUsedNodePoolIndex;
    return res;
}

SearchNode* THTS::createChanceNode(double const& prob) {
    assert(lastUsedNodePoolIndex < nodePool.size());

    SearchNode* res = nodePool[lastUsedNodePoolIndex];

    if (res) {
        res->reset(prob, stepsToGoInCurrentState);
    } else {
        res = new SearchNode(prob, stepsToGoInCurrentState);
        nodePool[lastUsedNodePoolIndex] = res;
    }

    ++lastUsedNodePoolIndex;
    return res;
}

/******************************************************************
                       Parameter Setter
******************************************************************/

void THTS::setMaxSearchDepth(int _maxSearchDepth) {
    SearchEngine::setMaxSearchDepth(_maxSearchDepth);

    if (initializer) {
        initializer->setMaxSearchDepth(_maxSearchDepth);
    }
}

/******************************************************************
                            Print
******************************************************************/

void THTS::printStats(std::ostream& out, bool const& printRoundStats,
                      std::string indent) const {
    SearchEngine::printStats(out, printRoundStats, indent);

    if (currentTrial > 0) {
        out << indent << "Performed trials: " << currentTrial << std::endl;
        out << indent << "Created SearchNodes: " << lastUsedNodePoolIndex
            << std::endl;
        out << indent << "Cache Hits: " << cacheHits << std::endl;
        actionSelection->printStats(out, indent);
        outcomeSelection->printStats(out, indent);
        backupFunction->printStats(out, indent);
    }
    if (initializer) {
        initializer->printStats(out, printRoundStats, indent + "  ");
    }

    if (currentRootNode) {
        out << std::endl << indent << "Root Node: " << std::endl;
        currentRootNode->print(out);
        out << std::endl << "Q-Value Estimates: " << std::endl;
        for (unsigned int i = 0; i < currentRootNode->children.size(); ++i) {
            if (currentRootNode->children[i]) {
                out << indent;
                SearchEngine::actionStates[i].printCompact(out);
                out << ": ";
                currentRootNode->children[i]->print(out);
            }
        }
    }

    if (printRoundStats) {
        out << std::endl << indent << "ROUND FINISHED" << std::endl;
        out << indent << "Accumulated number of remaining steps in first "
                         "solved root state: "
            << accumulatedNumberOfStepsToGoInFirstSolvedRootState << std::endl;
        out << indent << "Accumulated number of trials in root state: "
            << accumulatedNumberOfTrialsInRootState << std::endl;
        out << indent << "Accumulated number of search nodes in root state: "
            << accumulatedNumberOfSearchNodesInRootState << std::endl;
    }
}
