# laura maria engist, 2025
# script to set the paramater values in the config file

from evotuner import constants
import os
import yaml
import numpy as np
from ruamel.yaml import YAML

class ConfigTraining():
    def __init__(self, config_file_for_training, output_dir_training, checkpoints_path, margin, cnE_euc, cnD_euc, cnQ_euc, fs_euc, rcP_euc, us, dv, en, cb_euc, em_euc):
        self.config_file_for_training = config_file_for_training
        self.output_dir_training = output_dir_training
        self.checkpoints_path = checkpoints_path
        self.config = {}
        self.margin = margin
        self.cnE_euc = cnE_euc
        self.cnD_euc = cnD_euc
        self.cnQ_euc = cnQ_euc 
        self.fs_euc = fs_euc
        self.rcP_euc = rcP_euc 
        self.us = us
        self.dv = dv
        self.en = en
        self.cb_euc = cb_euc
        self.em_euc = em_euc

        self.yaml = YAML()
        self.yaml.preserve_quotes = True
        self.yaml.indent(mapping=2, sequence=4, offset=2)
        self.load_config_file()

    '''
    load config file if it exists
    '''
    def load_config_file(self):
        if os.path.exists(constants.CONFIG_FILE_VQVAE):
            with open(constants.CONFIG_FILE_VQVAE, "r", encoding="utf-8") as file:
                self.config = self.yaml.load(file) or {}
                print(f"self.config: {self.config}")
        else:
            self.config = {}
    
    '''
    save config file
    '''
    def save_config_file(self):
        print(f"save config file with config: {self.config}")
        with open(self.config_file_for_training, "w", encoding="utf-8") as file:
            self.yaml.dump(self.config, file)

    '''
    update config file with new values
    '''
    def update_config_file(self):
        self.config['model']['init_args']['model']['init_args']['margin'] = float(self.margin)

        self.config['model']['init_args']['model']['init_args']['loss_weights']['cnE_euc'] = float(self.cnE_euc)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['cnD_euc'] = float(self.cnD_euc)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['cnQ_euc'] = float(self.cnQ_euc)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['fs_euc'] = float(self.fs_euc)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['rcP_euc'] = float(self.rcP_euc)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['us'] = float(self.us)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['dv'] = float(self.dv)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['en'] = float(self.en)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['cb_euc'] = float(self.cb_euc)
        self.config['model']['init_args']['model']['init_args']['loss_weights']['em_euc'] = float(self.em_euc)

        # update output_dir parts in config file
        self.config['model']['init_args']['output_dir'] = self.output_dir_training
        self.config['trainer']['callbacks'][0]['init_args']['dirpath'] = self.checkpoints_path + '/'

        # update project name
        self.config['trainer']['logger']['init_args']['project'] = "HypOptHD"

        # update stuff that it runs on the cluster
        self.config['data']['init_args']['num_workers'] = 0
        self.config['data']['init_args']['batch_size'] = 16