# laura maria engist, 2025
# script to create a plot with costs per epoch based on the data from wandb

import matplotlib.pyplot as plt
import os
import pandas as pd

class CostPerEpoch:
    def __init__(self):
        pass
    
    ''''
    get x and y values from the cost per epoch dictionary
    '''
    def get_x_and_y_values(self, cost_per_epoch_dict):
        x_values = [] # epoch
        y_values = [] # cost
        for epoch in cost_per_epoch_dict:
            x_values.append(epoch) 
            y_values.append(cost_per_epoch_dict[epoch])
        return x_values, y_values
    
    ''''
    get index of the best cost (minimum cost)
    '''
    def get_index_best_cost(self, y_values):
        min_cost = 100
        min_cost_index = 0
        for index, y_value in enumerate(y_values):
            if y_value < min_cost:
                min_cost = y_value
                min_cost_index = index
        return min_cost_index, min_cost
    
    '''
    create line plot with the cost per epoch
    '''
    def create_line_plot(self, title, x_values, y_values, file_plot_svg, file_plot_png):
        min_cost_index, min_cost = self.get_index_best_cost(y_values)
        print(f"min cost: {min_cost}")
        plt.scatter([x_values[i] for i in range(len(x_values)) if i != min_cost_index], [y_values[i] for i in range(len(y_values)) if i != min_cost_index], color='#648fff', s=30)
        plt.scatter(x_values[min_cost_index], min_cost, color='#dc267f', s=30, label='Best Cost')
        plt.xlabel('epoch')
        plt.ylabel('cost')
        plt.xticks(x_values)
        plt.title(title)
        plt.show()
        plt.savefig(file_plot_svg)
        plt.savefig(file_plot_png)

    ''''
    create plots for all trials
    '''
    def create_for_all_trials(self, cost_per_epoch_dicts, nn_model, path_dir):
        for index in range(1, len(cost_per_epoch_dicts)+1):
            x_values, y_values = self.get_x_and_y_values(cost_per_epoch_dicts[index])
            title = nn_model + "_" + str(index)
            file_name_svg = "cost_per_epoch_trial" + str(index) + ".svg"
            file_name_png = "cost_per_epoch_trial" + str(index) + ".png"
            self.create_line_plot(title, x_values, y_values, os.path.join(path_dir, file_name_svg), os.path.join(path_dir, file_name_png))

    ''''
    create dictionary from csv file exported from weights and biases
    '''
    def create_dic_from_csv(self, csv_path):
        csv_file = pd.read_csv(csv_path, usecols=[0, 4])
        epochs = csv_file.iloc[:, 0].values  # column 0
        aucs = csv_file.iloc[:, 1].values  # column 2 (now at 1)
        cost_per_epoch_dict = dict()
        for epoch in epochs:
            cost_per_epoch_dict[epoch] = 1 - (aucs[epoch]/3)
        return cost_per_epoch_dict

# RUN THE CODE
path_to_csv_file = "" # vsc file exported from weights and biases with all costs per epochs TODO: set this variable
title_plot = "" # TODO: set this variable
path_directory_for_plots = "" # TODO: set this variable
cost_per_epoch = CostPerEpoch()
cost_per_epoch_dict1 = cost_per_epoch.create_dic_from_csv(path_to_csv_file)
cost_per_epoch_dict = {1: cost_per_epoch_dict1}
cost_per_epoch.create_for_all_trials(cost_per_epoch_dict, title_plot, path_directory_for_plots)

# if multiple trials:
#cost_per_epoch_dict1 = cost_per_epoch.create_dic_from_csv(path_to_csv_file)
#cost_per_epoch_dict2 = cost_per_epoch.create_dic_from_csv(path_to_csv_file2)
#cost_per_epoch_dicts = {1: cost_per_epoch_dict1, 2: cost_per_epoch_dict2}
#cost_per_epoch.create_for_all_trials(cost_per_epoch_dict, title_plot, path_directory_for_plots)