# laura maria engist, 2025
# script to run the LM-head pipeline with SMAC for hyperparameter optimization

import multiprocessing as mp

def main():
    import os
    import numpy as np
    from ConfigSpace import Configuration, ConfigurationSpace, CategoricalHyperparameter, EqualsCondition
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster
    from smac import BlackBoxFacade, Scenario
    from smac.callback import Callback
    from hypopthd import check_error_log
    from hypopthd.pipelines.lmhead_pipeline import nnp_lmhead
    import sys
    import gc
    import psutil
    import os

    ''''
        Class to optimize the LM-head pipeline with SMAC
    '''
    class NnpLmHeadSmac(object):
        
        '''
        Define the configuration space for the LM-head pipeline
        '''
        @property
        def configspace(self) -> ConfigurationSpace:
            cs = ConfigurationSpace(seed=0)
            values = [round(x * 0.01, 2) for x in range(101)]  # 0.0 to 1.0 in steps of 0.01
            pos_weight = CategoricalHyperparameter("pos_weight", choices=values)
            neg_weight = CategoricalHyperparameter("neg_weight", choices=values)
            si_loss = CategoricalHyperparameter("si_loss", choices=values)
            ss_loss = CategoricalHyperparameter("ss_loss", choices=values)
            mean_entropy = CategoricalHyperparameter("mean_entropy", choices=values)
            uniform_kl_loss = CategoricalHyperparameter("uniform_kl_loss", choices=values)
            values_temperature = [round(x * 0.01, 2) for x in range(1, 101)]
            temperature = CategoricalHyperparameter("temperature", choices=values_temperature)

            # parameters deciding whether a certain loss is used or set to 0
            use_pos_weight = CategoricalHyperparameter("use_pos_weight", [0, 1])
            use_neg_weight = CategoricalHyperparameter("use_neg_weight", [0, 1])
            use_si_loss = CategoricalHyperparameter("use_si_loss", [0, 1])
            use_ss_loss = CategoricalHyperparameter("use_ss_loss", [0, 1])
            use_mean_entropy = CategoricalHyperparameter("use_mean_entropy", [0, 1])
            use_uniform_kl_loss = CategoricalHyperparameter("use_uniform_kl_loss", [0, 1])
            use_temperature = CategoricalHyperparameter("use_temperature", [0, 1])

            cs.add([use_pos_weight, use_neg_weight, use_si_loss, use_ss_loss, use_mean_entropy, use_uniform_kl_loss, use_temperature])
            cs.add([pos_weight, neg_weight, si_loss, ss_loss, mean_entropy, uniform_kl_loss, temperature])

            # Conditional: if use is 1, the corresponding parameter is set to a value != 0
            cond_use_pos_weight = EqualsCondition(pos_weight, use_pos_weight, 1)
            cond_use_neg_weight = EqualsCondition(neg_weight, use_neg_weight, 1)
            cond_use_si_loss = EqualsCondition(si_loss, use_si_loss, 1)
            cond_use_ss_loss= EqualsCondition(ss_loss, use_ss_loss, 1)
            cond_use_mean_entropy = EqualsCondition(mean_entropy, use_mean_entropy, 1)
            cond_use_uniform_kl_loss = EqualsCondition(uniform_kl_loss, use_uniform_kl_loss, 1)
            cond_use_temperature = EqualsCondition(temperature, use_temperature, 1)

            cs.add(cond_use_pos_weight, cond_use_neg_weight, cond_use_si_loss, cond_use_ss_loss, cond_use_mean_entropy, cond_use_uniform_kl_loss, cond_use_temperature)

            print(cs)
            return cs

        ''''
        Train the LM-head pipeline with the given configuration
        '''
        def train(self, config: Configuration, config_id, seed: int = 0) -> float:
            current_dir_name = os.path.basename(os.getcwd())
            self.log_memory("START trial")
            
            try:
                nnp_lmhead_process = nnp_lmhead.NnLMheadPipeline(config, config_id, current_dir_name)
                cost = nnp_lmhead_process.run_nn_lmhead_pipeline()
                self.print_and_flush(f"to smac returned cost: {cost}")
                if not np.isfinite(cost):
                    self.print_and_flush(f"Invalid result detected: {cost} \n [TRAIN END] result: not finite")
                    gc.collect()
                    self.log_memory("END trial (after cleanup)")
                    return 1.0
                self.print_and_flush(f"[TRAIN END] result: {cost}")
                gc.collect()
                self.log_memory("END trial (after cleanup)")

                return cost
            except Exception as e:
                self.print_and_flush(f"[CRASHED CONFIG] {config} → Exception: {e} \n [TRAIN END] result: exception")
                gc.collect()
                self.log_memory("END trial (after cleanup)")
                return 1.0
        
        '''
        Safe train function that catches exceptions and returns a high cost in case of failure
        '''
        def safe_train(self, config: Configuration, seed: int = 0) -> float:
            import torch, socket
            print("Running on GPU?", torch.cuda.is_available())
            if torch.cuda.is_available():
                print("Device name:", torch.cuda.get_device_name(0))
            print("Worker hostname:", socket.gethostname())
            print("CUDA available:", torch.cuda.is_available())

            config_id = config.config_id
            print(f"CONFIG_ID: {config_id}")
            sys.stdout.flush()
            self.print_and_flush(f"CONFIG_ID: {config_id}")
            try:
                return self.train(config, config_id, seed)
            except Exception as e:
                self.print_and_flush(f"[CRASH] Safe fallback for config {config}: {e}")
                gc.collect()
                return 1.0
        
        '''
        Log memory usage
        '''
        def log_memory(self, tag=""):
            process = psutil.Process(os.getpid())
            mem = process.memory_info().rss / (1024 ** 3)  # in GB
            self.print_and_flush(f"[{tag}] Memory usage: {mem:.2f} GB")

        ''' 
        Print and flush output
        Makes sure that output is not lost when using Dask
        '''
        def print_and_flush(self, text):
            print(text)
            sys.stdout.flush()

    model = NnpLmHeadSmac()

    n_workers = 5 # set number of workers to parallelize on
    scenario = Scenario(
        configspace=model.configspace, 
        deterministic=True, 
        n_trials=1000, # set numbe of trials
        n_workers=n_workers, # to trigger parallel evaluation
    ) 

    current_dir_name = os.path.basename(os.getcwd())

    job_script_prologue = [
        'source /etc/profile.d/soft_stacks.sh',
        'module purge',
        'module load CUDA/12.4.0',
        'export PATH=$HOME/miniforge3/bin:$PATH',
        'source activate alphabeta'
    ]

    job_extra_directives = [
        '--qos=gpu1week',
        #'--reservation=TODO', # add reservation if needed
        '--time=7-00:00:00',
        '--gpus-per-task=a100:1'
    ]
    cluster = SLURMCluster(
        job_name=f"lmhead_hypopthd_{current_dir_name}",
        queue="a100",
        cores=1,
        memory="60 GB",
        processes=1,
        log_directory="loggs",
        job_script_prologue=job_script_prologue,
        job_extra_directives=job_extra_directives,
        nanny=False,
    )
    cluster.scale(jobs=n_workers)

    client = Client(cluster)
    client.wait_for_workers(n_workers)

    # test GPU access on all workers
    def test_gpu():
        import torch
        return torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"

    print("GPU check on Dask workers:")
    print(client.run(test_gpu))

    # Optimize with SMAC
    smac = BlackBoxFacade(
        scenario=scenario,
        target_function=model.safe_train,
        dask_client=client,
        overwrite=False
    )
    smac._runner._patience = 120 # wait that workers really are up
    incumbent = smac.optimize()

    # Get cost of default configuration
    default_cost = smac.validate(model.configspace.get_default_configuration())
    print(f"Default cost: {default_cost}")

    # Cost of best found configuration
    incumbent_cost = smac.validate(incumbent)
    print(f"Incumbent cost: {incumbent_cost}")

    # check error logg
    check_error_log = check_error_log.CheckErrorLog(os.getcwd())
    check_error_log.alert_me()

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main()
