//
// Created by badfer00 on 29.10.19.
//

#include "sogbofa_search.h"
#include <iostream>
#include <iomanip>
#include "utils/random.h"
#include "utils/string_utils.h"
#include "utils/system_utils.h"
#include <cmath>

#include <limits>
#include <chrono>
#include <fstream>

using namespace std;

/******************************************************************
                     Search Engine Creation
******************************************************************/

SogbofaSearch::HashMap SogbofaSearch::rewardCache;

SogbofaSearch::SogbofaSearch()
        : DeterministicSearchEngine("Sogbofa"), forward(true), highestMarginalProbability(true), conformant(false), minUpdates(500), verbose(0), maxGradientSteps(0), projection(2), threshold(0.1), penalty(0.0), initialValue(0.0) {}

bool SogbofaSearch::setValueFromString(string &param, string &value) {
    if (param == "-cp") {
        setPenalty(atof(value.c_str()));
        return true;
    } else if (param == "-f") {
        setForward(atoi(value.c_str()));
        return true;
    } else if (param == "-hmp") {
        setHighestMarginalProbability(atoi(value.c_str()));
        return true;
    } else if (param == "-c") {
        setConformant(atoi(value.c_str()));
        return true;
    } else if (param == "-u") {
        setMinUpdates(atoi(value.c_str()));
        return true;
    } else if (param == "-v") {
        setVerbose(atoi(value.c_str()));
        return true;
    } else if (param == "-th") {
        setThreshold(atof(value.c_str()));
        return true;
    } else if (param == "-g") {
        setGradientSteps(atoi(value.c_str()));
        return true;
    } else if (param == "-p") {
        setProjection(atoi(value.c_str()));
        return true;
    }

    return SearchEngine::setValueFromString(param, value);
}

/******************************************************************
                       Main Search Function

    Samle Config:./prost.py elevators_inst_mdp__1  [PROST -s 1 -se [Sogbofa -sd -1 -s -1 -f 1 -hmp 1 -c 1 -cp 1000.0 -u 500]]
    Arguments:
        -sd     Search Depth
                    int > 0, or -1 for dynamic (Only with -i)
        -s      Step Size
                    double > 0, or -1 for dynamic as per paper, -2 for only line search on [0,1], -3 for line search once at the beginning
        -c      Conformant
                    bool, enables the conformant procedure
        -cp     Constraint Penalty
                    double, the weight of the penalty to be added to the reward if constraints are violated
        -hmp    Highest Marginal Probability Action Sampling
                    bool, enables metagamed concrete action sampling as per the original sogbofa
        -u      minUpdates
                    int, the minimal number of gradient actions to be updated during dynamic search depth calculation, only necessary with -sd -1
    Unused:
        -f      Forward mode
                    reverse mode is not used at the moment, this is always used

******************************************************************/

void SogbofaSearch::estimateBestActions(
        State const &_rootState, std::vector<int> &bestActions) {

    // Definitions
    VectorXdual af(SearchEngine::actionFluents.size());
    VectorXdual actions(SearchEngine::actionFluents.size());
    double Q = -std::numeric_limits<double>::max();

    // TODO: Automatic calls to initialize / reset stats / printStats?
    // Initialize state fluents and search depth / step size as necessary
    initialize(_rootState, af);
    resetStats();

    // Main Loop
    while (MathUtils::doubleIsGreater(timeout, stopwatch())) {
        // Random Restart
        randomRestart(_rootState, af);
        if(verbose == 2 || verbose == 3){
            cout<<"..................................."<<endl;
            cout<<"Random Restart "<<restart_counter<<endl;
            cout<<"..................................."<<endl;
        }
        restart_counter++;

        gradientCounter = 0;
        bool converged = false;
        if(verbose == 2){
            cout << "Actions:" << endl;
            for (size_t index = 0; index < af.size(); ++index) {
                if(index%SearchEngine::actionFluents.size() == 0) cout <<"-----------"<< endl;
                if(index < SearchEngine::actionFluents.size()){
                    cout << "Action " << index << " " << SearchEngine::actionFluents[index]->name << " = " << af[index]
                         << endl;
                } else {
                    cout << "CA " << index << ": " << actionFluents[index%actionFluents.size()]->name << " = " << af[index]
                         << endl;
                }
            }
        }

        while (!converged && MathUtils::doubleIsGreater(timeout, stopwatch())) {
            // Gradient Ascent
            gradientCounter++;

            if(verbose == 2 || verbose == 3){
                cout<<"+++++++++++++++++++++++++++++++++++."<<endl;
                cout << "Gradient Step " << gradientCounter << endl;
                cout<<"+++++++++++++++++++++++++++++++++++"<<endl;
            }
            converged = gradientAscent(actions, Q, af);
            if(verbose == 2 || verbose == 3){
                if(converged)cout<<"CONVERGED **************************"<<endl;
            }
        }
        if ( !converged && (verbose == 2 || verbose == 3)){
            cout<<"TIMEOUT!!!"<<endl;
        }
    }

    // Sample concrete Action State from best action
    bestActions = highestMarginalProbability ? sampleConcreteActionMeta(actions, _rootState) : sampleConcreteAction(
            actions, _rootState);

    //printBestAction(bestActions);
    printDetailedStats(cout, true, "");
    resetStats();
}

/******************************************************************
                     Auxiliary functions
******************************************************************/

void SogbofaSearch::initialize(State const &_rootState, VectorXdual &af) {

    stopwatch.reset();

    // State Fluents
    // Create input for Q function
    sf_input.resize(SearchEngine::stateFluents.size());

    // Initialize state fluents
    for (size_t index = 0; index < State::numberOfDeterministicStateFluents; ++index) {
        sf_input[index] = _rootState.deterministicStateFluent(index);
    }
    for (size_t index = 0; index < State::numberOfProbabilisticStateFluents; ++index) {
        sf_input[State::numberOfDeterministicStateFluents + index] = _rootState.probabilisticStateFluent(index);
    }

    // Set the maximum search depth for this step
    if (maxSearchDepth == -1) {
        // -1 means dynamic search depth calculation

        // Set a min search depth
        // TODO: What is a good min Search depth
        maxSearchDepthForThisStep = min(5, _rootState.stepsToGo());

        // Prepare for sample gradient step
        if (conformant) {
            af.resize(SearchEngine::actionFluents.size() * maxSearchDepthForThisStep);
        }
        randomRestart(_rootState, af);
        int minGradientUpdates = minUpdates; /// actionFluents.size();

        // Calculate time needed for one gradient step at the current search depth
        double t_i = timeGradientCalculation(af) / 1000;
        double t_i_prime = t_i + t_i / maxSearchDepthForThisStep;

        //cout << "Time needed: " << minGradientUpdates*t_i_prime << " Time left: " << (timeout-stopwatch()) << endl;
        // Continually increase search depth
        while (MathUtils::doubleIsSmaller(minGradientUpdates * t_i_prime, timeout - stopwatch()) &&
               maxSearchDepthForThisStep < _rootState.stepsToGo()) {
            maxSearchDepthForThisStep++;
            t_i = t_i_prime;
            t_i_prime = t_i + t_i / maxSearchDepthForThisStep;
        }
        //cout << "Max search depth dynamically set to " << maxSearchDepthForThisStep << endl;
    } else {
        maxSearchDepthForThisStep = min(_rootState.stepsToGo(), maxSearchDepth);
    }

    // Action Fluents
    // Resize if conformant to store all actions for all layers as variables
    if (conformant) {
        af.resize(SearchEngine::actionFluents.size() * maxSearchDepthForThisStep);
    }

    // Step Size
    // -3 and -4 find an appropriate step size once at the start
    //if (stepSize == -3 || stepSize == -4) {
    //    randomRestart(_rootState, af);
    //    dual q;
    //    VectorXd dqda = gradient([&](auto x) { return qFunction(x); }, wrt(af), at(af), q);
    //    stepSize = findStepSize(af, dqda);
    //}
}

void SogbofaSearch::randomRestart(const State &_rootState, VectorXdual &af) {
    // Set actions to a random action
    vector<int> applicableActionIndices = getIndicesOfApplicableActions(_rootState);

    int randAction = MathUtils::rnd->randomElement(applicableActionIndices);

    for (size_t index = 0; index < actionFluents.size(); index++) {
        af[index] = actionStates[randAction].state[index];
        assert(!isnan(af[index].val));
    }

    // actions in higher layers are initialized uniformly
    if (conformant) {
        for (size_t level = 1; level < maxSearchDepthForThisStep; level++) {
            for (size_t index = 0; index < actionFluents.size(); index++) {
                af[index + actionFluents.size() * level] = 1.0 / actionFluents.size();
                assert(!isnan(af[index].val));
            }
        }
    }
}

double SogbofaSearch::timeGradientCalculation(VectorXdual &af) {
    auto started = std::chrono::high_resolution_clock::now();
    dual q;
    VectorXd dqda = gradient([&](auto x) { return qFunction(x); }, wrt(af), at(af), q);
    auto done = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> dur = (done - started);
    return dur.count();
}

/******************************************************************
                        Gradient Ascent
******************************************************************/

bool SogbofaSearch::gradientAscent(VectorXdual &actions, double &Q, VectorXdual &af) {

    if (verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Gradient Step " << gradientCounter << endl;
        cout << "*************************************************" << endl;
    }

    // Calculate Gradients
    dual q;
    VectorXd dqda;
    if (verbose == 1) {
        dqda = gradient([&](auto x) { return qFunctionVisualization(x); }, wrt(af), at(af), q);
    } else {
        dqda = gradient([&](auto x) { return qFunction(x); }, wrt(af), at(af), q);
    }
    //verbose = 2;

    //// Print Gradients ////
    if (verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Gradients:" << endl;
        cout << "*************************************************" << endl;
        for (size_t index = 0; index < dqda.size(); ++index) {
            if (index % SearchEngine::actionFluents.size() == 0)
                cout << "Layer: " << index / actionFluents.size() << " ----------------------------------------"
                     << endl;
            if (index < SearchEngine::actionFluents.size()) {
                cout << "Action " << index << ":\t" << setw(35) << SearchEngine::actionFluents[index]->name << "\t"
                     << dqda[index]
                     << endl;
            } else {
                cout << "CA " << index << ":\t" << setw(35) << actionFluents[index % actionFluents.size()]->name << "\t"
                     << dqda[index]
                     << endl;
            }
        }
        cout << "Q:  " << q << endl;
    }

    // Set step size
    double alpha = findStepSize(af, dqda);

    // Update actions
    VectorXdual old_af = af;
    for (size_t index = 0; index < af.size(); index++) {
        af[index] += alpha * dqda[index];
    }


    //// Print Updated Actions ////
    if (verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Updated Actions:" << endl;
        cout << "*************************************************" << endl;
        for (size_t index = 0; index < dqda.size(); ++index) {
            if (index % SearchEngine::actionFluents.size() == 0)
                cout << "Layer: " << index / actionFluents.size() << " ----------------------------------------"
                     << endl;
            if (index < SearchEngine::actionFluents.size()) {
                cout << "Action " << index << ":\t" << setw(35) << SearchEngine::actionFluents[index]->name << "\t"
                     << af[index]
                     << endl;
            } else {
                cout << "CA " << index << ":\t" << setw(35) << actionFluents[index % actionFluents.size()]->name << "\t"
                     << af[index]
                     << endl;
            }
        }
    }

    // Project to legal region
    /*
    VectorXdual actionsToProject = af.head(actionFluents.size());
    VectorXdual actionsToKeep  = af.tail(af.size()-actionFluents.size());
    projectActions(actionsToProject);
    af << actionsToProject, actionsToKeep;
     */
    projectActionsByLayer(af);

    //// Print Projected Actions and Q ////
    if (verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Projected Actions:" << endl;
        cout << "*************************************************" << endl;
        for (size_t index = 0; index < dqda.size(); ++index) {
            if (index % SearchEngine::actionFluents.size() == 0)
                cout << "Layer: " << index / actionFluents.size() << " ----------------------------------------"
                     << endl;
            if (index < SearchEngine::actionFluents.size()) {
                cout << "Action " << index << ":\t" << setw(35) << SearchEngine::actionFluents[index]->name << "\t"
                     << af[index]
                     << endl;
            } else {
                cout << "CA " << index << ":\t" << setw(35) << actionFluents[index % actionFluents.size()]->name << "\t"
                     << af[index]
                     << endl;
            }
        }
        cout << "THIS Q: " << q.val << endl;
        cout << "BEST Q: " << Q << endl;

        cout<<"with "<<endl;
        for (size_t index = 0; index < actions.size(); ++index) {
            if(index < SearchEngine::actionFluents.size()){
                cout << "Action " << index << " " << SearchEngine::actionFluents[index]->name << " = " << actions[index]
                     << endl;
            }
        }

    }

    // Save best action state and Q-value
    if (MathUtils::doubleIsGreater(q.val, Q)) {
        Q = q.val;
        actions = af.head(actionFluents.size());
    }

    // Stopping criterion based on max change in action fluents
    //double diff = static_cast<double>((af.head(actionFluents.size()) - old_af.head(actionFluents.size())).lpNorm<1>());
    double maxDiff = -std::numeric_limits<double>::max();
    for (int i = 0; i < af.size(); ++i) {
        double diff = abs(af[i].val - old_af[i].val);
        if(diff > maxDiff){
            maxDiff = diff;
        }
    }

    if ( verbose == 2 || verbose == 3){
        cout<<"Diff: "<<maxDiff<<endl;
    }
    //MathUtils::doubleIsSmaller(diff, 0.1) ? std::cout << "\nConverged! (diff = " << diff << ")" << endl : std::cout << "\nNot Converged... (diff = " << diff << ")" << endl;
    return MathUtils::doubleIsSmaller(maxDiff, threshold);
}

/******************************************************************
                         Q function
******************************************************************/

dual SogbofaSearch::qFunction(const VectorXdual &af) {

    if(verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Q Function" << endl;
        cout << "*************************************************" << endl;
        cout << "\nLayer: 0 ----------------------------------------" << endl;
    }

    // Prepare state fluents
    vector<VectorXdual> sfs;
    sfs.push_back(sf_input);

    if(verbose == 2){
        for(size_t i = 0; i < deterministicCPFs.size(); i++){
            cout<<"Sf\t"<< setw(35)<< deterministicCPFs[i]->name << "\t" << sfs[0][i]<<endl;
        }
        for(size_t i = 0; i < probabilisticCPFs.size(); i++){
            cout<<"Sf\t"<< setw(35)<< probabilisticCPFs[i]->name << "\t"<< sfs[0][deterministicCPFs.size()+i]<<endl;
        }
    }

    // Prepare action fluents
    vector<VectorXdual> afs;
    afs.push_back(af.head(actionFluents.size()));

    if(verbose == 2) {
        for (size_t i = 0; i < actionFluents.size(); i++) {
            cout << "Af\t" << setw(35)<< actionFluents[i]->name << "\t" << afs[0][i] << endl;
        }
    }

    // Calculate reward
    dual reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwd(sfs.back(), afs.back());

    // Calculate the constraint penalties for the Q value using the penalty parameters
    // TODO should the constraint penalty only apply on the input layer?
    dual cp = 0.0;
    if (!MathUtils::doubleIsEqual(penalty, 0.0)) {
        for (auto ap : SearchEngine::actionPreconditions) {
            cp += penalty * (1 - ap->formula->evaluateForAutodiffFwd(sfs.back(), afs.back()));
        }
    }

    // Calculate the Q value for the input layer
    dual Q = reward + cp;

    if(verbose == 2) {
        cout << "q0\t" << Q << endl;
    }

    for (int layer = 1; layer < maxSearchDepthForThisStep; layer++) {

        assert(layer <= afs.size());
        assert(layer <= sfs.size());

        if(verbose == 2) {
            cout << "\nLayer: "<<layer<<" ----------------------------------------" << endl;
        }

        // Update next state fluents
        sfs.push_back(updateCPFs(sfs[layer-1],afs[layer-1]));

        if(verbose == 2) {
            for (size_t i = 0; i < deterministicCPFs.size(); i++) {
                cout << "Sf\t" << setw(35)<< deterministicCPFs[i]->name <<"\t" << sfs[layer][i] << endl;
            }
            for (size_t i = 0; i < probabilisticCPFs.size(); i++) {
                cout << "Sf\t" << setw(35)<< probabilisticCPFs[i]->name << "\t" << sfs[layer][deterministicCPFs.size() + i]
                     << endl;
            }
        }

        // Update next action fluents
        if (conformant) {
            // conformant action fluents are already given
            afs.push_back(af.segment(layer * actionFluents.size(), actionFluents.size()));
        } else {
            // TODO: calculation could be done only once
            // calculate uniform action fluents
            VectorXdual next_af(actionFluents.size());
            double weight = 1.0 / actionFluents.size();
            for (size_t index = 0; index < actionFluents.size(); ++index) {
                next_af[index] = weight;
            }
            afs.push_back(next_af);
        }

        if(verbose == 2) {
            for (size_t i = 0; i < actionFluents.size(); i++) {
                cout << "Af C\t" << setw(35)<< actionFluents[i]->name << "\t" << afs[layer][i] << endl;
            }
        }

        // Calculate reward
        reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwd(sfs.back(), afs.back());

        // Calculate Q value
        Q += reward;

        if(verbose == 2) {
            cout << "q" << layer << " " << Q << endl;
        }
    }
    //verbose = 0;
    assert(!isnan(Q.val));
    return Q;
}

//TODO: Check
dual SogbofaSearch::qFunctionVisualization(const VectorXdual &af) {

    // Visualization
    std::ofstream out("q.dot");
    out << "digraph Q {" << endl;
    tree_size = 0;

    // copy input as original values should not be changed
    VectorXdual sf_prime = sf_input;
    VectorXdual af_prime = af.head(actionFluents.size());

    // tree size to estimate dual nodes
    //tree_size += sf_input.size() + af.size();
    // Visualization
    out << "n" << tree_size << " [label=\"Q\"]" << endl;
    int qCounter = tree_size;
    tree_size++;
    out << "n" << qCounter <<  " -> n" << tree_size << endl;

    // Calculate reward for the input layer
    dual reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwdVis(sf_prime, af_prime, tree_size, out);

    // Calculate the constraint penalties for the Q value using the penalty parameters
    // TODO should the constraint penalty only apply on the input layer?
    std::ofstream cpOut("cp.dot");
    cpOut << "digraph constraintPenalties {" << endl;
    cpOut << "n" << tree_size << " [label=\"+\"]" << endl;
    int cpCounter = tree_size;
    dual cp = 0.0;
    if (!MathUtils::doubleIsEqual(penalty, 0.0)) {
        for (auto ap : SearchEngine::actionPreconditions) {
            tree_size++;
            cpOut << "n" << cpCounter <<  " -> n" << tree_size << endl;
            cpOut << "n" << tree_size << " [label=\"*\"]" << endl;
            int timesCounter = tree_size;
            tree_size++;
            cpOut << "n" << timesCounter <<  " -> n" << tree_size << endl;
            cpOut << "n" << tree_size << " [label=\"" << penalty << "\"]" << endl;
            tree_size++;
            cpOut << "n" << timesCounter <<  " -> n" << tree_size << endl;
            cpOut << "n" << tree_size << " [label=\"!\"]" << endl;
            tree_size++;
            cpOut << "n" << tree_size - 1 <<  " -> n" << tree_size << endl;
            cp += penalty * (1 - ap->formula->evaluateForAutodiffFwdVis(sf_prime, af_prime, tree_size, cpOut));
        }
    }
    cpOut << "}" << endl;

    // Calculate the Q value for the input layer
    dual Q = reward + cp;

    // Build layers up to the search depth
    for (int layer = 1; layer < maxSearchDepthForThisStep; layer++) {

        // Update next state fluents
        sf_prime = updateCPFsVisualize(sf_prime, af_prime);

        // Update next action fluents
        if (!conformant) {
            if (layer == 1) {
                for (size_t index = 0; index < actionFluents.size(); ++index) {
                    af_prime[index] = 1.0 / actionFluents.size();
                }
            }
        } else {
            af_prime = af.segment(layer * actionFluents.size(), actionFluents.size());
        }

        // Calculate reward and Q value
        // Visualization
        out << "n" << qCounter <<  " -> n" << tree_size << endl;
        reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwdVis(sf_prime, af_prime, tree_size, out);
        Q += reward;
    }
    // Visualization
    out << "}" << endl;
    assert(false);
    assert(!isnan(Q.val));
    return Q;
}

/******************************************************************
                       Heuristic

    Samle Config:./prost.py elevators_inst_mdp__1  [PROST -s 1 -se [THTS -act [UCB1] -out [UMC] -backup [PB] -init [Expand -h [Sogbofa -sd  5 -c 0 -s 0.5]]]]
    Arguments:
        -sd     Search Depth
                    int > 0, or -1 for dynamic (Only with -i)
        -s      Step Size
                    double > 0, or -1 for dynamic as per paper, -2 for only line search on [0,1], -3 for line search once at the beginning
        -c      Conformant
                    bool, enables the conformant procedure

******************************************************************/

void SogbofaSearch::estimateQValues(
        State const &state, vector<int> const &actionsToExpand,
        vector<double> &qValues) {
    HashMap::iterator it = rewardCache.find(state);
    if (it != rewardCache.end()) {
        // Already cached
        ++cacheHits;
        assert(qValues.size() == it->second.size());
        for (size_t index = 0; index < qValues.size(); ++index) {
            if (actionsToExpand[index] == index) {
                qValues[index] =
                        it->second[index] * state.stepsToGo();
            } else {
                qValues[index] = -std::numeric_limits<double>::max();
            }
        }
    } else {
        // Calculate heuristic value
        // Initialize state fluents
        sf_input.resize(SearchEngine::stateFluents.size());
        for (size_t index = 0; index < State::numberOfDeterministicStateFluents; ++index) {
            sf_input[index] = state.deterministicStateFluent(index);
        }
        for (size_t index = 0; index < State::numberOfProbabilisticStateFluents; ++index) {
            sf_input[State::numberOfDeterministicStateFluents + index] = state.probabilisticStateFluent(index);
        }

        // Set the search depth
        maxSearchDepthForThisStep = std::min(maxSearchDepth, state.stepsToGo());

        for (size_t index = 0; index < qValues.size(); ++index) {
            if (actionsToExpand[index] == index) {

                // Action fluents
                ActionState const &action = SearchEngine::actionStates[index];
                VectorXdual af(action.state.size());
                for (size_t i = 0; i < action.state.size(); ++i) {
                    af[i] = action.state[i];
                }

                //// Print Action fluents/////
                if (verbose == 2){
                    cout << "\n******************************************************************" << endl;
                    cout << "Initial Action:" << endl;
                    for (size_t i = 0; i < af.size(); ++i) {
                            cout << "Action " << i << " " << SearchEngine::actionFluents[i]->name << " = " << af[i]
                                 << endl;
                    }
                    cout << "******************************************************************\n" << endl;
                } ////////////////////////////

                double q = -std::numeric_limits<double>::max();

                // Calculate q
                if (conformant) {

                    // Conformant action fluents
                    assert(!isnan(1.0 / action.state.size()));

                    VectorXdual afConformant(action.state.size()*(maxSearchDepthForThisStep -1));
                    double weights = 1.0 / action.state.size();

                    for (size_t level = 0; level < maxSearchDepthForThisStep - 1; level++) {
                        for (size_t i = 0; i < action.state.size(); i++) {
                            afConformant[i + action.state.size() * level] = weights;
                            assert(!isnan(afConformant[i].val));
                        }
                    }


                    //// Print Conformant Action Fluents //
                    if (verbose == 2){
                        cout << "Initial Conformant Action:" << endl;
                        for (size_t i = 0; i < afConformant.size(); ++i) {
                            if(i%SearchEngine::actionFluents.size() == 0) cout <<"Level "<<i/actionFluents.size()<<"-----------"<< endl;
                            cout << "CA " << i << ": "<< actionFluents[i%actionFluents.size()]->name << " = " << afConformant[i]
                                 << endl;
                        }
                    } /////////////////////////////////////

                    // Optimize conformant action fluents
                    // TODO: convergence does not appear to work as a termination criterion, fixed number of iterations better?
                    gradientCounter = 0;
                    totalGradientOptimizations += 1.0;
                    bool converged = false;
                    while (!converged && gradientCounter < maxGradientSteps) {
                        gradientCounter++;
                        converged = gradientAscentHeuristic(q, af, afConformant);
                    }
                    totalGradientSteps += (double)gradientCounter;
                    if(verbose == 2){
                        cout<<"Q Value: "<<q<<endl;
                    }

                    //// Print Conformant Action Fluents //
                    /*if (verbose == 2){
                        cout << "Optimized Conformant Actions:" << endl;
                        for (size_t i = 0; i < afConformant.size(); ++i) {
                            if(i%SearchEngine::actionFluents.size() == 0) cout <<"Level "<<i/actionFluents.size()<<"-----------"<< endl;
                            cout << "CA " << i << ": "<< actionFluents[i%actionFluents.size()]->name << " = " << afConformant[i]
                                 << endl;
                        }
                    }*/ /////////////////////////////////////

                    if (verbose == 2) {
                        cout << "\n*************************************************" << endl;
                        cout << "Q Last Forward Propagation..." << endl;
                        cout << "*************************************************" << endl;
                    }

                    // Last step: propagation only
                    q = qFunctionHeuristic(af, afConformant).val;

                    //verbose = 2;


                    //// Print Action fluents/////
                    if (verbose == 2){
                        cout << "\n******************************************************************" << endl;
                        cout << "Actions (as before):" << endl;
                        for (size_t i = 0; i < af.size(); ++i) {
                            cout << "Action " << i << " " << SearchEngine::actionFluents[i]->name << " = " << af[i]
                                 << endl;
                        }
                        cout << "******************************************************************\n" << endl;
                    } ////////////////////////////

                    if(verbose == 2){
                        cout<<"Q Value: "<<q<<endl;
                        getchar();
                    }

                } else {
                    // Propagate the Q-value
                    q = qFunction(af).val;
                }

                // Return and cache
                if(verbose == 2){
                    cout<<"\n Q VALUE " << q << endl;
                    cout<<"\n======================================== " << endl;
                }
                qValues[index] = q / maxSearchDepthForThisStep;
            }
        }

        if (cachingEnabled) {
            rewardCache[state] = qValues;
        }

        for (size_t index = 0; index < qValues.size(); ++index) {
            if (actionsToExpand[index] == index) {
                qValues[index] *= state.stepsToGo();
            }
        }
    }
}

/******************************************************************
                        Gradient Ascent
******************************************************************/

bool SogbofaSearch::
gradientAscentHeuristic(double &Q, VectorXdual &af, VectorXdual &afConformant) {

    if (verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Gradient Step " << gradientCounter << endl;
        cout << "*************************************************" << endl;
    }

    // Calculate Gradients
    dual q;
    VectorXd dqda = gradient([&](auto x, auto y) { return qFunctionHeuristic(x, y); }, wrt(afConformant), at(af, afConformant), q);
    //verbose = 2;

    assert(dqda.size() == (maxSearchDepthForThisStep-1)*actionFluents.size());

    //// Print Gradients of Conformant Actions ////
    if (verbose == 2){
        cout << "\n*************************************************" << endl;
        cout << "Gradients of Conformant Actions:" << endl;
        cout << "*************************************************" << endl;
        for (size_t i = 0; i < dqda.size(); ++i) {
            if(i%SearchEngine::actionFluents.size() == 0) cout <<"Layer: "<<i/actionFluents.size() + 1<<" ----------------------------------------"<< endl;
            cout << "CA " << i << ":\t"<< setw(35) << actionFluents[i%actionFluents.size()]->name << "\t" << dqda[i]
                 << endl;
        }
    }


    // Set step size (no dynamic option)
    double alpha = stepSize;

    // Update actions
    VectorXdual old_afConformant = afConformant;
    for (size_t i = 0; i < afConformant.size(); i++) {
            afConformant[i] += alpha * dqda[i];
    }

    //// Print Updated Conformant Actions ////
    if (verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Updated Conformant Actions:" << endl;
        cout << "*************************************************" << endl;
        for (size_t i = 0; i < afConformant.size(); ++i) {
            if (i % SearchEngine::actionFluents.size() == 0)
                cout << "Layer: " << i / actionFluents.size() + 1 << " ----------------------------------------" << endl;
            cout << "CA " << i << ":\t" << setw(35) << actionFluents[i % actionFluents.size()]->name << "\t"<< afConformant[i]
                 << endl;

        }
    }

    // Project to legal region
    // TODO: should actions be projected?
    projectActionsByLayer(afConformant);

    //// Print Projected Conformant Actions ////
    if (verbose == 2){
        cout << "\n*************************************************" << endl;
        cout << "Projected Conformant Actions:" << endl;
        cout << "*************************************************" << endl;
        for (size_t i = 0; i < afConformant.size(); ++i) {
            if(i%SearchEngine::actionFluents.size() == 0) cout <<"Layer: "<< i/actionFluents.size() + 1 <<" ----------------------------------------"<< endl;
            cout << "CA " << i << ":\t"<< setw(35)<< actionFluents[i%actionFluents.size()]->name << "\t" << afConformant[i]
                     << endl;
        }
    }

    // Save best Q-value
    if (MathUtils::doubleIsGreater(q.val, Q)) {
        Q = q.val;
    }

    // Stopping criterion
    //double diff = static_cast<double>((afConformant - old_afConformant).lpNorm<1>());
    double maxDiff = -std::numeric_limits<double>::max();
    for (int i = 0; i < afConformant.size(); ++i) {
        double diff = abs(afConformant[i].val - old_afConformant[i].val);
        if(diff > maxDiff){
            maxDiff = diff;
        }
    }
    return MathUtils::doubleIsSmaller(maxDiff, threshold);
}

/******************************************************************
                         Q function
******************************************************************/

dual SogbofaSearch::
qFunctionHeuristic(const VectorXdual &af, const VectorXdual &afConformant) {
    if(verbose == 2) {
        cout << "\n*************************************************" << endl;
        cout << "Q Function" << endl;
        cout << "*************************************************" << endl;
        cout << "\nLayer: 0 ----------------------------------------" << endl;
    }

    // Prepare state fluents
    vector<VectorXdual> sfs;
    sfs.push_back(sf_input);
    if(verbose == 2){
        for(size_t i = 0; i < deterministicCPFs.size(); i++){
            cout<<"Sf\t"<< setw(35)<< deterministicCPFs[i]->name << "\t" << sfs[0][i]<<endl;
        }
        for(size_t i = 0; i < probabilisticCPFs.size(); i++){
            cout<<"Sf\t"<< setw(35)<< probabilisticCPFs[i]->name << "\t"<< sfs[0][deterministicCPFs.size()+i]<<endl;
        }
    }


    // Prepare action fluents
    vector<VectorXdual> afs;
    afs.push_back(af);
    if(verbose == 2) {
        for (size_t i = 0; i < actionFluents.size(); i++) {
            cout << "Af\t" << setw(35)<< actionFluents[i]->name << "\t" << afs[0][i] << endl;
        }
    }

    // Calculate reward
    dual reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwd(sfs.back(), afs.back());

    // Calculate Q value
    dual Q = reward;
    if(verbose == 2) {
        cout << "q0\t" << Q << endl;
    }

    for (int layer = 1; layer < maxSearchDepthForThisStep; layer++) {

        if(verbose == 2) {
            cout << "\nLayer: "<<layer<<" ----------------------------------------" << endl;
        }

        // Update next state fluents
        sfs.push_back(updateCPFs(sfs[layer-1],afs[layer-1]));

        if(verbose == 2) {
            for (size_t i = 0; i < deterministicCPFs.size(); i++) {
                cout << "Sf\t" << setw(35)<< deterministicCPFs[i]->name <<"\t" << sfs[layer][i] << endl;
            }
            for (size_t i = 0; i < probabilisticCPFs.size(); i++) {
                cout << "Sf\t" << setw(35)<< probabilisticCPFs[i]->name << "\t" << sfs[layer][deterministicCPFs.size() + i]
                     << endl;
            }
        }

        // Update next action fluents
        afs.push_back(afConformant.segment((layer - 1) * actionFluents.size(), actionFluents.size()));

        if(verbose == 2) {
            for (size_t i = 0; i < actionFluents.size(); i++) {
                cout << "Af C\t" << setw(35)<< actionFluents[i]->name << "\t" << afs[layer][i] << endl;
            }
        }

        // Calculate reward
        reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwd(sfs.back(), afs.back());

        // Calculate Q value
        Q += reward;
        if(verbose == 2) {
            cout << "q" << layer << " " << Q << endl;
        }
    }
    //verbose = 0;
    assert(!isnan(Q.val));
    return Q;
}

//TODO: Check
dual SogbofaSearch::
qFunctionHeuristicVisualization(const VectorXdual &af, const VectorXdual &af_conformant) {

    // Visualization
    std::ofstream out("q.dot");
    out << "digraph Q {" << endl;
    tree_size = 0;

    VectorXdual sf_prime = sf_input;
    VectorXdual af_prime = af;
    tree_size += sf_input.size() + af.size();

    // Visualization
    out << "n" << tree_size << " [label=\"Q\"]" << endl;
    int qCounter = tree_size;
    tree_size++;
    out << "n" << qCounter <<  " -> n" << tree_size << endl;

    // Calculate reward
    dual reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwdVis(sf_prime, af_prime, tree_size,out);
    tree_size++;

    // Calculate Q value
    dual Q = reward;
    tree_size++;

    for (int layer = 1; layer < maxSearchDepthForThisStep; layer++) {
        // Update next state fluents
        sf_prime = updateCPFsVisualize(sf_prime, af_prime);

        // Update next action fluents
        af_prime = af_conformant.segment((layer - 1) * actionFluents.size(), actionFluents.size());

        // Calculate reward
        // Visualization
        out << "n" << qCounter <<  " -> n" << tree_size << endl;
        reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffFwdVis(sf_prime, af_prime, tree_size, out);
        Q += reward;
        tree_size++;
    }
    // Visualization
    out << "}" << endl;
    assert(false);
    assert(!isnan(Q.val));
    return Q;
}

/******************************************************************
                       Auxiliary functions
******************************************************************/

double SogbofaSearch::findStepSize(VectorXdual &af, VectorXd &u) {

    // Use given step size
    if (!MathUtils::doubleIsEqual(stepSize, -1.0)) return stepSize;

    // Calculate dynamic step size
    double alpha_max = 0.0;
    double u_max = -std::numeric_limits<double>::max();
    bool zero = true;

    // Find max gradient value
    for (int index = 0; index < u.size(); index++) {
        // Check for max abs gradient value
        if (MathUtils::doubleIsGreater(abs(u[index]), u_max)) {
            u_max = abs(u[index]);
        }

        // Check if all gradients are 0
        if (!MathUtils::doubleIsEqual(u[index], 0.0)) {
            zero = false;
        }
    }

    // Step size is irrelevant
    if (zero) return 0.0;

    // set alpha max so variables don't get pushed too far
    assert(!MathUtils::doubleIsEqual(u_max, 0.0));
    alpha_max = 1 / u_max;

    int best_alpha = -1;
    double best_q = -std::numeric_limits<double>::max();
    // TODO: better values than original sogbofa?
    int levels = 5;
    int alphas = 10;
    double spacing;

    // perform linear search up to 5 times
    for (int l = 0; l < levels; l++) {
        spacing = alpha_max / alphas;

        // Find best alpha between [0, max_alpha]
        for (int alpha = 1; alpha <= alphas; alpha++) {

            // Project values for updates with this alpha
            VectorXdual a(af.size());
            for (int index = 0; index < a.size(); index++) {
                a[index] = af[index] + spacing * alpha * u[index];
            }

            projectActionsByLayer(a);

            // Calculate Q value for updates calculated with this alpha
            double q = qFunction(a).val;

            // Save best q values
            if (MathUtils::doubleIsGreater(q, best_q)) {
                //std::cout << "New best alpha " << alpha << " with size " << alpha * spacing << " and q of " << q << endl;
                best_alpha = alpha;
                best_q = q;
            }
        }

        // Resize the region if the first alpha is the best alpha
        if (best_alpha != 1) {
            break;
        } else {
            alpha_max = spacing;
        }
    }
    assert(best_alpha != -1);
    return best_alpha * spacing;
}

void SogbofaSearch::projectActions(VectorXdual &af) {

    // Projection
    // Find boundaries of actions
    double min = std::numeric_limits<double>::max();
    double max = -std::numeric_limits<double>::max();
    for (size_t index = 0; index < af.size(); index++) {
        if (af[index] > max) {
            max = af[index].val;
        }
        if (af[index] < min) {
            min = af[index].val;
        }
    }

    // Project actions to [0,1]
    if (MathUtils::doubleIsSmaller(min, 0.0) || MathUtils::doubleIsGreater(max, 1.0)) {
        for (size_t index = 0; index < af.size(); index++) {
            // Check for trivial projection
            if (MathUtils::doubleIsEqual(max, min)) {
                af[index] = 1.0 / actionFluents.size();
            } else {
                af[index] = (af[index] - min) / (max - min);
            }
        }
    }

    // Action Constraints
    // TODO: Sum constraint for number of allowed actions needed?
    // Preparation for action constraints
    double sum = 0.0;
    double nonZeroEntries = af.size();

    // Find sum and number of non zero entries
    for (size_t index = 0; index < af.size(); index++) {
        sum += af[index].val;
        if (MathUtils::doubleIsEqual(af[index].val, 0.0)) {
            nonZeroEntries--;
        }
    }

    // Project actions to adhere to Sum(a_i) <= numberOfConcurrentActions
    // Continue as long as overflow is left over
    while (MathUtils::doubleIsGreater(sum, numberOfConcurrentActions) &&
           MathUtils::doubleIsGreater(nonZeroEntries, 0.0)) {
        double diff = (sum - numberOfConcurrentActions) / nonZeroEntries;
        sum = 0.0;
        for (size_t index = 0; index < af.size(); index++) {
            if (MathUtils::doubleIsGreaterOrEqual(af[index].val - diff, 0.0)) {
                // Shrink value by necessary amount
                af[index] -= diff;
            } else {
                // Cannot be shrunk by full amount, remainder is left for others in next step
                sum += diff - af[index].val;
                af[index] = 0.0;
                nonZeroEntries--;
            }
        }
    }
}

void SogbofaSearch::projectActionsByLayer(VectorXdual &af) {

    for(size_t layer = 0; layer < af.size()/actionFluents.size(); ++layer){
        // find boundaries of actions
        double min = std::numeric_limits<double>::max();
        double max = -std::numeric_limits<double>::max();
        double sum = 0;

        for (size_t index = layer * actionFluents.size();
             index < layer * actionFluents.size() + actionFluents.size(); index++) {

            if (af[index] > max) {
                max = af[index].val;
            }
            if (af[index] < min) {
                min = af[index].val;
            }
        }

        // Project actions to [0,1]
        if (MathUtils::doubleIsSmaller(min, 0.0) || MathUtils::doubleIsGreater(max, 1.0)) {
            for (size_t index = layer*actionFluents.size(); index < layer*actionFluents.size() + actionFluents.size(); index++) {
                if (MathUtils::doubleIsEqual(max, min)) {
                    af[index] = 1.0;
                } else {
                    af[index] = (af[index] - min) / (max - min);
                }
            }
        }

        if (projection == 2){
            // Project actions to sum constraints by division
            // gather sum
            for (size_t index = layer * actionFluents.size();
                 index < layer * actionFluents.size() + actionFluents.size(); index++) {
                assert(MathUtils::doubleIsGreaterOrEqual(af[index].val, 0.0));
                sum += af[index].val;
            }
            assert(MathUtils::doubleIsGreaterOrEqual(sum, 0.0));

            //fix max number of actions
            if(MathUtils::doubleIsGreater(sum, SearchEngine::maxNumberOfActions)){
                sum /= SearchEngine::maxNumberOfActions;
                assert(MathUtils::doubleIsGreater(sum, 1.0));
                for (size_t index = layer*actionFluents.size(); index < layer*actionFluents.size() + actionFluents.size(); index++) {
                    assert(SearchEngine::maxNumberOfActions != 0);
                    af[index] /= sum;

                }
            }
        } else if (projection == 1){
            // Action Constraints
            // Preparation for action constraints
            sum = 0.0;
            int nonZeroEntries = af.size();

            // Find sum and number of non zero entries
            for (size_t index = layer*actionFluents.size(); index < layer*actionFluents.size() + actionFluents.size(); index++) {
                sum += af[index].val;
                if (MathUtils::doubleIsEqual(af[index].val, 0.0)) {
                    nonZeroEntries--;
                }
            }

            // Project actions to adhere to Sum(a_i) <= numberOfConcurrentActions
            // Continue as long as overflow is left over
            while (MathUtils::doubleIsGreater(sum, numberOfConcurrentActions) &&
                   nonZeroEntries > 0) {
                double diff = (sum - numberOfConcurrentActions) / nonZeroEntries;
                sum = 0.0;
                for (size_t index = layer*actionFluents.size(); index < layer*actionFluents.size() + actionFluents.size(); index++) {
                    if (MathUtils::doubleIsGreaterOrEqual(af[index].val - diff, 0.0)) {
                        // Shrink value by necessary amount
                        af[index] -= diff;
                    } else {
                        // Cannot be shrunk by full amount, remainder is left for others in next step
                        sum += diff - af[index].val;
                        af[index] = 0.0;
                        nonZeroEntries--;
                    }
                }
            }
        }
    }
}

vector<int> SogbofaSearch::sampleConcreteAction(VectorXdual &actions, State const &_rootState) {

    // sample concrete action as min distance to grounded action states
    vector<int> res;
    int min_index = -1;
    double min_diff = std::numeric_limits<double>::max();

    // find action state with the smallest distance
    for (auto as : actionStates) {
        if(actionIsApplicable(as, _rootState)){
            double diff = 0.0;
            for (size_t j = 0; j < actions.size(); j++) {
                diff += abs(actions[j].val - as.state[j]);
            }
            if (MathUtils::doubleIsSmaller(diff, min_diff)) {
                min_index = as.index;
                min_diff = diff;
            }
        }
    }
    res.push_back(min_index);
    return res;
}

vector<int> SogbofaSearch::sampleConcreteActionMeta(VectorXdual &actions, State const &_rootState) {

    // find target action greedily then find closest concrete action
    // TODO: is it a problem that action states are still used?
    VectorXdual greedy_actions(actions.size());
    greedy_actions.fill(0.0);
    double uniform_threshold = 1.0 / actions.size();
    int counter = 0;
    double max_val = std::numeric_limits<double>::max();

    // TODO: max val not needed
    // set actions to 1 as long as it doesn't surpass the max allowed or is less likely than a uniform distribution
    while (counter < maxNumberOfActions && MathUtils::doubleIsGreater(max_val, uniform_threshold)) {
        // find current max index
        int max_index = -1.0;
        max_val = -std::numeric_limits<double>::max();
        for (size_t index = 0; index < actions.size(); index++) {
            if (MathUtils::doubleIsGreater(actions[index].val, max_val)) {
                max_val = actions[index].val;
                max_index = index;
            }
        }
        if (MathUtils::doubleIsGreater(max_val, uniform_threshold)) {
            greedy_actions[max_index] = 1.0;
            // TODO: is this necessary so the same value isn't picked all the time?
            actions[max_index] = -1.0;
        }
        counter++;
    }
    return sampleConcreteAction(greedy_actions, _rootState);
}

VectorXdual SogbofaSearch::updateCPFs(VectorXdual const &sf, VectorXdual const &af) {
    // Update next state fluents
    VectorXdual nextSf(SearchEngine::stateFluents.size());
    for (size_t index = 0; index < State::numberOfDeterministicStateFluents; ++index) {
        nextSf[index] = SearchEngine::deterministicCPFs[index]->formula->evaluateForAutodiffFwd(sf, af);
        assert(MathUtils::doubleIsGreaterOrEqual(nextSf[index].val,0.0));
        assert(MathUtils::doubleIsSmallerOrEqual(nextSf[index].val,1.0));
    }
    for (size_t index = 0; index < State::numberOfProbabilisticStateFluents; ++index) {
        nextSf[State::numberOfDeterministicStateFluents +
               index] = SearchEngine::probabilisticCPFs[index]->formula->evaluateForAutodiffFwd(sf, af);
        assert(MathUtils::doubleIsGreaterOrEqual(nextSf[State::numberOfDeterministicStateFluents + index].val,0.0));
        assert(MathUtils::doubleIsSmallerOrEqual(nextSf[State::numberOfDeterministicStateFluents + index].val,1.0));
    }
    return nextSf;
}

VectorXdual SogbofaSearch::updateCPFsVisualize(VectorXdual const &sf, VectorXdual const &af) {
    // Visualization
    std::ofstream out("sf.dot");
    out << "digraph stateFluents {" << endl;
    out << "n" << tree_size << " [label=\"State fluent transitions last layer\"]" << endl;
    int counter = tree_size;
    // Update next state fluents
    VectorXdual sfPrime(SearchEngine::stateFluents.size());
    for (size_t index = 0; index < State::numberOfDeterministicStateFluents; ++index) {
        // Visualization
        tree_size++;
        out << "n" << counter <<  " -> n" << tree_size << " [style=invis]" << endl;
        out << "n" << tree_size << " [label=\"" << deterministicCPFs[index]->name << "\"]" << endl;
        tree_size++;
        out << "n" << tree_size-1 <<  " -> n" << tree_size << endl;
        sfPrime[index] = SearchEngine::deterministicCPFs[index]->formula->evaluateForAutodiffFwdVis(sf, af, tree_size, out);
    }
    for (size_t index = 0; index < State::numberOfProbabilisticStateFluents; ++index) {
        // Visualization
        tree_size++;
        out << "n" << counter <<  " -> n" << tree_size << " [style=invis]" << endl;
        out << "n" << tree_size << " [label=\"" << probabilisticCPFs[index]->name << "\"]" << endl;
        tree_size++;
        out << "n" << tree_size-1 <<  " -> n" << tree_size << endl;
        sfPrime[State::numberOfDeterministicStateFluents +
                index] = SearchEngine::probabilisticCPFs[index]->formula->evaluateForAutodiffFwdVis(sf, af, tree_size, out);
    }
    // Visualization
    out << "}" << endl;
    return sfPrime;
}

/******************************************************************
                       Print Functions
******************************************************************/

void SogbofaSearch::resetStats() {
    tree_size = 0;
    restart_counter = 0;
    gradientCounter = 0;
}

void SogbofaSearch::printGradients(const VectorXdual &af, const dual &q,
                                   const VectorXd &dqda) const {//std::cout << "Gradients Calculated in "
//          << std::chrono::duration_cast<std::chrono::seconds>(done - started).count()
//          << " seconds." << endl;
    std::cout << "Evaluated output: " << q << endl;
    std::cout << "\nParameter Input: \n" << af << endl;
    std::cout << "\nGradients dQ/da: \n" << dqda << endl;
}

void SogbofaSearch::printBestAction(vector<int> &bestActions) const {
    std::cout << "Best Action:" << endl;
    std::cout << bestActions.front() << endl;
    std::cout << "with action state:" << endl;
    for (size_t index = 0; index < actionStates[bestActions.front()].state.size(); index++) {
        std::cout << actionStates[bestActions.front()].state[index] << endl;
    }
}

void SogbofaSearch::printLayer(const dual &reward, const dual &Q, VectorXdual &sf_prime, VectorXdual &af_prime, int layer) const {
    std::cout << "\nLayer " << layer << ": ------------------------------\n" << endl;
    printFluents(cout, sf_prime, af_prime);
    std::cout << "Reward: = " << reward << endl;
    std::cout << "Current Q: = " << Q << endl;
}

void SogbofaSearch::printRandomRestart() {
    std::cout << "\n********************************************" << endl;
    std::cout << "Random Restart " << this->restart_counter++ << "..." << endl;
    std::cout << "********************************************\n" << endl;
}

void SogbofaSearch::printStats(ostream &out, bool const &printRoundStats,
                               string indent) const {
    //SearchEngine::printStats(out, printRoundStats, indent);
    if(printRoundStats) out << indent << "Round Stats omitted" << endl;
    out << indent << "Maximal search depth: " << maxSearchDepthForThisStep << endl;
    out << indent << "Step Size: " << stepSize << endl;
    out << indent << "Average Gradient Steps: " << totalGradientSteps / totalGradientOptimizations << endl;

    //out << indent << "Average tree_size: " << tree_size / restart_counter << endl;
    //out << indent << "Random Restarts: " << restart_counter << endl;
    //out << indent << "Total Gradients Calculated: " << gradientCounter << endl;
    //out << indent << "Average Gradients: " << gradientCounter / restart_counter << endl;
}

void SogbofaSearch::printDetailedStats(ostream &out, bool const &printRoundStats,
                               string indent) const {
    SearchEngine::printStats(out, printRoundStats, indent);
    out << indent << "Maximal search depth: " << maxSearchDepthForThisStep << endl;
    out << indent << "Step Size: " << stepSize << endl;
    out << indent << "Average tree_size: " << tree_size / restart_counter << endl;
    out << indent << "Random Restarts: " << restart_counter << endl;
    out << indent << "Total Gradients Calculated: " << gradientCounter << endl;
    out << indent << "Average Gradients: " << gradientCounter / restart_counter << endl;
}

void SogbofaSearch::printVar(ostream &out, var &node) const {
    int counter = 0;
    out << "digraph G {" << endl;
    printVar(out, node, counter);
    out << "}" << endl;
}

void SogbofaSearch::printFluents(ostream &out, VectorXvar &sf, VectorXvar &af) {
    // State Fluents
    out << "State Fluents:" << endl;
    for (size_t index = 0; index < sf.size(); ++index) {
        out << "State " << index << " " << SearchEngine::stateFluents[index]->name << " = " << sf[index]
            << endl;
    }
    // Action Fluents
    out << "Action Fluents:" << endl;
    for (size_t index = 0; index < af.size(); ++index) {
        out << "Action " << index << " " << SearchEngine::actionFluents[index]->name << " = " << af[index]
            << endl;
    }
}

void SogbofaSearch::printFluents(ostream &out, VectorXdual &sf, VectorXdual &af) {
    // State Fluents
    out << "State Fluents:" << endl;
    for (size_t index = 0; index < sf.size(); ++index) {
        out << "State " << index << " " << SearchEngine::stateFluents[index]->name << " = " << sf[index]
            << endl;
    }
    // Action Fluents
    out << "Action Fluents:" << endl;
    for (size_t index = 0; index < af.size(); ++index) {
        out << "Action " << index << " " << SearchEngine::actionFluents[index]->name << " = " << af[index]
            << endl;
    }
}


void SogbofaSearch::printActionFluents(const VectorXdual &af) const {
    cout << "\nActions: " << endl;
    for (size_t index = 0; index < af.size(); index++) {
        if (index < actionFluents.size()) {
            cout << "Action " << index << " " << actionFluents[index]->name << " = " << af[index]
                 << endl;
        } else {
            cout << "Conformant Action " << index << " = " << af[index]
                 << endl;
        }
    }
}

void SogbofaSearch::printVar(ostream &out, var &node, int &counter) const {

    /** Example Output*************************************************
    digraph G {
            n0 [label="*"]
            n1 [label="+"]
            n3 [label="v1"]
            n4 [label="v2"]
            n2 [label="a1"]
            n0 -> n1
            n0 -> n2
            n1 -> n3
            n1 -> n4
     }
     ******************************************************************/

    if (dynamic_cast<const ParameterExpr *>(node.expr.get())) {
        //// Parameter Expression
        out << "    n" << counter << " [label=\"p\"]" << endl;
        //we are done?
    } else if (dynamic_cast<const VariableExpr *>(node.expr.get())) {
        //// Variable Expression
//        const auto *parent = dynamic_cast<const VariableExpr *>(node.expr.get());
        out << "    n" << counter << " [label=\"v=" << node << "\"]" << endl;
        //we are done?
//        counter++;
//        out << "    n" << counter - 1 << " -> n" << counter << endl;
//        var child = parent->expr;
//        printVar(out, child, counter);
    } else if (dynamic_cast<const ConstantExpr *>(node.expr.get())) {
        //// Constant Expression
        out << "    n" << counter << " [label=\"c\"]" << endl;
        //no children
    } else if (dynamic_cast<const NegativeExpr *>(node.expr.get())) {
        //// Negative Expression
        const auto *parent = dynamic_cast<const NegativeExpr *>(node.expr.get());
        out << "    n" << counter << " [label=\"neg\"]" << endl;
        //call child
        counter++;
        out << "    n" << counter - 1 << " -> n" << counter << endl;
        var child = parent->x;
        printVar(out, child, counter);
    } else if (dynamic_cast<const AddExpr *>(node.expr.get())) {
        //// Addition Expression
        const auto *parent = dynamic_cast<const AddExpr *>(node.expr.get());
        out << "    n" << counter << " [label=\"+\"]" << endl;
        int id = counter;
        //call left child
        counter++;
        var left = parent->l;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, left, counter);
        //call right child
        counter++;
        var right = parent->r;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, right, counter);
    } else if (dynamic_cast<const SubExpr *>(node.expr.get())) {
        //// Subtraction Expression
        const auto *parent = dynamic_cast<const SubExpr *>(node.expr.get());
        out << "    n" << counter << " [label=\"-\"]" << endl;
        int id = counter;
        //call left child
        counter++;
        var left = parent->l;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, left, counter);
        //call right child
        counter++;
        var right = parent->r;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, right, counter);
    } else if (dynamic_cast<const MulExpr *>(node.expr.get())) {
        //// Multiplication Expression
        const auto *parent = dynamic_cast<const MulExpr *>(node.expr.get());
        out << "    n" << counter << " [label=\"*\"]" << endl;
        int id = counter;
        //call left child
        counter++;
        var left = parent->l;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, left, counter);
        //call right child
        counter++;
        var right = parent->r;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, right, counter);
    } else if (dynamic_cast<const DivExpr *>(node.expr.get())) {
        //// Division Expression
        const auto *parent = dynamic_cast<const MulExpr *>(node.expr.get());
        out << "    n" << counter << " [label=\"/\"]" << endl;
        int id = counter;
        //call left child
        counter++;
        var left = parent->l;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, left, counter);
        //call right child
        counter++;
        var right = parent->r;
        out << "    n" << id << " -> n" << counter << endl;
        printVar(out, right, counter);
    }
}

/******************************************************************
                       Auxiliary functions: Reverse
******************************************************************/
//bool
//SogbofaSearch::gradientAscent(VectorXvar &actions, double &Q, VectorXvar &sf_var, VectorXvar &af_var, VectorXdual &sf_dual,
//                                VectorXdual &af_dual) {
//
//    // Calculate Gradients
//    var q_rev = qFunctionRev(sf_var, af_var);
//    dual q_fwd;
//    std::cout << "\nCalculating Gradients......................." << endl;
//    auto started = std::chrono::high_resolution_clock::now();
//    VectorXd dqda_rev = gradient(q_rev, af_var);
//    auto done = std::chrono::high_resolution_clock::now();
//
//    // Print derivatives
//    std::cout << "Gradients Calculated in "
//              << std::chrono::duration_cast<std::chrono::seconds>(done - started).count()
//              << " seconds." << endl;
//    std::cout << "Evaluated output: " << q_rev << endl;
//    std::cout << "\nParameter Input: \n" << af_var << endl;
//    std::cout << "\nGradients dQ/da: \n" << dqda_rev << endl;
//
//    started = std::chrono::high_resolution_clock::now();
//    VectorXd dqda_fwd = gradient(qFunction, wrt(af_dual), at(sf_dual, af_dual, maxSearchDepth, , tre), q_fwd);
//    done = std::chrono::high_resolution_clock::now();
//
//    // Print derivatives
//    std::cout << "Gradients Calculated in "
//              << std::chrono::duration_cast<std::chrono::seconds>(done - started).count()
//              << " seconds." << endl;
//    std::cout << "Evaluated output: " << q_fwd << endl;
//    std::cout << "\nParameter Input: \n" << af_dual << endl;
//    std::cout << "\nGradients dQ/da: \n" << dqda_fwd << endl;
//
//    for (size_t index = 0; index < dqda_fwd.size(); index++) {
//        assert(MathUtils::doubleIsEqual(dqda_rev[index], dqda_fwd[index]));
//    }
//    // Print Q graph
//    //std::ofstream out("output/output.dot");
//    //printVar(cout, q);
//
//    // Update actions
//    VectorXvar old_af = af_var;
//    for (size_t index = 0; index < af_var.size(); index++) {
//        af_var[index] += SearchEngine::stepSize * dqda_rev[index];
//    }
//
//    // Project to legal region
//    projectActionsRev(af_var);
//
//    // Save best action state and Q-value
//    if (MathUtils::doubleIsGreater(q_rev.expr->val, Q)) {
//        Q = q_rev.expr->val;
//        actions = af_var;
//        std::cout << "\nBest Actions updated:" << endl;
//    } else {
//        std::cout << "\nBest Actions remain:" << endl;
//    }
//    std::cout << actions << endl;
//    std::cout << "Q = " << Q << endl;
//
//
//    // Stopping criterion
//    double diff = static_cast<double>((af_var - old_af).lpNorm<1>());
//    if (MathUtils::doubleIsSmaller(diff, 0.1)) {
//        std::cout << "\nConverged! (diff = " << diff << ")" << endl;
//        return true;
//    }
//    std::cout << "\nNot Converged... (diff = " << diff << ")" << endl;
//    return false;
//}

//bool SogbofaSearch::findActionsRev(VectorXvar &actions, double &Q, VectorXvar &sf, VectorXvar &af) {
//
//    // Calculate Gradients
//    var q = qFunctionRev(sf, af);
//    //std::cout << "\nCalculating Gradients......................." << endl;
//    //auto started = std::chrono::high_resolution_clock::now();
//    VectorXd dqda = gradient(q, af);
//    //auto done = std::chrono::high_resolution_clock::now();
//
//    // Print derivatives
//    //std::cout << "Gradients Calculated in "
//    //          << std::chrono::duration_cast<std::chrono::seconds>(done - started).count()
//    //          << " seconds." << endl;
//    //std::cout << "Evaluated output: " << q << endl;
//    //std::cout << "\nParameter Input: \n" << af << endl;
//    //std::cout << "\nGradients dQ/da: \n" << dqda << endl;
//
//    // Print Q graph
//    //std::ofstream out("output/output.dot");
//    //printVar(cout, q);
//
//    // Update actions
//    VectorXvar old_af = af;
//    for (size_t index = 0; index < af.size(); index++) {
//        af[index] += SearchEngine::stepSize * dqda[index];
//    }
//
//    // Project to legal region
//    projectActionsRev(af);
//
//    // Save best action state and Q-value
//    if (MathUtils::doubleIsGreater(q.expr->val, Q)) {
//        Q = q.expr->val;
//        actions = af;
//        //std::cout << "\nBest Actions updated:" << endl;
//    } else {
//        //std::cout << "\nBest Actions remain:" << endl;
//    }
//    //std::cout << actions << endl;
//    //std::cout << "Q = " << Q << endl;
//
//
//    // Stopping criterion
//    double diff = static_cast<double>((af - old_af).lpNorm<1>());
//    if (MathUtils::doubleIsSmaller(diff, 0.1)) {
//        //std::cout << "\nConverged! (diff = " << diff << ")" << endl;
//        return true;
//    }
//    //std::cout << "\nNot Converged... (diff = " << diff << ")" << endl;
//    return false;
//}
//
//var SogbofaSearch::qFunctionRev(const VectorXvar &sf, const VectorXvar &af) {
//
//    VectorXvar sf_prime = sf;
//    VectorXvar af_prime = af;
//
//    var Q = 0.0;
//
//    // Calculate reward
//    var reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffRev(sf_prime, af_prime);
//    var cp = 0.0;
//    for (auto ap : SearchEngine::actionPreconditions) {
//        cp += penalty * (1 - ap->formula->evaluateForAutodiffRev(sf_prime, af_prime));
//    }
//
//    Q += reward + cp;
//
//    /**Printouts Input Layer*******************************************/
//    std::cout << "\nInput Layer " << ": ------------------------------\n" << endl;
//    printFluents(cout, sf_prime, af_prime);
//    std::cout << "Reward: = " << reward << endl;
//    std::cout << "Current Q: = " << Q << endl;
//
//    /**Loop over Layers*******************************************/
//    for (int layer = 1; layer < maxSearchDepth; layer++) {
//
//        std::cout << "\nLayer " << layer << ": -----------------------------------\n" << endl;
//
//        // Update next states
//        sf_prime = updateCPFsRev(sf_prime, af_prime);
//
//        // Update af to uniform
//        if (layer == 1) {
//            for (size_t index = 0; index < actionFluents.size(); ++index) {
//                af_prime[index] = 1.0 / actionFluents.size();
//            }
//        }
//
//        // Calculate reward
//        reward = SearchEngine::rewardCPF->formula->evaluateForAutodiffRev(sf_prime, af_prime);
//        Q += reward;
//
//        /**Printouts*******************************************************/
//        printFluents(cout, sf_prime, af_prime);
//        std::cout << "Reward: = " << reward << endl;
//        std::cout << "Current Q: = " << Q << endl;
//    }
//    return Q;
//}
//
//VectorXvar SogbofaSearch::updateCPFsRev(VectorXvar &sf, VectorXvar &af) {
//    // Update next states
//    VectorXvar sf_prime(SearchEngine::stateFluents.size());
//    for (size_t index = 0; index < State::numberOfDeterministicStateFluents; ++index) {
//        sf_prime[index] = SearchEngine::deterministicCPFs[index]->formula->evaluateForAutodiffRev(sf, af);
//    }
//    for (size_t index = 0; index < State::numberOfProbabilisticStateFluents; ++index) {
//        sf_prime[State::numberOfDeterministicStateFluents +
//                 index] = SearchEngine::probabilisticCPFs[index]->formula->evaluateForAutodiffRev(sf, af);
//    }
//    return sf_prime;
//}
//
//void SogbofaSearch::projectActionsRev(VectorXvar &af) {
//    // Print unprojected actions
//    //std::cout << "\nUnprojected Actions: " << endl;
//    for (size_t index = 0; index < af.size(); index++) {
//        if (index < actionFluents.size()) {
//            //std::cout << "Action " << index << " " << SearchEngine::actionFluents[index]->name << " = " << af[index]
//            //          << endl;
//        } else {
//            //std::cout << "Conformant Action " << index << " = " << af[index]
//            //          << endl;
//        }
//    }
//
//    // find boundaries of actions
//    double min = std::numeric_limits<double>::max();
//    double max = -std::numeric_limits<double>::max();
//    for (size_t index = 0; index < af.size(); index++) {
//        if (af[index] > max) {
//            max = af[index].expr->val;
//        }
//        if (af[index] < min) {
//            min = af[index].expr->val;
//        }
//    }
//
//    // Project actions to [0,1]
//    if (MathUtils::doubleIsSmaller(min, 0.0) || MathUtils::doubleIsGreater(max, 1.0)) {
//        for (size_t index = 0; index < af.size(); index++) {
//            af[index] = (af[index] - min) / (max - min);
//        }
//    }
//
//    // TODO: Sum constraint for number of allowed actions needed?
//    // Preparation for action constraints
//    double sum = 0.0;
//    double nonZeroEntries = af.size();
//
//    for (size_t index = 0; index < af.size(); index++) {
//        sum += af[index].expr->val;
//        if (MathUtils::doubleIsEqual(af[index].expr->val, 0.0)) {
//            nonZeroEntries--;
//        }
//    }
//
//    // Project actions to adhere to Sum(a_i) <= numberOfConcurrentActions
//    while (MathUtils::doubleIsGreater(sum, numberOfConcurrentActions) &&
//           MathUtils::doubleIsGreater(nonZeroEntries, 0.0)) {
//        double diff = (sum - numberOfConcurrentActions) / nonZeroEntries;
//        sum = 0.0;
//        for (size_t index = 0; index < af.size(); index++) {
//            if (MathUtils::doubleIsGreaterOrEqual(af[index].expr->val - diff, 0.0)) {
//                af[index] -= diff;
//            } else {
//                sum += diff - af[index].expr->val;
//                af[index] = 0.0;
//                nonZeroEntries--;
//            }
//        }
//    }
//
//    // Print projected actions
//    //std::cout << "\nProjected Actions: " << endl;
//    for (size_t index = 0; index < af.size(); index++) {
//        if (index < actionFluents.size()) {
//            //std::cout << "Action " << index << " " << SearchEngine::actionFluents[index]->name << " = " << af[index]
//            //          << endl;
//        } else {
//            //std::cout << "Conformant Action " << index << " = " << af[index]
//            //          << endl;
//        }
//    }
//}
//
//vector<int> SogbofaSearch::sampleConcreteActionRev(VectorXvar &actions) {
//    // sample concrete action as min distance to grounded action states
//    vector<int> res;
//    int min = -1;
//    double min_diff = std::numeric_limits<double>::max();
//    for (auto as : actionStates) {
//        assert(as.state.size() == actions.size());
//        double diff = 0.0;
//        for (size_t j = 0; j < actions.size(); j++) {
//            diff += abs(actions[j].expr->val - as.state[j]);
//        }
//        if (MathUtils::doubleIsSmaller(diff, min_diff)) {
//            min = as.index;
//            min_diff = diff;
//        }
//    }
//    res.push_back(min);
//    return res;
//}
//
//vector<int> SogbofaSearch::sampleConcreteActionMetaRev(VectorXvar &actions) {
//    VectorXvar actionState(actions.size());
//    actionState.fill(0.0);
//    double threshold = 1.0 / actions.size();
//    int counter = 0;
//    double max_val = std::numeric_limits<double>::max();
//    while (counter < numberOfConcurrentActions && MathUtils::doubleIsGreater(max_val, threshold)) {
//        int max_index = -1.0;
//        max_val = -std::numeric_limits<double>::max();
//        for (size_t index = 0; index < actions.size(); index++) {
//            if (MathUtils::doubleIsGreater(actions[index].expr->val, max_val)) {
//                max_val = actions[index].expr->val;
//                max_index = index;
//            }
//        }
//        if (MathUtils::doubleIsGreater(max_val, threshold)) {
//            actionState[max_index] = 1.0;
//        }
//        counter++;
//    }
//    return sampleConcreteActionRev(actionState);
//}
