//
// Created by badfer00 on 29.10.19.
//

#ifndef PROST_SOGBOFA_SEARCH_H
#define PROST_SOGBOFA_SEARCH_H

#include "search_engine.h"
#include <iostream>
#include "utils/string_utils.h"
#include "utils/system_utils.h"

// Eigen includes
#include <eigen3/Eigen/Core>

using namespace Eigen;

// autodiff include
#include <autodiff/reverse.hpp>
#include <autodiff/reverse/eigen.hpp>
#include <stopwatch.h>

using namespace autodiff;

class SogbofaSearch : public DeterministicSearchEngine {
public:
    SogbofaSearch();

    // Set parameters from command line
    bool setValueFromString(std::string &param, std::string &value) override;

    // Start the search engine to calculate best actions
    void estimateBestActions(State const &_rootState,
                             std::vector<int> &bestActions) override;

    // Start the search engine to estimate the Q-value of a single action
    void estimateQValue(State const &state, int /*actionIndex*/,
                        double &qValue) override {
        qValue = initialValue * (double) state.stepsToGo();
    }

    // Start the search engine to estimate the Q-values of all applicable
    // actions
    void estimateQValues(State const &state,
                         std::vector<int> const &actionsToExpand,
                         std::vector<double> &qValues) override;

    // Parameter setter
    virtual void setInitialValue(double const &_initialValue) {
        initialValue = _initialValue;
    }

    virtual void setPenalty(double _penalty) {
        penalty = _penalty;
    }

    virtual void setForward(bool _forward) {
        forward = _forward;
    }

    virtual void setHighestMarginalProbability(bool _hmp) {
        highestMarginalProbability = _hmp;
    }

    virtual void setConformant(bool _conformant) {
        conformant = _conformant;
    }

    virtual void setMinUpdates(int _minUpdates) {
        minUpdates = _minUpdates;
    }

    virtual void setVerbose(int _verbose) {
        verbose = _verbose;
    }

    virtual void setThreshold(bool _threshold) {
        threshold = _threshold;
    }

    virtual void setGradientSteps(int _gradientSteps) {
        maxGradientSteps = _gradientSteps;
    }

    virtual void setProjection(int _projection) {
        projection = _projection;
    }

    // Reset statistic variables
    void resetStats() override;


    // Print
    void printStats(std::ostream &out, bool const& printRoundStats,
                    std::string indent="") const override;

    void printVar(std::ostream &out, var &node) const;

    // Caching
    typedef std::unordered_map<State, std::vector<double>,
            State::HashWithoutRemSteps,
            State::EqualWithoutRemSteps>
            HashMap;
    static HashMap rewardCache;

//    std::ofstream &output;


private:
    // Forward
    void initialize(State const &_rootState, VectorXdual &af);

    bool gradientAscent(VectorXdual &actions, double &Q, VectorXdual &af);

    bool gradientAscentHeuristic(double &Q, VectorXdual &af, VectorXdual &afConformant);

    dual qFunction(const VectorXdual& af);

    dual qFunctionHeuristic(const VectorXdual& af, const VectorXdual & afConformant);

    dual qFunctionVisualization(const VectorXdual &af);

    dual qFunctionHeuristicVisualization(const VectorXdual &af, const VectorXdual& af_conformant);

    VectorXdual updateCPFs(const VectorXdual& sf, const VectorXdual& af);

    void randomRestart(const State &_rootState, VectorXdual &af);

    double timeGradientCalculation(VectorXdual& af);

    double findStepSize(VectorXdual& af, VectorXd& u);

    void projectActions(VectorXdual &af);

    void projectActionsByLayer(VectorXdual &af);

    std::vector<int> sampleConcreteAction(VectorXdual &actions, State const &_rootState);

    std::vector<int> sampleConcreteActionMeta(VectorXdual &actions, State const &_rootState);

    // Example
//    dual f();
//    VectorXdual x;
//    VectorXdual p;


    // Reverse
//    bool findActionsRev(VectorXvar& actions, double& Q, VectorXvar& sf, VectorXvar& af);
//    var qFunctionRev(const VectorXvar& sf, const VectorXvar& af);
//    static VectorXvar updateCPFsRev(VectorXvar& sf, VectorXvar& af);
//    void projectActionsRev(VectorXvar &af);
//    static std::vector<int> sampleConcreteActionRev(VectorXvar &actions);
//    static std::vector<int> sampleConcreteActionMetaRev(VectorXvar &actions);

    // Printing
    void printVar(std::ostream &out, var &node, int &counter) const;

    static void printFluents(std::ostream &out, VectorXvar &sf, VectorXvar &af);

    static void printFluents(std::ostream &out, VectorXdual &sf, VectorXdual &af);

    void printActionFluents(const VectorXdual &af) const ;

    void printGradientStep();

    void printRandomRestart();

    void printBestAction(std::vector<int> &bestActions) const;

    void printGradients(const VectorXdual &af, const dual &q, const VectorXd &dqda) const;

    void printLayer(const dual &reward, const dual &Q, VectorXdual &sf_prime, VectorXdual &af_prime, int layer) const;

    // Parameter
    bool forward;
    bool highestMarginalProbability;
    bool conformant;
    int minUpdates;
    int verbose;
    int maxGradientSteps;
    int projection;
    double threshold;
    double penalty;

    double totalGradientSteps = 0;
    double totalGradientOptimizations = 0;
    int maxSearchDepthForThisStep;

    // Temporary
    double initialValue;

    // Calculation
    VectorXdual sf_input;

    // Stats
    int tree_size;
    int restart_counter;
    int gradientCounter;
    int cacheHits;

    // The stopwatch used for timeout check
    Stopwatch stopwatch;

    bool findActions(VectorXvar &actions, double &Q, VectorXvar &sf_var, VectorXvar &af_var, VectorXdual &sf_dual, VectorXdual &af_dual);

    void printDetailedStats(std::ostream &out, const bool &printRoundStats, std::string indent) const;

    void printActions(const VectorXdual &af) const;

    VectorXdual updateCPFsVisualize(const VectorXdual &sf, const VectorXdual &af);
};

#endif //PROST_SOGBOFA_SEARCH_H
