#include "network.h"
//#include "search_engine.h"

#include "torch/torch.h"

#include <vector>
#include <iostream> 



using namespace std;



void Network::testrun(){
    
    Network net2(1,1,1);
    //for (const auto& p: net2.parameters()){
      //  cout << p << endl;
    //}



    torch::optim::SGD optimizer (net2.parameters(), /*learning rate*/0.01);
    
    for(int epoch = 0; epoch < 100; epoch++){
        float data_X[] = {1.,2.,3.,4.};
        float data_Y[] = {2.,4.,6.,8.};
        cout << (sizeof(data_X)/sizeof(*data_X)) << endl;
        for(size_t i = 0; i < (sizeof(data_X)/sizeof(*data_X)); i++){
            float x_input[] = {data_X[i]};
            float y_input[] = {data_Y[i]};
            torch::Tensor x = torch::from_blob(x_input, {1,1});
            torch::Tensor y = torch::from_blob(y_input, {1,1});
            
            torch::Tensor prediction = net2.forward_single(x);
            if((epoch % 10) == 0)
            cout << "epoch: ----------- " << epoch << endl;
            cout << "xvalue: " << x << endl;
            cout << "yvalue: " << y << endl;
            cout << "prediciton: " << prediction << endl;
        
            torch::Tensor loss = torch::mse_loss(prediction, y);
            
            loss.backward();
            optimizer.step();
            optimizer.zero_grad();
        }
        float test_value[] = {5};
        
        cout << "Prediction for 5:" << net2.forward_single(torch::from_blob(test_value, {1,1}));
    }


}



void Network::initNN(){

    cout << "Initialize NN_2..." << endl;

    int horizon = 2;
    vector<double> dataX;
    vector<double> dataY;
    string value = "[IPC2014]";
    SearchEngine* thtsTrainData = SearchEngine::fromString(value);
    thtsTrainData->setMaxSearchDepth(horizon);
    thtsTrainData->initSession();

    for (State const& state : SearchEngine::trainingSet) {
        thtsTrainData->initStep(state);
        vector<int> actions = thtsTrainData->getApplicableActions(state);
    
        vector<double> reward(actions.size(), 0.0);
        thtsTrainData->estimateQValues(state, actions, reward);

        for (int i = 0; i < actions.size(); ++i) {
            if (actions[i] == i) {
                dataY.push_back(reward[i]);
            }
        }

        vector<double> trainingDataForState = transformInputFlatten(state, actions);
        dataX.insert(dataX.end(), trainingDataForState.begin(), trainingDataForState.end());
    }

    this->validationData = std::make_pair(dataX, dataY);
    trainNNBatches(dataX, dataY);

    delete thtsTrainData;

    cout << "Init training finished..." << endl;
    
}

std::vector<double>  Network::prepareSingleActionInput(State const& currentState, int const& action){
    //vector<vector<double>> result;
    vector<double> stateVars;

    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    stateVars.insert(stateVars.end(), SearchEngine::actionStates[action].state.begin(), SearchEngine::actionStates[action].state.end());
    //result.push_back(stateVars);
    
    return stateVars;
}

std::vector<std::vector<double>> Network::transformInput(State const& currentState, std::vector<int> const& actions){
    vector<vector<double>> result;
    vector<double> stateVars;
    
    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    for (int k = 0; k < actions.size(); k++) {
        if (actions[k] == k) {
            vector<double> allVars(stateVars);
            allVars.insert(allVars.end(), SearchEngine::actionStates[k].state.begin(), SearchEngine::actionStates[k].state.end());
            result.push_back(allVars);
        }
    }
    
    return result;
}

std::vector<double> Network::transformInputFlatten(State const& currentState, std::vector<int> const& actions){
    vector<double> result;
    vector<double> stateVars;

    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    for (int k = 0; k < actions.size(); k++) {
        if (actions[k] == k) {
            vector<double> allVars(stateVars); //merge to result insert
            allVars.insert(allVars.end(), SearchEngine::actionStates[k].state.begin(), SearchEngine::actionStates[k].state.end());
            result.insert(result.end(), allVars.begin(), allVars.end());
        }
    }
    
    return result;
}

void Network::trainNN(std::vector<std::vector<double>> const& dataX, std::vector<double> const& dataY){
    cout << "start trainNN" << endl;
    cout << "dataX " << dataX.size();
    cout << "-------------------------------------------------" << endl;
    torch::optim::SGD optimizer (this->parameters(), /*learning rate*/0.01);  //0.0001
    
    std::vector<double> expectedOutcome;
    torch::Tensor runningLoss;

    int counter;
    for(int epoch = 0; epoch < numberOfEpochs; epoch++){
        runningLoss = torch::zeros(1);
        counter = 0;
        for (const std::vector<double>& dataXi : dataX) {

            torch::Tensor inputStream = torch::tensor(dataXi).clone();
            torch::Tensor prediction;
 
            prediction = this->forward_custom(inputStream);
            

            expectedOutcome.push_back(dataY[counter]);
            torch::Tensor outcome = torch::tensor(expectedOutcome).clone();

            torch::Tensor loss = torch::mse_loss(prediction, outcome);
            runningLoss += loss.item();
            
            loss.backward();

            optimizer.step();
            optimizer.zero_grad();
            assert(dataY[counter] >= 0);

            if((epoch%10) == 0){
                cout << "epoch: ----------- " << epoch << endl;
                cout << "yvalue: " << dataY[counter] << endl;
                cout << "prediciton: " << prediction << endl;
                cout << runningLoss << endl;
                
            }

            expectedOutcome.clear();
            counter++;

        }
    }
}

void Network::trainNNBatches(std::vector<double> const& dataX, std::vector<double> const& dataY){
    // torch::optim::SGDOptions().momentum(0.9);
    torch::optim::SGD optimizer (this->parameters(), learningRate);
    // double momentum=0.9;
    // torch::optim::SGD optimizer (this->parameters(), learningRate, momentum);
    //torch::optim::SGDOptions()
    int inputSize = SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size();
    int numberOfSamplePoints = dataX.size()/inputSize;

    torch::Tensor inputData = torch::tensor(dataX).clone();
    torch::Tensor targetData = torch::tensor(dataY).clone();
    inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});

    auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());
    //inputData.set_requires_grad(true);

    auto dataLoader = torch::data::make_data_loader(
        std::move(trainData),
        torch::data::DataLoaderOptions().batch_size(batchSize));

    for(int epoch = 0; epoch < numberOfEpochs; epoch++){

        for(auto& batch : *dataLoader){
            
            optimizer.zero_grad();

            torch::Tensor prediction = this->forward_custom(batch.data);
            //have to reshape prediction for .backward() method
            //see: https://github.com/pytorch/examples/issues/819
            prediction = torch::reshape(prediction, {batch.data.size(0)});
            torch::Tensor loss = torch::mse_loss(prediction, batch.target);

            // cout << prediction << endl;
            // cout << batch.target << endl;
            // cout << loss << ""<<endl;
            assert(!(isnan(prediction[0].item<double>())));
            //assert(loss.item<double>() < 100000);
            //assert(prediction[0].item<double>() > 0.0);
            // assert(loss.item<double>() > 0);
            loss.backward();
            optimizer.step();
            
        }
    }
}


void Network::predictQValue(State const& currentState, int const& action, double& reward){
    
    //vector<vector<double>> dataX;
    vector<int> actions;
    actions.push_back(action);

    vector<double> trainingDataForState = prepareSingleActionInput(currentState, action);
    
    torch::Tensor inputStream = torch::tensor(trainingDataForState).clone();
    torch::Tensor prediction = this->forward_custom(inputStream);
    reward = prediction[0].item<double>();

}

void Network::predictQValues(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues){
    
    vector<vector<double>> dataX;
    vector<vector<double>> trainingDataForState = transformInput(currentState, actions);
    
    double out;

    dataX.insert(dataX.end(), trainingDataForState.begin(), trainingDataForState.end());

    for (const std::vector<double>& dataXi : dataX) {
        
        torch::Tensor inputStream = torch::tensor(dataXi).clone();
        torch::Tensor prediction = this->forward_custom(inputStream);
        
        out = prediction[0].item<double>();
        if(isnan(out)){
            cout << "NAN VALUE IN NN-PREDICTION" << endl;
        }
        qValues.push_back(out);
    }
}

//evaluate the trained networks to verify their accuracy
void Network::evaluateNetworks(std::vector<Network> const& networks){

    cout << "Evaluate networks..." << endl;
    cout << "Netsize inside evaluation: " <<  networks.size()+2 << endl;

    int inputSize = SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size();
    for(int i = 0; i < networks.size(); i++){
        Network net = networks[i];

        int numberOfSamplePoints = net.validationData.first.size()/inputSize;

        torch::Tensor inputData = torch::tensor(net.validationData.first).clone();
        torch::Tensor targetData = torch::tensor(net.validationData.second).clone();
        inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});

        auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());
        

        auto dataLoader = torch::data::make_data_loader(
        std::move(trainData),
        torch::data::DataLoaderOptions().batch_size(net.validationData.second.size())); //batchsize is the full validationset net.validationData.first.size()

        for(auto& batch : *dataLoader){
            torch::Tensor prediction = net.forward_custom(batch.data);
            //have to reshape prediction for .backward() method
            //see: https://github.com/pytorch/examples/issues/819
            prediction = torch::reshape(prediction, {batch.data.size(0)});
            torch::Tensor loss = torch::mse_loss(prediction, batch.target);
            
            cout << "loss of NN depth_" << i+2 << ": " << loss << endl;
            
        }

   }
   cout << "evaluation finished ..." << endl;
   
}