# laura maria engist, 2025
# The VQ-VAE Pipeline
# varying / coming from SMAC: loss weights and their parameters for use or not

from evotuner.pipelines.vqvae_pipeline import nnp_vqvae_config
import os
from evotuner import constants
import subprocess
import sys
from evotuner import get_best_cost_per_trial
import torch

class NnVqvaePipeline:
    def __init__(self, config, config_id, current_dir_name):
        self.go = 12 # fix from cpu pipeline
        self.ge = 1 # fix from cpu pipeline

        self.margin = config["margin"]
        if config["use_cnE_euc"] == 1:
            self.cnE_euc = config["cnE_euc"]
        else:
            self.cnE_euc = 0
        
        if config["use_cnD_euc"] == 1:
            self.cnD_euc = config["cnD_euc"]
        else:
            self.cnD_euc = 0
        
        if config["use_cnQ_euc"] == 1:
            self.cnQ_euc = config["cnQ_euc"]
        else:
            self.cnQ_euc = 0
        
        if config["use_fs_euc"] == 1: 
            self.fs_euc = config["fs_euc"]
        else:
            self.fs_euc = 0
        
        if config["use_rcP_euc"] == 1:
            self.rcP_euc = config["rcP_euc"]
        else:
            self.rcP_euc = 0

        if config["use_us"] == 1:
            self.us = config["us"]
        else:
            self.us = 0
        
        if config["use_dv"] == 1:
            self.dv = config["dv"]
        else:
            self.dv = 0
        
        if config["use_en"] == 1:
            self.en = config["en"]
        else:
            self.en = 0
        
        if config["use_cb_euc"] == 1:
            self.cb_euc = config["cb_euc"]
        else:
            self.cb_euc = 0
        
        if config["use_em_euc"] == 1:
            self.em_euc = config["em_euc"]
        else:
            self.em_euc = 0 

        self.config_id = config_id
        self.current_dir_name = current_dir_name
    
    '''
    run the VQ-VAE pipeline
    '''
    def run_nn_vqvae_pipeline(self):
        s = constants.S_AA
        job_name = f"laen_nnp_{self.current_dir_name}_{self.config_id}"
        # create config file
        path_for_this_config_output = os.path.join(os.getcwd(), "logs", "wandb", f"{self.config_id}")
        os.makedirs(path_for_this_config_output)
        checkpoints_path = os.path.join(path_for_this_config_output, "checkpoints")
        os.makedirs(checkpoints_path)
        config_file = os.path.join(path_for_this_config_output, f"{job_name}.yaml")
        
        config_training = nnp_vqvae_config.ConfigTraining(config_file, path_for_this_config_output, checkpoints_path, self.margin, self.cnE_euc, self.cnD_euc, self.cnQ_euc, self.fs_euc, self.rcP_euc, self.us, self.dv, self.en, self.cb_euc, self.em_euc)
        config_training.update_config_file()
        config_training.save_config_file()

        # check the environment
        print("🔍 Environment Check:")
        print("Hostname:", subprocess.getoutput("hostname"))
        print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
        print("Python:", subprocess.getoutput("which python"))
        print("Python version:", subprocess.getoutput("python --version"))
        print("CUDA Info:\n", subprocess.getoutput("nvidia-smi"))
        print("SLURM Job ID:", os.environ.get("SLURM_JOB_ID"))
        print("CUDA available:", torch.cuda.is_available())
        print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

        # set environment variables
        environment = os.environ.copy()
        environment["HF_DATASETS_OFFLINE"] = "1"
        environment["TRANSFORMERS_OFFLINE"] = "1"
        command = ["python", "/scicore/home/schwede/engist0000/HypOptHD/PALM/main.py", "fit", "-c", config_file, f"--trainer.logger.name={job_name}"]
        try:
            subprocess.run(
                command,
                env=environment,
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
        except subprocess.CalledProcessError as e:
            print("❌ Subprocess failed!")
            print("👉 Command:", e.cmd)
            print("👉 Return code:", e.returncode)
            print("👉 STDOUT:\n", e.stdout)
            print("👉 STDERR:\n", e.stderr)
            raise

        nnp_get_best_cost = get_best_cost_per_trial.NnpGetBestCostPerTrial(path_for_this_config_output) # read all rocx files and return best cost from all epochs
        epoch, cost, trial_number = nnp_get_best_cost.get_cost(path_for_this_config_output, self.config_id) # can also save or return the epoch with the best cost

        return cost
