# laura maria engist, 2025
# script to prepare what is needed to evaluate the results of an experiment
# start this script after the experiment has finished via the command line

import json
import sys
import re
import os
import matplotlib.pyplot as plt

directory_experiment = sys.argv[1] # path to directory of the experiment
directory_output_files = sys.argv[2] # path to directory with all output files
runhistory = directory_output_files + "/runhistory.json" # path to runhistory file
runhistory_converted = re.sub(".json", "_conv.json", runhistory) # path to converted runhistory file
intensifier = directory_output_files + "/intensifier.json"

''''
Convert runhistory json file to a format to work with
'''
with open(runhistory, "r") as f:
    data = json.load(f)

    configs = data.get("configs", {})
    config_origins = data.get("config_origins", {})

    converted_data = {
        "stats": {
            "submitted": data["stats"]["submitted"],
            "finished": data["stats"]["finished"],
            "running": data["stats"]["running"]
        },
        "data": [],
        "configs": configs,
        "config_origins": config_origins 
    }

    for entry in data["data"]:
        converted_entry = [
            entry["config_id"],
            entry.get("instance", None),
            entry["seed"],
            entry.get("budget", None),
            entry["cost"],
            entry["time"],
            entry["status"],
            entry["starttime"],
            entry["endtime"],
            entry["additional_info"]
        ]
        converted_data["data"].append(converted_entry)
            
    with open(runhistory_converted, "w") as f:
        json.dump(converted_data, f, indent=4)

'''
Get all configurations from runhistory json file
'''
def get_runhistory_all_configs():
    runhistory_all_configs = dict()
    with open(runhistory, "r") as f:
        data = json.load(f)
        for entry in data["data"]:
            id = entry["config_id"]
            cost = entry["cost"]
            if cost < 1.0:
                runhistory_all_configs[id] = cost
    return runhistory_all_configs

''' 
Get id of the final incumbent from intensifier json file
'''
def get_id_incumbent():
    with open(intensifier, "r") as f:
        data = json.load(f)
        for entry in data["incumbent_ids"]:
            return entry

'''
Create all incumbents manually from runhistory json file
'''
def create_all_incumbents_manually():
    incumbent_cost_so_far = 1.0
    all_incumbents = dict()
    data = get_runhistory_all_configs()
    for id in data:
        if data[id] < incumbent_cost_so_far:
            all_incumbents[id] = data[id]
            incumbent_cost_so_far = data[id]
    print(f"all incumbents: {all_incumbents}")
    return all_incumbents

'''
Create scatter or 3D plot of the exploration of the parameter space
dimension: 2 or 3
colormapping: True or False
'''
def create_plot(go, ge, build, costs, dimension, colormapping):
    if dimension == 2 and colormapping == False:
        plt.figure(figsize=(10, 7))
        plt.scatter(ge, go, s=50)
        plt.xlabel("ge")
        plt.xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        plt.ylabel("go")
        plt.yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        plt.title("Exploration of the Parameter Space - gap-penatlies")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties.svg"), bbox_inches='tight', pad_inches=0.2)
    elif dimension == 3 and colormapping == False:
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
        sc = ax.scatter(ge, go, build, s=50)
        ax.set_xlabel("ge")
        ax.set_xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_ylabel("go")
        ax.set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_zlabel("build")
        ax.set_zticks([0.0, 1.0])
        ax.set_zticklabels(["False", "True"])
        ax.set_title("Exploration of the Parameter Space - all parameters")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space.svg"), bbox_inches='tight', pad_inches=0.2)
    elif dimension == 2 and colormapping == True:
        plt.figure(figsize=(10, 7))
        sc = plt.scatter(ge, go, c=costs, cmap='plasma', s=50)
        cb = plt.colorbar(sc, pad=0.1)
        cb.set_label("Cost")
        plt.xlabel("ge")
        plt.ylabel("go")
        plt.title("Exploration of the Parameter Space - gap-penatlies")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties_colormap.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties_colormap.svg"), bbox_inches='tight', pad_inches=0.2)
    elif dimension == 3 and colormapping == True:
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
        sc = ax.scatter(ge, go, build, c=costs, cmap='viridis', s=50, alpha=0.9)
        cb = plt.colorbar(sc, ax=ax, pad=0.1)
        cb.set_label("Cost")
        ax.set_xlabel("ge")
        ax.set_xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_ylabel("go")
        ax.set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_zlabel("build")
        ax.set_zticks([0.0, 1.0])
        ax.set_zticklabels(["False", "True"])
        ax.set_title("Exploration of the Parameter Space - all parameters")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_colormap.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_colormap.svg"), bbox_inches='tight', pad_inches=0.2)

'''
Create graph for all configurations or all incumbents
file_type: "svg" or "png"
type: "all" or "incumbents"
'''
def create_graph(file_type, type):
        data = dict()
        title = ""
        file_name = ""
        plt.clf()
        if type == 'all':
            data = get_runhistory_all_configs()
            title = "Cost Over Trials - All Configurations"
            file_name = "cost_over_trials_all_configurations." + file_type
        elif type == 'incumbents':
            print("incumbents")
            data = create_all_incumbents_manually()
            title = "Cost Over Trials - All Intermediate Incumbents"
            file_name = "cost_over_trials_all_intermediate_incumbents." + file_type

        if type == 'all':
            x_sorted, y_sorted = zip(*sorted(data.items()))
            plt.plot(x_sorted, y_sorted, color='#dc267f' ,linewidth=1, markersize=8)
        elif type == 'incumbents':
            # fill in values in between such that it creates a stair and only stops at end of run
            number_trials = get_number_trials()
            new_coordinates = get_complete_data_for_incumbents_graph(sorted(data.items()), number_trials)
            x_sorted_new, y_sorted_new = zip(*new_coordinates)
            plt.plot(x_sorted_new, y_sorted_new, color='#dc267f' ,linewidth=1, markersize=8)  # optional: marker='o' zeigt Punkte an
        plt.xlabel('Trials')
        plt.ylabel('Cost')
        plt.title(title)

        # mark final incumbment
        x_mark = get_id_incumbent()
        y_mark = data[x_mark]
        plt.plot(x_mark, y_mark, 'bo')

        plt.axvline(x=x_mark, color='#648fff', linestyle='--', linewidth=1)

        plt.grid(True)
        plt.savefig(os.path.join(directory_experiment, file_name))
        return 0

''''
Get number of trials from runhistory json file
'''
def get_number_trials():
    with open(runhistory, "r") as f:
        data = json.load(f)
        statistics = data["stats"]
        number_finished = statistics["finished"]
    return number_finished

'''
Get complete data for incumbents graph such that it creates a stair and only stops at end of run
'''
def get_complete_data_for_incumbents_graph(coordinates, number_trials):
    new_coordinates = []
    previous_incumbent_cost = 1.0
    new_coordinates.append((0,1))
    for x,y in coordinates:
        new_coordinates.append((x, previous_incumbent_cost))
        new_coordinates.append((x, y))
        previous_incumbent_cost = y
    new_coordinates.append((number_trials, previous_incumbent_cost))
    return new_coordinates

'''
Create all plots
'''
def create_all_plots():
    with open(runhistory, "r") as f:
        data = json.load(f)

        configs = data.get("configs", {})
        print(configs)
        data_with_costs = data.get("data", {})
        print(data_with_costs)

        go = []
        ge = []
        build = []
        for config_nr in configs:
            config = configs[config_nr]
            print(config)
            go.append(config['go'])
            ge.append(config['ge'])
            if config['build'] == 'True':
                build.append(1.0)
            else:
                build.append(0.0)
        
        cost = []
        for d in data_with_costs:
            cost.append(d['cost'])
            

        create_plot(go, ge, build, cost, 3, False)
        create_plot(go, ge, build, cost, 2, False)
        create_plot(go, ge, build, cost, 3, True)
        create_plot(go, ge, build, cost, 2, True)
        
create_graph("svg", 'incumbents')
create_graph("svg", 'all')
create_graph("png", 'incumbents')
create_graph("png", 'all')
create_all_plots()