#ifndef CUSTOM_DATASET_H
#define CUSTOM_DATASET_H

#include "torch/torch.h"

class CustomDataset : public torch::data::Dataset<CustomDataset>
{
    private:
        torch::Tensor inputFeatures, target;
        int dataSize = 0;
    public:
        CustomDataset(torch::Tensor const& dataX, torch::Tensor const& dataY){
            
            inputFeatures = dataX;
            target = dataY;
            dataSize = inputFeatures.size(0);
        };

        // Override the get method to load custom data.
        torch::data::Example<> get(size_t index) override {

            torch::Tensor inputFeaturesIndex = inputFeatures[index];
            torch::Tensor targetIndex = target[index];

            return {inputFeaturesIndex, targetIndex};
        };

        // Override the size method to infer the size of the data set.
        torch::optional<size_t> size() const override {
            return dataSize;
        };
};

#endif