# laura maria engist, 2025
# script to create a visualization of all trials vs baseline

import json
import matplotlib.pyplot as plt
import os

class AllTrialsVsBaseline:
    def __init__(self, baseline, runhistory):
        self.baseline = baseline
        self.runhistory = runhistory
    
    ''''
    Get all configurations from runhistory json file
    '''
    def get_runhistory_all_configs(self):
        runhistory_all_configs = dict()
        with open(self.runhistory, "r") as f:
            data = json.load(f)
            for entry in data["data"]:
                #id = entry["config_id"]
                id = entry["trial"]

                cost = entry["cost"]
                if cost < 1.0:
                    runhistory_all_configs[id] = cost
        return runhistory_all_configs
    
    ''''
    Get x and y values from runhistory_all_configs dictionary
    x values: trial number or configuration
    y values: cost
    '''
    def get_x_and_y_values(self, runhistory_all_configs):
        x_values = [] # trial number or configuration
        y_values = [] # cost
        for config in runhistory_all_configs:
            x_values.append(config) 
            y_values.append(runhistory_all_configs[config])
        return x_values, y_values
    
    ''''
    Get index of best cost and the best cost value
    '''
    def get_index_best_cost(self, y_values):
        min_cost = 100
        min_cost_index = 0
        for index, y_value in enumerate(y_values):
            if y_value < min_cost:
                min_cost = y_value
                min_cost_index = index
        return min_cost_index, min_cost
    
    ''''
    Create line plot of all configurations vs baseline
    '''
    def create_line_plot(self, file_plot_svg, file_plot_png):
        x_values, y_values = self.get_x_and_y_values(self.get_runhistory_all_configs())
        min_cost_index, min_cost = self.get_index_best_cost(y_values)
        plt.plot(x_values, y_values, color='#648fff', linewidth=1, label='All Configs')
        plt.scatter(x_values[min_cost_index], min_cost, color='#dc267f', s=30, label='Best Cost')
        plt.axhline(y=self.baseline, color='#ffb000', linestyle='--', linewidth=1)
        plt.xlabel('configuration') # or trial
        plt.ylabel('cost')
        plt.xticks([x_values[min_cost_index]], [str(x_values[min_cost_index])])
        plt.show()
        plt.savefig(file_plot_svg)
        plt.savefig(file_plot_png)

''' Example usage '''
baseline_cost = 0 # baseline cost value
runhistory = "" # path to runhistory json file created by SMAC
files_directory = "" # directory to save the plots
file_plot_svg = "....svg" # path to save the svg plot
file_plot_png = "....png" # path to save the png plot
all_trials_vs_baseline = AllTrialsVsBaseline(baseline_cost, runhistory)
all_trials_vs_baseline.create_line_plot(os.path.join(files_directory, file_plot_svg), os.path.join(files_directory, file_plot_png))