#include "network_impl.h"
#include "network_trainer.h"

#include "torch/torch.h"

#include <vector>
#include <iostream> 



using namespace std;

std::vector<double>  NetImpl::prepareSingleActionInput(State const& currentState, int const& action, bool const& appendAction){
    vector<double> stateVars;

    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    if(appendAction){
        stateVars.insert(stateVars.end(), SearchEngine::actionStates[action].state.begin(), SearchEngine::actionStates[action].state.end());
    }

    return stateVars;
}

void NetImpl::prepareSingleActionInput(State const& currentState, int const& action, bool const& appendAction, std::vector<double>& dataX){

    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        dataX.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        dataX.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    if(appendAction){
        dataX.insert(dataX.end(), SearchEngine::actionStates[action].state.begin(), SearchEngine::actionStates[action].state.end());
    }
}

std::vector<std::vector<double>> NetImpl::transformInput(State const& currentState, std::vector<int> const& actions){
    vector<vector<double>> result;
    vector<double> stateVars;
    
    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    for (int k = 0; k < actions.size(); k++) {
        if (actions[k] == k) {
            vector<double> allVars(stateVars);
            allVars.insert(allVars.end(), SearchEngine::actionStates[k].state.begin(), SearchEngine::actionStates[k].state.end());
            result.push_back(allVars);
        }
    }
    
    return result;
}

std::vector<std::vector<double>> NetImpl::transformInputAllActions(State const& currentState, std::vector<int> const& actions){
    vector<vector<double>> result;
    vector<double> stateVars;
    
    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    for (int k = 0; k < actions.size(); k++) {
        
        vector<double> allVars(stateVars);
        allVars.insert(allVars.end(), SearchEngine::actionStates[k].state.begin(), SearchEngine::actionStates[k].state.end());
        result.push_back(allVars);
    }
    
    return result;
}

void NetImpl::transformInputFlatten(State const& currentState, std::vector<int> const& actions, std::vector<double>& dataX){
    vector<double> stateVars;
    
    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    for (int k = 0; k < actions.size(); k++) {
        if (actions[k] == k) {
            dataX.insert(dataX.end(), stateVars.begin(), stateVars.end()); 
            dataX.insert(dataX.end(), SearchEngine::actionStates[k].state.begin(), SearchEngine::actionStates[k].state.end());
        }
    }
}

void NetImpl::transformInputPolicyNetwork(State const& currentState, std::vector<double>& dataX){
    vector<double> stateVars;
    
    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    dataX.insert(dataX.end(), stateVars.begin(), stateVars.end()); 
}


void NetImpl::transformInputHorizonNetwork(State const& currentState, std::vector<double>& dataX, int const& currentDepth){
    vector<double> stateVars;
    
    for (int i = 0; i < currentState.numberOfDeterministicStateFluents; i++) {
        stateVars.push_back(static_cast<int>(currentState.deterministicStateFluent(i)));
    }
    for (int j = 0; j < currentState.numberOfProbabilisticStateFluents; j++) {
        stateVars.push_back(static_cast<int>(currentState.probabilisticStateFluent(j)));
    }

    stateVars.push_back((double)currentDepth);
    dataX.insert(dataX.end(), stateVars.begin(), stateVars.end()); 
}

void NetImpl::trainNNBatches(std::vector<double> const& dataX, std::vector<double> const& dataY){
        
    torch::optim::SGD optimizer(
        this->parameters(),
        torch::optim::SGDOptions(learningRate).momentum(0.9));

    // torch::optim::Adam optimizer(
    //     this->parameters(),
    //     torch::optim::AdamOptions(learningRate).betas({0.9, 0.999})
    // );

    // torch::optim::LBFGS optimizer(
    //     this->parameters(),
    //     torch::optim::LBFGSOptions(learningRate).max_iter(40));
    
    int inputSize = SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size();
    int numberOfSamplePoints = dataX.size()/inputSize;
    bool printed = true;
    torch::Tensor inputData = torch::tensor(dataX).clone();
    torch::Tensor targetData = torch::tensor(dataY).clone();
    inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});

    auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());
    //inputData.set_requires_grad(true);

    auto dataLoader = torch::data::make_data_loader(
        std::move(trainData),
        torch::data::DataLoaderOptions().batch_size(batchSize));

    torch::Tensor loss;

    for(int epoch = 0; epoch < numberOfEpochs; epoch++){
        
        for(auto& batch : *dataLoader){
            
            optimizer.zero_grad();

            if(!printed){
                cout << batch.data << endl;
                cout << batch.target << endl;
                printed = true;
            }
            torch::Tensor prediction = this->forward_custom(batch.data);
            
            //have to reshape prediction for .backward() method
            //see: https://github.com/pytorch/examples/issues/819

            prediction = torch::reshape(prediction, {batch.data.size(0)});
            loss = torch::mse_loss(prediction, batch.target);
            //loss = torch::l1_loss(prediction, batch.target);

            loss.backward();
            optimizer.step();
            
        }                
        
    }
}

void NetImpl::trainPolicyNetwork(std::vector<double> const& dataX, std::vector<double> const& dataY){
    torch::optim::SGD optimizer(
        this->parameters(),
        torch::optim::SGDOptions(learningRate).momentum(0.9));
    
    //other possible optimizer variant
    // torch::optim::Adam optimizer(
    //     this->parameters(),
    //     torch::optim::AdamOptions(learningRate).betas({0.9, 0.999})
    // );

    bool printed = true;
    int inputSize = SearchEngine::stateFluents.size(); 
    int numberOfSamplePoints = dataX.size()/inputSize;
    int numberOfOutputValues = SearchEngine::numberOfActions;
    int batchNumber;
    torch::Tensor inputData = torch::tensor(dataX).clone();
    torch::Tensor targetData = torch::tensor(dataY).clone();
    inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});
    targetData = torch::reshape(targetData, {numberOfSamplePoints, numberOfOutputValues});
    
    
    auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());
    
    auto dataLoader = torch::data::make_data_loader(
        std::move(trainData),
        torch::data::DataLoaderOptions().batch_size(batchSize));
    
    torch::Tensor loss;

    for(int epoch = 0; epoch < numberOfEpochs; epoch++){
        batchNumber = 0;
        
        for(auto& batch : *dataLoader){

            if(!printed){
                cout << batch.data << endl;
                cout << batch.target << endl;
            }
            optimizer.zero_grad();
                     
            torch::Tensor prediction = this->forward_custom(batch.data);
            //have to reshape prediction for .backward() method
            //see: https://github.com/pytorch/examples/issues/819

            prediction = torch::reshape(prediction, {batch.target.size(0), batch.target.size(1)});
            auto access = prediction.accessor<float, 2>();


            for(int sampleIndex = 0; sampleIndex < batch.data.size(0); sampleIndex++){
                for(int outputIndex = 0; outputIndex < batch.target.size(1); outputIndex++){
                    if(batch.target[sampleIndex][outputIndex].item<double>() == -1){
                        access[sampleIndex][outputIndex] = batch.target[sampleIndex][outputIndex].item<double>();
                    }
                }
            }
            
            loss = torch::mse_loss(prediction, batch.target);
            //loss = torch::l1_loss(prediction, batch.target);
            if(!printed){
                cout << "Loss: " << loss << endl;
                printed = true;
            }
            // cout << "Prediction: " << prediction << endl;
            // cout << "Target: " << batch.target << endl;
            // cout << "Loss: " << loss << endl;

            loss.backward();
            optimizer.step();

            batchNumber++;
        }
    }
}

void NetImpl::trainHorizonNetwork(std::vector<double> const& dataX, std::vector<double> const& dataY){
    torch::optim::SGD optimizer(
        this->parameters(),
        torch::optim::SGDOptions(learningRate).momentum(0.9));

    bool printed = true;
    int inputSize = SearchEngine::stateFluents.size()+1;
    int numberOfSamplePoints = dataX.size()/inputSize;
    int numberOfOutputValues = SearchEngine::numberOfActions;
    int batchNumber;
    torch::Tensor inputData = torch::tensor(dataX).clone();
    torch::Tensor targetData = torch::tensor(dataY).clone();
    inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});
    targetData = torch::reshape(targetData, {numberOfSamplePoints, numberOfOutputValues});

    auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());
    //inputData.set_requires_grad(true);

    auto dataLoader = torch::data::make_data_loader(
        std::move(trainData),
        torch::data::DataLoaderOptions().batch_size(batchSize));
    
    torch::Tensor loss;

    for(int epoch = 0; epoch < numberOfEpochs; epoch++){
        batchNumber = 0;
        
        for(auto& batch : *dataLoader){

            if(!printed){
                cout << batch.data << endl;
                cout << batch.target << endl;
                printed = true;
            }
            optimizer.zero_grad();
                     

            torch::Tensor prediction = this->forward_custom(batch.data);


            prediction = torch::reshape(prediction, {batch.target.size(0), batch.target.size(1)});
            auto access = prediction.accessor<float, 2>();

            for(int sampleIndex = 0; sampleIndex < batch.data.size(0); sampleIndex++){
                for(int outputIndex = 0; outputIndex < batch.target.size(1); outputIndex++){
                    if(batch.target[sampleIndex][outputIndex].item<double>() == -1){
                        access[sampleIndex][outputIndex] = batch.target[sampleIndex][outputIndex].item<double>();
                    }
                }
            }
            
            //loss = torch::mse_loss(prediction, batch.target);
            loss = torch::l1_loss(prediction, batch.target);
            
            // cout << "Prediction: " << prediction << endl;
            // cout << "Target: " << batch.target << endl;
            // cout << "Loss: " << loss << endl;

            loss.backward();
            optimizer.step();

            batchNumber++;
        }
    }
}


void NetImpl::predictQValue(State const& currentState, int const& action, double& reward){
    
    vector<int> actions;
    actions.push_back(action);

    vector<double> trainingDataForState = prepareSingleActionInput(currentState, action, true);
    
    torch::Tensor inputStream = torch::tensor(trainingDataForState).clone();
    torch::Tensor prediction = this->forward_custom(inputStream);
    reward = prediction[0].item<double>();
}

void NetImpl::predictQValuePolicyNetwork(State const& currentState, int const& action, double& reward){
    
    vector<int> actions;
    actions.push_back(action);

    vector<double> trainingDataForState = prepareSingleActionInput(currentState, action, false);
    
    torch::Tensor inputStream = torch::tensor(trainingDataForState).clone();
    torch::Tensor prediction = this->forward_custom(inputStream);
    reward = prediction[action].item<double>();
}

void NetImpl::predictQValueHorizonNetwork(State const& currentState, int const& action, double& reward, int const& depth){
    
    vector<int> actions;
    actions.push_back(action);

    vector<double> trainingDataForState = prepareSingleActionInput(currentState, action, false);
    trainingDataForState.push_back((double)depth);

    torch::Tensor inputStream = torch::tensor(trainingDataForState).clone();
    torch::Tensor prediction = this->forward_custom(inputStream);
    reward = prediction[action].item<double>();
}

//returns qvalue for each action
void NetImpl::predictQValues(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues){
    
    vector<vector<double>> dataX = transformInputAllActions(currentState, actions);
    for (size_t actionIndex = 0; actionIndex < dataX.size(); actionIndex++){
        const std::vector <double>& dataXi = dataX[actionIndex]; 
        torch::Tensor inputStream = torch::tensor(dataXi).clone();
        torch::Tensor prediction = this->forward_custom(inputStream);
        qValues[actionIndex] = prediction[0].item<double>();
    }
}

//returns predicted value if applicable
void NetImpl::predictQValuesPolicyNetwork(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues){
    vector<double> dataX;
    transformInputPolicyNetwork(currentState, dataX);
    torch::Tensor inputStream = torch::tensor(dataX).clone();
    torch::Tensor prediction = this->forward_custom(inputStream);
    
    for (size_t actionIndex = 0; actionIndex < actions.size(); actionIndex++)
    {
        qValues[actionIndex] = prediction[actionIndex].item<double>();
    }
}


void NetImpl::predictQValuesHorizonNetwork(State const& currentState, std::vector<int> const& actions, std::vector<double>& qValues, int const& depth){
    vector<double> dataX;
    transformInputHorizonNetwork(currentState, dataX, (double) depth);
    torch::Tensor inputStream = torch::tensor(dataX).clone();
    torch::Tensor prediction = this->forward_custom(inputStream);
    
    for (size_t actionIndex = 0; actionIndex < actions.size(); actionIndex++)
    {
        if(actions[actionIndex] == actionIndex){
            qValues[actionIndex] = prediction[actionIndex].item<double>();
        }
    }
}

void NetImpl::createTrainTestDataQvalue(std::vector<double>& dataX, std::vector<double>& dataY, std::vector<double>& testX, std::vector<double>& testY, int numberOfSamples, int inputSize){
    
    int testSetSize = numberOfSamples/10;
    
    for(int index = 0; index < testSetSize; index++){
        testX.insert(testX.end(), dataX.begin()+(index*inputSize), dataX.begin()+((index+1)*inputSize));
        testY.insert(testY.end(), dataY[index]);
    }

    //dataX and dataY become the trainsets
    dataX.erase(dataX.begin(), dataX.begin()+(testSetSize*inputSize));
    dataY.erase(dataY.begin(), dataY.begin()+testSetSize);
}

void NetImpl::createTrainTestDataPolicy(std::vector<double>& dataX, std::vector<double>& dataY, std::vector<double>& testX, std::vector<double>& testY, int numberOfSamples, int inputSize, int outputSize){
    
    int testSetSize = numberOfSamples/10;

    
    for(int index = 0; index < testSetSize; index++){
        testX.insert(testX.end(), dataX.begin()+(index*inputSize), dataX.begin()+((index+1)*inputSize));
        testY.insert(testY.end(), dataY.begin()+(index*outputSize), dataY.begin()+((index+1)*outputSize));
    }
    //dataX and dataY become the trainsets
    dataX.erase(dataX.begin(), dataX.begin()+(testSetSize*inputSize));
    dataY.erase(dataY.begin(), dataY.begin()+(testSetSize*outputSize));
}