# laura maria engist, 2025
# script to run the VQ-VAE pipeline with SMAC for hyperparameter optimization

import multiprocessing as mp

def main():
    import os
    import numpy as np
    from ConfigSpace import Configuration, ConfigurationSpace, CategoricalHyperparameter, EqualsCondition
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster
    from smac import BlackBoxFacade, Scenario
    from smac.callback import Callback
    from evotuner import check_error_log
    from evotuner.pipelines.vqvae_pipeline import nnp_vqvae
    import sys
    import gc
    import psutil
    import os

    ''''
        Class to optimize the VQ-VAE pipeline with SMAC
    '''
    class NnpVqvaeSmac(object):
        '''
        Define the configuration space for the VQ-VAE pipeline
        '''
        @property
        def configspace(self) -> ConfigurationSpace:
            cs = ConfigurationSpace(seed=0)
            values = [round(x * 0.01, 2) for x in range(101)]  # 0.0 to 1.0 in steps of 0.01
            values_margin = np.round(np.arange(0.001, 10.001, 0.5), 3) # min: 0.001 then steps of 0.5 until 10
            margin = CategoricalHyperparameter("margin", choices=values_margin) 
            cnE_euc = CategoricalHyperparameter("cnE_euc", choices=values)
            cnD_euc = CategoricalHyperparameter("cnD_euc", choices=values)
            cnQ_euc = CategoricalHyperparameter("cnQ_euc", choices=values)
            fs_euc = CategoricalHyperparameter("fs_euc", choices=values)
            rcP_euc = CategoricalHyperparameter("rcP_euc", choices=values)
            us = CategoricalHyperparameter("us", choices=values)
            dv = CategoricalHyperparameter("dv", choices=values)
            en = CategoricalHyperparameter("en", choices=values)
            cb_euc = CategoricalHyperparameter("cb_euc", choices=values)
            em_euc = CategoricalHyperparameter("em_euc", choices=values)

            # parameters deciding whether a certain loss is used or set to 0
            use_cnE_euc = CategoricalHyperparameter("use_cnE_euc", [0, 1])
            use_cnD_euc = CategoricalHyperparameter("use_cnD_euc", [0, 1])
            use_cnQ_euc = CategoricalHyperparameter("use_cnQ_euc", [0, 1])
            use_fs_euc = CategoricalHyperparameter("use_fs_euc", [0, 1])
            use_rcP_euc = CategoricalHyperparameter("use_rcP_euc", [0, 1])
            use_us = CategoricalHyperparameter("use_us", [0, 1])
            use_dv = CategoricalHyperparameter("use_dv", [0, 1])
            use_en = CategoricalHyperparameter("use_en", [0, 1])
            use_cb_euc = CategoricalHyperparameter("use_cb_euc", [0, 1])
            use_em_euc = CategoricalHyperparameter("use_em_euc", [0, 1])

            cs.add([use_cnE_euc,use_cnD_euc, use_cnQ_euc, use_fs_euc, use_rcP_euc, use_us, use_dv, use_en, use_cb_euc, use_em_euc])
            cs.add([margin, cnE_euc, cnD_euc, cnQ_euc, fs_euc, rcP_euc, us, dv, en, cb_euc, em_euc])

            # Conditional: if use is 1, the corresponding parameter is set to a value != 0
            cond_use_cnE_euc = EqualsCondition(cnE_euc, use_cnE_euc, 1)
            cond_use_cnD_euc = EqualsCondition(cnD_euc, use_cnD_euc, 1)
            cond_use_cnQ_euc = EqualsCondition(cnQ_euc, use_cnQ_euc, 1)
            cond_use_fs_euc = EqualsCondition(fs_euc, use_fs_euc, 1)
            cond_use_rcP_euc = EqualsCondition(rcP_euc, use_rcP_euc, 1)
            cond_use_us= EqualsCondition(us, use_us, 1)
            cond_use_dv = EqualsCondition(dv, use_dv, 1)
            cond_use_en = EqualsCondition(en, use_en, 1)
            cond_use_cb_euc = EqualsCondition(cb_euc, use_cb_euc, 1)
            cond_use_em_euc = EqualsCondition(em_euc, use_em_euc, 1)

            cs.add(cond_use_cnE_euc, cond_use_cnD_euc, cond_use_cnQ_euc, cond_use_fs_euc, cond_use_rcP_euc, cond_use_us, cond_use_dv, cond_use_en, cond_use_cb_euc,cond_use_em_euc)
            
            print(cs)
            return cs

        ''''
        Train the VQ-VAE pipeline with the given configuration
        '''
        def train(self, config: Configuration, config_id, seed: int = 0) -> float:
            current_dir_name = os.path.basename(os.getcwd())
            self.log_memory("START trial")
            
            try:
                nnp_vqvae_process = nnp_vqvae.NnVqvaePipeline(config, config_id, current_dir_name)
                cost = nnp_vqvae_process.run_nn_vqvae_pipeline()
                self.print_and_flush(f"to smac returned cost: {cost}")
                if not np.isfinite(cost):
                    self.print_and_flush(f"Invalid result detected: {cost} \n [TRAIN END] result: not finite")
                    gc.collect()
                    self.log_memory("END trial (after cleanup)")
                    return 1.0
                self.print_and_flush(f"[TRAIN END] result: {cost}")
                gc.collect()
                self.log_memory("END trial (after cleanup)")

                return cost
            except Exception as e:
                self.print_and_flush(f"[CRASHED CONFIG] {config} → Exception: {e} \n [TRAIN END] result: exception")
                gc.collect()
                self.log_memory("END trial (after cleanup)")
                return 1.0
        
        ''' 
        Safe train function to catch exceptions and log GPU info
        '''
        def safe_train(self, config: Configuration, seed: int = 0) -> float:
            import torch, socket
            print("Running on GPU?", torch.cuda.is_available())
            if torch.cuda.is_available():
                print("Device name:", torch.cuda.get_device_name(0))
            print("Worker hostname:", socket.gethostname())
            print("CUDA available:", torch.cuda.is_available())

            config_id = config.config_id
            print(f"CONFIG_ID: {config_id}")
            sys.stdout.flush()
            self.print_and_flush(f"CONFIG_ID: {config_id}")
            try:
                return self.train(config, config_id, seed)
            except Exception as e:
                self.print_and_flush(f"[CRASH] Safe fallback for config {config}: {e}")
                gc.collect()
                return 1.0
        
        '''
        Log memory usage with a tag
        '''
        def log_memory(self, tag=""):
            process = psutil.Process(os.getpid())
            mem = process.memory_info().rss / (1024 ** 3)  # in GB
            self.print_and_flush(f"[{tag}] Memory usage: {mem:.2f} GB")

        ''' 
        Print and flush output
        Makes sure that output is not lost when using Dask
        '''
        def print_and_flush(self, text):
            print(text)
            sys.stdout.flush()

    model = NnpVqvaeSmac()

    n_workers = 5 # set number of workers to parallelize on
    scenario = Scenario(
        configspace=model.configspace, 
        deterministic=True, 
        n_trials=1000, # set numbe of trials
        n_workers=n_workers, # to trigger parallel evaluation
    ) 

    current_dir_name = os.path.basename(os.getcwd())

    job_script_prologue = [
        'source /etc/profile.d/soft_stacks.sh',
        'module purge',
        'module load CUDA/12.4.0',
        'export PATH=$HOME/miniforge3/bin:$PATH',
        'source activate alphabeta'
    ]

    job_extra_directives = [
        '--qos=gpu1week',
        #'--reservation=TODO', # if you want to use a reservation
        '--time=7-00:00:00',
        '--gpus-per-task=a100:1'
    ]
    cluster = SLURMCluster(
        job_name=f"vqvae_hypopthd_{current_dir_name}",
        queue="a100",
        cores=1,
        memory="60 GB",
        processes=1,
        log_directory="loggs",
        job_script_prologue=job_script_prologue,
        job_extra_directives=job_extra_directives,
        nanny=False,
    )
    cluster.scale(jobs=n_workers)

    client = Client(cluster)
    client.wait_for_workers(n_workers)

    # test GPU access on all workers
    def test_gpu():
        import torch
        return torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"

    print("GPU check on Dask workers:")
    print(client.run(test_gpu))

    # Optimize with SMAC
    smac = BlackBoxFacade(
        scenario=scenario,
        target_function=model.safe_train,
        dask_client=client,
        overwrite=False
    )
    smac._runner._patience = 120 # wait that workers really are up
    incumbent = smac.optimize()

    # Get cost of default configuration
    default_cost = smac.validate(model.configspace.get_default_configuration())
    print(f"Default cost: {default_cost}")

    # Cost of best found configuration
    incumbent_cost = smac.validate(incumbent)
    print(f"Incumbent cost: {incumbent_cost}")

    # check error logg
    check_error_log = check_error_log.CheckErrorLog(os.getcwd())
    check_error_log.alert_me()

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main()
