# laura maria engist, 2025
# The LM-head Pipeline
# varying / coming from SMAC: loss weights and their parameters for use or not

from evotuner.pipelines import nnp_lmhead_config
import os
from evotuner import constants
import subprocess
import sys
from evotuner import get_best_cost_per_trial
import torch

class NnLMheadPipeline:
    def __init__(self, config, config_id, current_dir_name):
        self.go = 12 # fix from cpu pipeline
        self.ge = 1 # fix from cpu pipeline

        if config["use_pos_weight"] == 1:
            self.pos_weight = config["pos_weight"]
        else:
            self.pos_weight = 0

        if config["use_neg_weight"] == 1:
            self.neg_weight = config["neg_weight"]
        else:
            self.neg_weight = 0
        
        if config["use_si_loss"] == 1:
            self.si_loss = config["si_loss"]
        else:
            self.si_loss = 0
        
        if config["use_ss_loss"] == 1:
            self.ss_loss = config["ss_loss"]
        else:
            self.ss_loss = 0
        
        if config["use_mean_entropy"] == 1:
            self.mean_entropy = config["mean_entropy"]
        else:
            self.mean_entropy = 0
        
        if config["use_uniform_kl_loss"] == 1:
            self.uniform_kl_loss = config["uniform_kl_loss"]
        else:
            self.uniform_kl_loss = 0
        
        if config["use_temperature"] == 1:
            self.temperature = config["temperature"]
        else:
            self.temperature = 0.01

        self.config_id = config_id
        self.current_dir_name = current_dir_name
   
    '''
    run the LM-head pipeline
    '''
    def run_nn_lmhead_pipeline(self):
        s = constants.S_AA
        job_name = f"laen_nnp_{self.current_dir_name}_{self.config_id}"
        # create config file
        path_for_this_config_output = os.path.join(os.getcwd(), "logs", "wandb", f"{self.config_id}")
        os.makedirs(path_for_this_config_output)
        checkpoints_path = os.path.join(path_for_this_config_output, "checkpoints")
        os.makedirs(checkpoints_path)
        config_file = os.path.join(path_for_this_config_output, f"{job_name}.yaml")
        
        config_training = nnp_lmhead_config.ConfigTraining(config_file, path_for_this_config_output, checkpoints_path, self.pos_weight, self.neg_weight, self.si_loss, self.ss_loss, self.mean_entropy, self.uniform_kl_loss, self.temperature)
        config_training.update_config_file()
        config_training.save_config_file()

        # check the environment
        print("🔍 Environment Check:")
        print("Hostname:", subprocess.getoutput("hostname"))
        print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
        print("Python:", subprocess.getoutput("which python"))
        print("Python version:", subprocess.getoutput("python --version"))
        print("CUDA Info:\n", subprocess.getoutput("nvidia-smi"))
        print("SLURM Job ID:", os.environ.get("SLURM_JOB_ID"))
        print("CUDA available:", torch.cuda.is_available())
        print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

        # set environment variables
        environment = os.environ.copy()
        environment["HF_DATASETS_OFFLINE"] = "1"
        environment["TRANSFORMERS_OFFLINE"] = "1"
        command = ["python", "/scicore/home/schwede/engist0000/HypOptHD/PALM/main.py", "fit", "-c", config_file, f"--trainer.logger.name={job_name}"] # "srun", 
        print(f"COMMAND: {command}")
        print(f"ENVIRONMENT: {environment}")
        try:
            subprocess.run(
                command,
                env=environment,
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            print("Threads in process:", subprocess.getoutput("ps -Lf $$ | wc -l"))
        except subprocess.CalledProcessError as e:
            print("❌ Subprocess failed!")
            print("👉 Command:", e.cmd)
            print("👉 Return code:", e.returncode)
            print("👉 STDOUT:\n", e.stdout)
            print("👉 STDERR:\n", e.stderr)
            raise
        
        nnp_get_best_cost = get_best_cost_per_trial.NnpGetBestCostPerTrial(path_for_this_config_output) # read all rocx files and return best cost from all epochs
        epoch, cost, trial_number = nnp_get_best_cost.get_cost(path_for_this_config_output, self.config_id) # can also save or return the epoch with the best cost

        return cost
