# laura maria engist, 2025
# script to create a visualization of all incumbents vs baseline

import matplotlib.pyplot as plt
import json
import os

class AllIncumbentsVsBaseline:
    def __init__(self, baseline, intensifier):
        self.baseline = baseline
        self.intensifier = intensifier
    
    ''' 
    Get x and y values from incumbents dictionary 
    x values: trial number or configuration
    y values: cost
    '''
    def get_x_and_y_values(self, incumbents):
        x_values = [] # trial number or configuration
        y_values = [] # cost
        for trial in incumbents:
            x_values.append(trial) 
            y_values.append(incumbents[trial])
        return x_values, y_values
    
    ''' 
    Get index of best cost and the best cost value
    '''
    def get_index_best_cost(self, y_values):
        min_cost = 100
        min_cost_index = 0
        for index, y_value in enumerate(y_values):
            if y_value < min_cost:
                min_cost = y_value
                min_cost_index = index
        return min_cost_index, min_cost
    
    ''''
    Get all incumbents from intensifier json file
    '''
    def get_incumbents(self):
        intensifier_all_incumbents = dict()
        with open(self.intensifier, "r") as f:
            data = json.load(f)
            for entry in data["trajectory"]:
                #id = entry["config_ids"][0]
                id = entry["trial"]
                cost = entry["costs"][0]
                if cost < 1.0:
                    intensifier_all_incumbents[id] = cost
        return intensifier_all_incumbents
    
    ''''
    Create line plot of all incumbents vs baseline
    '''
    def create_line_plot(self, file_plot_svg, file_plot_png):
        x_values, y_values = self.get_x_and_y_values(self.get_incumbents())
        min_cost_index, min_cost = self.get_index_best_cost(y_values)
        plt.scatter([x_values[i] for i in range(len(x_values)) if i != min_cost_index], [y_values[i] for i in range(len(y_values)) if i != min_cost_index], color='#648fff', s=30)
        plt.scatter(x_values[min_cost_index], min_cost, color='#dc267f', s=30, label='Best Cost')
        plt.axhline(y=self.baseline, color='#ffb000', linestyle='--', linewidth=1)
        plt.xlabel('configuration') # or trial
        plt.ylabel('cost')
        plt.xticks([x_values[min_cost_index]], [str(378)]) # [str(x_values[min_cost_index])]
        plt.show()
        plt.savefig(file_plot_svg)
        plt.savefig(file_plot_png)

''' Example usage: '''
baseline_cost = 0 # baseline cost value
intensifier = "" # path to intensifier json file created by SMAC
files_directory = "" # directory to save the plots
file_plot_svg = "....svg" # path to save the svg plot
file_plot_png = "....png" # path to save the png plot
all_incumbents_vs_baseline = AllIncumbentsVsBaseline(baseline_cost, intensifier)
all_incumbents_vs_baseline.create_line_plot(os.path.join(files_directory, file_plot_svg), os.path.join(files_directory, file_plot_png))