#ifndef NETWORK_TRAINER_H
#define NETWORK_TRAINER_H

#include "states.h"
#include "network_impl.h"



class NetworkTrainer{

public:
    NetworkTrainer(){};

    NetworkTrainer(Net& _initNet, NetworkType _networkType, std::string _modulePath, std::string _netConfig, std::string _version){
        networkType = _networkType;
        nets.push_back(_initNet);
        setupModulePath(_modulePath, _netConfig, _version);
    };

    void safeModel(Net net, int depth);
    void loadModel(Net& net, int depth);

    std::vector<Net> loadModels(int inputSize, int numberOfHiddenLayers, int breadthOfHiddenLayers, int numberOfEpochs, int batchSize, double learningRate, int numberOfActions, int depthBound);

    static void parseStateFile();
    static void parseStateFile2();
    static void safeStates();

    static void evaluateNetworks(std::vector<Net>& networks);
    static void evaluateQValueNetwork(Net& network, std::vector<double> const& dataX, std::vector<double> const& dataY);
    static void evaluatePolicyNetwork(Net& network, std::vector<double> const& dataX, std::vector<double> const& dataY);
    
    void setupModulePath(std::string _modulePath, std::string _netConfig, std::string _version){
        std::string delimiter = "_inst";
        std::string benchmarkGroup = SearchEngine::taskName.substr(0, SearchEngine::taskName.find(delimiter));

        modulePath = _modulePath + benchmarkGroup + "/" + _version + "/" + SearchEngine::taskName + "/" + _netConfig;

        const int dir= system(("mkdir -p " + modulePath).c_str());
        if (dir< 0)
        {
            std::cout << "mkdir failed" << std::endl;
        }
    }

private:

    std::vector<Net> nets;
    NetworkType networkType;
    std::string modulePath;

};

#endif