# laura maria engist, 2025
# script to set the paramater values in the config file

from evotuner import constants
import os
import yaml
import numpy as np
from ruamel.yaml import YAML

class ConfigTraining():
    def __init__(self, config_file_for_training, output_dir_training, checkpoints_path, pos_weight, neg_weight, si_loss, ss_loss, mean_entropy, uniform_kl_loss, temperature):
        self.config_file_for_training = config_file_for_training
        self.output_dir_training = output_dir_training
        self.checkpoints_path = checkpoints_path
        self.config = {}
        self.pos_weight = pos_weight
        self.neg_weight = neg_weight
        self.si_loss = si_loss
        self.ss_loss = ss_loss
        self.mean_entropy = mean_entropy
        self.uniform_kl_loss = uniform_kl_loss
        self.temperature = temperature

        self.yaml = YAML()
        self.yaml.preserve_quotes = True
        self.yaml.indent(mapping=2, sequence=4, offset=2)
        self.load_config_file()

    '''
    load config file if it exists, otherwise create empty config
    '''
    def load_config_file(self):
        if os.path.exists(constants.CONFIG_FILE_LMHEAD):
            with open(constants.CONFIG_FILE_LMHEAD, "r", encoding="utf-8") as file:
                self.config = self.yaml.load(file) or {}
                print(f"self.config: {self.config}")
        else:
            self.config = {}
    
    '''
    save config file with updated values
    '''
    def save_config_file(self):
        print(f"save config file with config: {self.config}")
        with open(self.config_file_for_training, "w", encoding="utf-8") as file:
            self.yaml.dump(self.config, file)

    ''' 
    update config file with new values
    '''
    def update_config_file(self):
        self.config['model']['init_args']['model']['init_args']['loss_weights']['pos_weight'] = float(self.pos_weight)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['neg_weight'] = float(self.neg_weight)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['si_loss'] = float(self.si_loss)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['ss_loss'] = float(self.ss_loss)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['mean_entropy'] = float(self.mean_entropy)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['uniform_kl_loss'] = float(self.uniform_kl_loss)
        self.config['model']['init_args']['model']['init_args']['temperature'] = float(self.temperature)

        # update output_dir parts in config file
        self.config['model']['init_args']['output_dir'] = self.output_dir_training
        self.config['trainer']['callbacks'][0]['init_args']['dirpath'] = self.checkpoints_path + '/'

        # update project name
        self.config['trainer']['logger']['init_args']['project'] = "HypOptHD"

        # update stuff that it runs on the cluster
        self.config['data']['init_args']['num_workers'] = 0
        self.config['data']['init_args']['batch_size'] = 16