#include "network_trainer.h"

#include "utils/system_utils.h"
#include <sstream>

using namespace std;


void NetworkTrainer::safeModel(Net net, int depth){
    
    string path = modulePath + "/" + to_string(depth) + ".pt";

    torch::save(net, path);
    cout << net << endl;
}

void NetworkTrainer::loadModel(Net& net, int depth){
    
    string path = modulePath + "/" + to_string(depth) + ".pt";

    torch::load(net, path);
}

std::vector<Net> NetworkTrainer::loadModels(int inputSize, int numberOfHiddenLayers, int breadthOfHiddenLayers, int numberOfEpochs, int batchSize, double learningRate, int numberOfActions, int depthBound){
    std::vector<Net> nets;
    
    for(int depth = 2; depth <= depthBound; depth++){
        Net net;
        if(numberOfActions==0){
            net = Net(inputSize, numberOfHiddenLayers, breadthOfHiddenLayers, numberOfEpochs, batchSize, learningRate);
        }else{
            net = Net(inputSize, numberOfHiddenLayers, breadthOfHiddenLayers, numberOfEpochs, batchSize, learningRate, numberOfActions);
        }
        if(SystemUtils::fileAvailable(modulePath+"/"+to_string(depth)+".pt")){
            loadModel(net, depth);
            nets.push_back(net);
        }else{
            cout << "depth " << depth << " is not available" << endl;
            break;
        }
    }
    
    cout << "NETS " << nets.size()+1 << "/" << SearchEngine::horizon << endl;

    if(nets.size() == 0){
        SystemUtils::abort("There are no nets provided!");
    }
    return nets;
}


void NetworkTrainer::parseStateFile(){

    string fileName;
    fileName = "../../../simulated_states/states_" + SearchEngine::taskName;
    SearchEngine::trainingSet.clear();
    string problemDesc;
    if (!SystemUtils::readFile(fileName, problemDesc, "#")) {
        SystemUtils::abort("Error: Unable to read problem file: " +
                           fileName);
    }
    stringstream desc(problemDesc);
    int numberOfTrainingStates;
    desc >> numberOfTrainingStates;
    for (size_t i = 0; i < numberOfTrainingStates; ++i) {
        vector<double> valuesOfDeterministicStateFluents(
            State::numberOfDeterministicStateFluents);
        for (size_t j = 0; j < State::numberOfDeterministicStateFluents; ++j) {
            desc >> valuesOfDeterministicStateFluents[j];
        }

        vector<double> valuesOfProbabilisticStateFluents(
            State::numberOfProbabilisticStateFluents);
        for (size_t j = 0; j < State::numberOfProbabilisticStateFluents; ++j) {
            desc >> valuesOfProbabilisticStateFluents[j];
        }

        State trainingState(valuesOfDeterministicStateFluents,
                            valuesOfProbabilisticStateFluents,
                            SearchEngine::horizon);
        State::calcStateFluentHashKeys(trainingState);
        State::calcStateHashKey(trainingState);
        SearchEngine::trainingSet.push_back(trainingState);
    }
}

void NetworkTrainer::parseStateFile2(){

    string fileName;
    fileName = "../../../state_files/states_" + SearchEngine::taskName;
    SearchEngine::trainingSet.clear();
    string problemDesc;
    if (!SystemUtils::readFile(fileName, problemDesc, "#")) {
        SystemUtils::abort("Error: Unable to read problem file: " +
                           fileName);
    }
    stringstream desc(problemDesc);
    int numberOfTrainingStates;
    desc >> numberOfTrainingStates;
    for (size_t i = 0; i < numberOfTrainingStates; ++i) {
        vector<double> valuesOfDeterministicStateFluents(
            State::numberOfDeterministicStateFluents);
        for (size_t j = 0; j < State::numberOfDeterministicStateFluents; ++j) {
            desc >> valuesOfDeterministicStateFluents[j];
        }

        vector<double> valuesOfProbabilisticStateFluents(
            State::numberOfProbabilisticStateFluents);
        for (size_t j = 0; j < State::numberOfProbabilisticStateFluents; ++j) {
            desc >> valuesOfProbabilisticStateFluents[j];
        }

        State trainingState(valuesOfDeterministicStateFluents,
                            valuesOfProbabilisticStateFluents,
                            SearchEngine::horizon);
        State::calcStateFluentHashKeys(trainingState);
        State::calcStateHashKey(trainingState);
        SearchEngine::trainingSet.push_back(trainingState);
    }
}

void NetworkTrainer::safeStates(){
    stringstream ss;
    string fileName = "../../../simulated_states/states_" + SearchEngine::taskName;
    int numberOfTrainingStates = SearchEngine::trainingSet.size();
    ss << numberOfTrainingStates << endl;

    for(State const& currentState : SearchEngine::trainingSet){
        ss << currentState.toCompactStringForFile() << endl;
    }
    
    SystemUtils::writeFile(fileName, ss.str());

    cout << "finished generation of trainingset" << endl;

    exit(0);
}

void NetworkTrainer::evaluateQValueNetwork(Net& net, vector<double> const& dataX, vector<double> const& dataY){


    int inputSize = SearchEngine::stateFluents.size() + SearchEngine::actionFluents.size();
    int numberOfSamplePoints = dataX.size()/inputSize;

    torch::Tensor inputData = torch::tensor(dataX).clone();
    torch::Tensor targetData = torch::tensor(dataY).clone();
    inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});

    auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());

    auto dataLoader = torch::data::make_data_loader(
    std::move(trainData),
    torch::data::DataLoaderOptions().batch_size(dataY.size())); //batchsize is the full validationset net.validationData.first.size()
    
    for(auto& batch : *dataLoader){
        torch::Tensor prediction = net.get()->forward_custom(batch.data);
        //have to reshape prediction for .backward() method
        //see: https://github.com/pytorch/examples/issues/819
        prediction = torch::reshape(prediction, {batch.data.size(0)});
        //torch::Tensor loss = torch::mse_loss(prediction, batch.target);
        torch::Tensor loss = torch::l1_loss(prediction, batch.target);

        cout << loss << endl;
    }

}


void NetworkTrainer::evaluatePolicyNetwork(Net& net, vector<double> const& dataX, vector<double> const& dataY){
    

    int inputSize = SearchEngine::stateFluents.size();
    int numberOfSamplePoints = dataX.size()/inputSize;
    int numberOfOutputValues = SearchEngine::numberOfActions;
    
    torch::Tensor inputData = torch::tensor(dataX).clone();
    torch::Tensor targetData = torch::tensor(dataY).clone();
    inputData = torch::reshape(inputData, {numberOfSamplePoints, inputSize});
    targetData = torch::reshape(targetData, {numberOfSamplePoints, numberOfOutputValues});

    auto trainData = CustomDataset(inputData, targetData).map(torch::data::transforms::Stack<>());

    auto dataLoader = torch::data::make_data_loader(
        std::move(trainData),
        torch::data::DataLoaderOptions().batch_size(dataY.size())); //batchsize is the full validationset net.validationData.first.size()

    for(auto& batch : *dataLoader){
        torch::Tensor prediction = net.get()->forward_custom(batch.data);
        //have to reshape prediction for .backward() method
        //see: https://github.com/pytorch/examples/issues/819
        prediction = torch::reshape(prediction, {batch.target.size(0), batch.target.size(1)});
        auto access = prediction.accessor<float, 2>();

        for(int sampleIndex = 0; sampleIndex < batch.data.size(0); sampleIndex++){
            for(int outputIndex = 0; outputIndex < batch.target.size(1); outputIndex++){
                if(batch.target[sampleIndex][outputIndex].item<double>() == -1){
                    access[sampleIndex][outputIndex] = batch.target[sampleIndex][outputIndex].item<double>();
                }
            }
        }

        //torch::Tensor loss = torch::mse_loss(prediction, batch.target);
        torch::Tensor loss = torch::l1_loss(prediction, batch.target);

        cout << loss << endl;
    }
}