# laura maria engist, 2025
# script to create a visualization of the parameter importance regarding all incumbents

import json

class UseInIncumbents:
    def __init__(self, runhistory, intensifier, nn, file_plot_svg, file_plot_png):
        self.runhistory = runhistory
        self.intensifier = intensifier
        self.nn = nn
        self.file_plot_svg = file_plot_svg
        self.file_plot_png = file_plot_png

        if nn == "vq_vae":
            print("ratio for vq-vae:")
            ordered_ratio = self.count_use_lm_head = self.count_use_vq_vae()
            self.create_bar_plot_for_ratio(ordered_ratio)
        elif nn == "lm_head":
            print("ratio for lm-head:")
            ordered_ratio = self.count_use_lm_head = self.count_use_lm_head()
            self.create_bar_plot_for_ratio(ordered_ratio)

    ''''
    Get all incumbents from intensifier json file
    '''
    def get_incumbents(self):
        incumbents = []
        with open(self.intensifier, "r") as f:
            data = json.load(f)
            for entry in data["trajectory"]:
                incumbents.append(entry["config_ids"][0])
        return incumbents
    
    ''''
     Count how often each parameter was used in all incumbents for the LM-head model (use_... = 1)
     Return the ratio of use for each parameter
    7 parameters:
    - use_pos_weight
    - use_neg_weight
    - use_si_loss
    - use_ss_loss
    - use_mean_entropy
    - use_uniform_kl_loss
    - use_temperature
    '''
    def count_use_lm_head(self):
        count_use_pos_weight = 0
        count_use_neg_weight = 0
        count_use_si_loss = 0
        count_use_ss_loss = 0
        count_use_mean_entropy = 0
        count_use_uniform_kl_loss = 0
        count_use_temperature = 0
        incumbents = self.get_incumbents()
        count_incumbents = len(incumbents)
        for incumbent in incumbents:
            config = self.get_config(incumbent)
            if config["use_pos_weight"] == 1:
                count_use_pos_weight += 1
            if config["use_neg_weight"] == 1:
                count_use_neg_weight += 1
            if config["use_si_loss"] == 1:
                count_use_si_loss += 1
            if config["use_ss_loss"] == 1:
                count_use_ss_loss += 1
            if config["use_mean_entropy"] == 1:
                count_use_mean_entropy += 1
            if config["use_uniform_kl_loss"] == 1:
                count_use_uniform_kl_loss += 1
            if config["use_temperature"] == 1:
                count_use_temperature += 1
        ratio_per_use_parameter = {
            "pos": count_use_pos_weight / count_incumbents,
            "neg": count_use_neg_weight / count_incumbents,
            "si": count_use_si_loss / count_incumbents,
            "ss": count_use_ss_loss / count_incumbents,
            "en": count_use_mean_entropy / count_incumbents,
            "us": count_use_uniform_kl_loss / count_incumbents,
            "temp": count_use_temperature / count_incumbents
        }
        ordered_ratio_per_use_parameter = self.order_dic_decreasingly(ratio_per_use_parameter)
        print(ordered_ratio_per_use_parameter)
        return ordered_ratio_per_use_parameter
    
    ''''
     Count how often each parameter was used in all incumbents for the VQ-VAE model (use_... = 1)
     Return the ratio of use for each parameter
    10 parameters:
    - use_cnE_euc
    - use_cnD_euc
    - use_cnQ_euc
    - use_fs_euc
    - use_rcP_euc
    - use_us
    - use_dv
    - use_en
    - use_cb_euc
    - use_em_euc
    '''
    def count_use_vq_vae(self):
        count_use_cnE_euc = 0
        count_use_cnD_euc = 0
        count_use_cnQ_euc = 0
        count_use_fs_euc = 0
        count_use_rcP_euc = 0
        count_use_us = 0
        count_use_dv = 0
        count_use_en = 0
        count_use_cb_euc = 0
        count_use_em_euc = 0
        incumbents = self.get_incumbents()
        count_incumbents = len(incumbents)
        for incumbent in incumbents:
            config = self.get_config(incumbent)
            if config["use_cnE_euc"] == 1:
                count_use_cnE_euc += 1
            if config["use_cnD_euc"] == 1:
                count_use_cnD_euc += 1
            if config["use_cnQ_euc"] == 1:
                count_use_cnQ_euc += 1
            if config["use_fs_euc"] == 1:
                count_use_fs_euc += 1
            if config["use_rcP_euc"] == 1:
                count_use_rcP_euc += 1
            if config["use_us"] == 1:
                count_use_us += 1
            if config["use_dv"] == 1:
                count_use_dv += 1
            if config["use_en"] == 1:
                count_use_en += 1
            if config["use_cb_euc"] == 1:
                count_use_cb_euc += 1
            if config["use_em_euc"] == 1:
                count_use_em_euc += 1
        ratio_per_use_parameter = {
            "cnE": count_use_cnE_euc / count_incumbents,
            "cnD": count_use_cnD_euc / count_incumbents,
            "cnQ": count_use_cnQ_euc / count_incumbents,
            "fs": count_use_fs_euc / count_incumbents,
            "rcP": count_use_rcP_euc / count_incumbents,
            "us": count_use_us / count_incumbents,
            "dv": count_use_dv / count_incumbents,
            "en": count_use_en / count_incumbents,
            "cb": count_use_cb_euc / count_incumbents,
            "em": count_use_em_euc / count_incumbents
        }
        ordered_ratio_per_use_parameter = self.order_dic_decreasingly(ratio_per_use_parameter)
        print(ordered_ratio_per_use_parameter)
        return ordered_ratio_per_use_parameter

    ''''
    Get configuration from runhistory json file for a given config_id
    '''
    def get_config(self, config_id):
        with open(self.runhistory, "r") as f:
            data = json.load(f)
            return data["configs"][str(config_id)]
    
    ''''
     Order dictionary decreasingly by its values
    '''
    def order_dic_decreasingly(self, dic):
        return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True))
    
    ''''
     Create bar plot for the ratio of use for each parameter
    '''
    def create_bar_plot_for_ratio(self, ratio_per_use_parameter):
        import matplotlib.pyplot as plt
        import numpy as np

        labels = list(ratio_per_use_parameter.keys())
        values = list(ratio_per_use_parameter.values())

        y_pos = np.arange(len(labels))

        plt.bar(y_pos, values, align='center', alpha=0.5, color='#648fff')
        plt.xticks(y_pos, labels, rotation=45)
        plt.ylabel('Fraction used in Incumbents')
        #plt.title(f'Parameter Importance for {self.nn}')

        plt.tight_layout()
        plt.show()
        plt.savefig(self.file_plot_svg)
        plt.savefig(self.file_plot_png)

''' Example usage '''
path_to_json_runhistory = "" # path to runhistory json file created by SMAC
path_to_json_intensifier = "" # path to intensifier json file created by SMAC
file_plot_png = "" # path to save png file of the plot
file_plot_svg = "" # path to save svg file of the plot
model_name = "" # "lm_head" or "vq_vae"
use_in_incumbents = UseInIncumbents(path_to_json_runhistory, path_to_json_intensifier, model_name, file_plot_svg, file_plot_png)
