#ifndef NETWORK_IMPL_H
#define NETWORK_IMPL_H

#include "torch/torch.h"
#include <torch/script.h>

#include <vector>
#include <iostream>

#include "search_engine.h"
#include "custom_dataset.h"

enum class NetworkType {
    QValueNetwork,
    PolicyNetwork,
    BoundedQvalue,
    BoundedPolicy,
    HorizonNetwork
};

struct NetImpl: public torch::nn::Module {
    
public:
    NetImpl(){};

    //constructor for Q-value network
    NetImpl(int inputSize, int numberOfHiddenLayers, int breadthOfHiddenLayers, int _numberOfEpochs, int _batchSize, double _learningRate){
            
            this->numberOfEpochs = _numberOfEpochs;
            this->batchSize = _batchSize;
            this->learningRate = _learningRate;

            breadthOfHiddenLayers = inputSize*breadthOfHiddenLayers;

            fc1 = register_module("fc1", torch::nn::Linear(inputSize, breadthOfHiddenLayers));
            linearLayers.push_back(fc1);
            for(int layer = 0; layer < numberOfHiddenLayers; layer++){
                torch::nn::Linear fc{nullptr};
                fc = register_module("fc" + std::to_string(layer+2), torch::nn::Linear(breadthOfHiddenLayers, breadthOfHiddenLayers));
                linearLayers.push_back(fc);
            }
            fc2 = register_module("fc" + std::to_string(numberOfHiddenLayers+2), torch::nn::Linear(breadthOfHiddenLayers, 1));
    
    }

    //constructor for policy network
    NetImpl(int inputSize, int numberOfHiddenLayers, int breadthOfHiddenLayers, int _numberOfEpochs, int _batchSize, double _learningRate, int numberOfActions){
            
            this->numberOfEpochs = _numberOfEpochs;
            this->batchSize = _batchSize;
            this->learningRate = _learningRate;
            
            breadthOfHiddenLayers = inputSize*breadthOfHiddenLayers;
            
            fc1 = register_module("fc1", torch::nn::Linear(inputSize, breadthOfHiddenLayers));
            linearLayers.push_back(fc1);
            for(int layer = 0; layer < numberOfHiddenLayers; layer++){
                torch::nn::Linear fc{nullptr};
                fc = register_module("fc" + std::to_string(layer+2), torch::nn::Linear(breadthOfHiddenLayers, breadthOfHiddenLayers));
                linearLayers.push_back(fc);
            }
            fc2 = register_module("fc" + std::to_string(numberOfHiddenLayers+2), torch::nn::Linear(breadthOfHiddenLayers, numberOfActions));
    }

    torch::Tensor forward_custom(torch::Tensor input){
        
        for(torch::nn::Linear layer : linearLayers){
            input = torch::relu(layer->forward(input));
        }
        input = fc2->forward(input);

        return input;
    }

    torch::nn::Linear fc1{nullptr}, fc2{nullptr};
    std::vector<torch::nn::Linear> linearLayers;

    void trainNNBatches(std::vector<double> const& dataX, std::vector<double> const& dataY);
    void trainPolicyNetwork(std::vector<double> const& dataX, std::vector<double> const& dataY);
    void trainHorizonNetwork(std::vector<double> const& dataX, std::vector<double> const& dataY);

    void predictQValue(State const& currentState, int const& action, double& reward);
    void predictQValuePolicyNetwork(State const& currentState, int const& action, double& reward);
    void predictQValueHorizonNetwork(State const& currentState, int const& action, double& reward, int const& depth);

    void predictQValues(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues);
    void predictQValuesPolicyNetwork(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues);
    void predictQValuesHorizonNetwork(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues, int const& depth);

    static std::vector<double> prepareSingleActionInput(State const& currentState, int const& action, bool const& appendAction);
    static void prepareSingleActionInput(State const& currentState, int const& action, bool const& appendAction, std::vector<double>& dataX);

    static std::vector<std::vector<double>> transformInput(State const& currentState, std::vector<int> const& actions);
    static std::vector<std::vector<double>> transformInputAllActions(State const& currentState, std::vector<int> const& actions);
    static void transformInputFlatten(State const& currentState, std::vector<int> const& actions, std::vector<double>& dataX);
    static void transformInputPolicyNetwork(State const& currentState, std::vector<double>& dataX);
    static void transformInputHorizonNetwork(State const& currentState, std::vector<double>& dataX, int const& currentDepth);

    static void createTrainTestDataQvalue(std::vector<double>& dataX, std::vector<double>& dataY, std::vector<double>& testX, std::vector<double>& testY, int numberOfSamples, int inputSize);
    static void createTrainTestDataPolicy(std::vector<double>& dataX, std::vector<double>& dataY, std::vector<double>& testX, std::vector<double>& testY, int numberOfSamples, int inputSize, int outputSize);

private: 
        int numberOfEpochs;
        int batchSize;
        double learningRate;

public:
        std::pair<std::vector<double>, std::vector<double>> validationData;
};

TORCH_MODULE(Net);

#endif