# laura maria engist, 2025
# script to plot the exploration of the parameter space

from mpl_toolkits.mplot3d import Axes3D  # aktiviert 3D-Modus
import matplotlib.pyplot as plt
import sys
import json
import os

runhistory = sys.argv[1] # path to runhistory file
directory_experiment = sys.argv[2]

''''
Create scatter plot of the exploration of the parameter space
'''
def create_plot(go, ge, build, costs, dimension, colormapping):
    if dimension == 2 and colormapping == False:
        plt.figure(figsize=(10, 7))
        plt.scatter(ge, go, s=50)
        plt.xlabel("ge")
        plt.xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        plt.ylabel("go")
        plt.yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        plt.title("Exploration of the Parameter Space - gap-penatlies")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties.svg"), bbox_inches='tight', pad_inches=0.2)
    elif dimension == 3 and colormapping == False:
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
        sc = ax.scatter(ge, go, build, s=50)
        ax.set_xlabel("ge")
        ax.set_xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_ylabel("go")
        ax.set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_zlabel("build")
        ax.set_zticks([0.0, 1.0])
        ax.set_zticklabels(["False", "True"])
        ax.set_title("Exploration of the Parameter Space - all parameters")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space.svg"), bbox_inches='tight', pad_inches=0.2)
    elif dimension == 2 and colormapping == True:
        plt.figure(figsize=(10, 7))
        sc = plt.scatter(ge, go, c=costs, cmap='plasma', s=50)
        cb = plt.colorbar(sc, pad=0.1)
        cb.set_label("Cost")
        plt.xlabel("ge")
        plt.ylabel("go")
        plt.title("Exploration of the Parameter Space - gap-penatlies")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties_colormap.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_gap_penalties_colormap.svg"), bbox_inches='tight', pad_inches=0.2)
    elif dimension == 3 and colormapping == True:
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
        sc = ax.scatter(ge, go, build, c=costs, cmap='viridis', s=50, alpha=0.9)
        cb = plt.colorbar(sc, ax=ax, pad=0.1)
        cb.set_label("Cost")
        ax.set_xlabel("ge")
        ax.set_xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_ylabel("go")
        ax.set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
        ax.set_zlabel("build")
        ax.set_zticks([0.0, 1.0])
        ax.set_zticklabels(["False", "True"])
        ax.set_title("Exploration of the Parameter Space - all parameters")
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_colormap.png"), bbox_inches='tight', pad_inches=0.2)
        plt.savefig(os.path.join(directory_experiment, "exploration_of_parameter_space_colormap.svg"), bbox_inches='tight', pad_inches=0.2)

''''
Get configurations and costs from runhistory json file
'''
with open(runhistory, "r") as f:
    data = json.load(f)

    configs = data.get("configs", {})
    print(configs)
    data_with_costs = data.get("data", {})
    print(data_with_costs)

    go = []
    ge = []
    build = []
    for config_nr in configs:
        config = configs[config_nr]
        print(config)
        go.append(config['go'])
        ge.append(config['ge'])
        if config['build'] == 'True':
            build.append(1.0)
        else:
            build.append(0.0)
    
    cost = []
    for d in data_with_costs:
        cost.append(d['cost'])

    create_plot(go, ge, build, cost, 3, False)
    create_plot(go, ge, build, cost, 2, False)
    create_plot(go, ge, build, cost, 3, True)
    create_plot(go, ge, build, cost, 2, True)

    