#ifndef NETWORK_H
#define NETWORK_H

#include "torch/torch.h"
#include <torch/script.h>

#include <vector>
#include <iostream>

#include "search_engine.h"
#include "custom_dataset.h"

class Network: public torch::nn::Module{
    
public:
    Network(){}; //default constructor

    Network(int64_t N, int64_t M, int setting){
        // //parameters -> record gradients "trainable weights"
        // W = register_parameter("W", torch::randn({N, M}));
        // //buffers -> store mean and variances for batch normalization
        // b = register_parameter("b", torch::randn({M}));
        if(setting == 1){
            std::cout << "Create 1-Layer NN" << std::endl;
            fc1 = register_module("fc1", torch::nn::Linear(N, M));
        }else{
            std::cout << "Create 5-Layer NN" << std::endl;
            fc1 = register_module("fc1", torch::nn::Linear(N, 100));
            fc2 = register_module("fc2", torch::nn::Linear(100, 100));
            fc3 = register_module("fc3", torch::nn::Linear(100, 100));
            fc4 = register_module("fc4", torch::nn::Linear(100, 100));
            fc5 = register_module("fc5", torch::nn::Linear(100, M));
        }
    }

    //custom NN parametrized from command line
    Network(int inputSize, int numberOfHiddenLayers, int breadthOfHiddenLayers, int _numberOfEpochs, int _batchSize, double _learningRate){

            //std::cout << "Create custom NN" << std::endl;
            
            this->numberOfEpochs = _numberOfEpochs;
            this->batchSize = _batchSize;
            this->learningRate = _learningRate;

            breadthOfHiddenLayers = inputSize;

            fc1 = register_module("fc1", torch::nn::Linear(inputSize, breadthOfHiddenLayers));
            linearLayers.push_back(fc1);
            for(int layer = 0; layer < numberOfHiddenLayers; layer++){
                torch::nn::Linear fc{nullptr};
                fc = register_module("fc" + std::to_string(layer+2), torch::nn::Linear(breadthOfHiddenLayers, breadthOfHiddenLayers));
                linearLayers.push_back(fc);
            }
            fc2 = register_module("fc" + std::to_string(numberOfHiddenLayers+2), torch::nn::Linear(breadthOfHiddenLayers, 1));
    
    }

    torch::Tensor forward_single(torch::Tensor input){
        //return torch::addmm(b, input, W);
        input = fc1->forward(input);//torch::relu(fc1->forward(input));
        return input;
    }

    torch::Tensor forward_custom(torch::Tensor input){
        
        for(torch::nn::Linear layer : linearLayers){
            input = torch::relu(layer->forward(input));
        }
        input = fc2->forward(input);

        return input;
    }

    torch::Tensor forward_full(torch::Tensor input){
        //return torch::addmm(b, input, W);
        //input = torch::relu(fc1->forward(input));
        //     // Use one of many tensor manipulation functions.
        input = torch::relu(fc1->forward(input));
        input = torch::relu(fc2->forward(input));
        input = torch::relu(fc3->forward(input));
        input = torch::relu(fc4->forward(input));
        input = fc5->forward(input);
        //input = torch::log_softmax(fc3->forward(input), /*dim=*/1);

        return input;
    }

    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}, fc4{nullptr}, fc5{nullptr};
    std::vector<torch::nn::Linear> linearLayers;

    //starts linear regression example
    static void testrun();

    //create first NN and train by depth of 2
    void initNN();
    //prepare trainingdata using current State
    void iterativeNetworkTraining();

    void trainNN(std::vector<std::vector<double>> const& dataX, std::vector<double> const& dataY);
    void trainNNBatches(std::vector<double> const& dataX, std::vector<double> const& dataY);

    void predictQValue(State const& currentState, int const& action, double& reward);
    void predictQValues(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues);

    static void evaluateNetworks(std::vector<Network> const& networks);

    static std::vector<double> prepareSingleActionInput(State const& currentState, int const& action);
    static std::vector<std::vector<double>> transformInput(State const& currentState, std::vector<int> const& actions);
    static std::vector<double> transformInputFlatten(State const& currentState, std::vector<int> const& actions);
    
private: 
        int numberOfEpochs;
        int batchSize;
        double learningRate;

public:
        std::pair<std::vector<double>, std::vector<double>> validationData;
};

#endif