# laura maria engist, 2025
# script to return the best cost per trial run 'get_cost(...)' or for all trials 'get_best_cost_of_all_trials()'

import os
import subprocess
import sys

class NnpGetBestCostPerTrial:
    def __init__(self, wandb_directories_all_trials):
        self.wandb_directories_all_trials =  wandb_directories_all_trials
    
    '''
    Get the cost for a specific trial number
    rocx_files_dir: path to the directory containing the .rocx files
    trial_number: the trial number as string
    return: the epoch number with the best cost, the best cost and the trial number'''
    def get_cost(self, rocx_files_dir, trial_number):
        files = os.listdir(rocx_files_dir)
        costs = dict()
        print(f"files: {files}")
        for file in files:
            if '.rocx' in file:
                if '3di.rocx' not in file:
                    if 'aa.rocx' not in file:
                        print(f"for this file: {os.path.join(rocx_files_dir, file)}")
                        auc = self.get_auc_metrics(os.path.join(rocx_files_dir, file))
                        cost = 1 - auc
                        epoch_number = file[-6]
                        costs[epoch_number] = cost
        print(f"costs: {costs}")
        sys.stdout.flush()
        min_key, best_cost = self.get_min_cost(costs)
        return min_key, best_cost, trial_number
    
    '''
    Get the minimum cost from a dictionary of costs
    costs: dictionary with epoch number as key and cost as value
    return: the epoch number with the best cost and the best cost   
    '''
    def get_min_cost(self, costs):
        min_key = 0
        best_cost = 1
        for k in costs:
            if costs[k] < best_cost:
                min_key = k
                best_cost = costs[k]
        return min_key, best_cost
    
    '''
    Get the AUC metrics from a .rocx file
    output_file: path to the .rocx file
    return: the mean AUC value (mean of fam_auc, supfam_auc and fold_auc)
    '''
    def get_auc_metrics(self, output_file):
        result = subprocess.run(
            [
                "awk",
                "{ famsum+=$3; supfamsum+=$4; foldsum+=$5}END{print famsum/NR,supfamsum/NR,foldsum/NR}",
                output_file,
            ],
            capture_output=True,
            text=True,
            check=True,
        )
        fam_auc, supfam_auc, fold_auc = map(float, result.stdout.strip().split())
        print(f"AUC VALUES: fam_auc: {fam_auc}, supfam_auc: {supfam_auc}, fold_auc: {fold_auc}")
        sys.stdout.flush()
        auc_mean = (fam_auc + supfam_auc + fold_auc) / 3
        return auc_mean
    
    '''
    Get the best cost of all trials in the given directory
    return: the trial number and epoch number with the best cost and the best cost
    '''
    def get_best_cost_of_all_trials(self):
        dirs = os.listdir(self.wandb_directories_all_trials)
        best_costs_and_epochs = dict()
        for dir in dirs:
            if "debug-internal.log" not in dir and "debug.log" not in dir and "offline-run" not in dir and "latest" not in dir and "None" not in dir:
                print(f"dir: {dir}")
                min_key, cost, trial_number = self.get_cost(os.path.join(self.wandb_directories_all_trials, dir), dir)
                best_costs_and_epochs[(trial_number, min_key)] = cost
        return self.get_best_from_dict(best_costs_and_epochs)
    
    '''
    Get the best cost from a dictionary of costs
    best_costs_and_epochs: dictionary with (trial number, epoch number) as key and cost as value
    return: the trial number and epoch number with the best cost and the best cost
    '''
    def get_best_from_dict(self, best_costs_and_epochs):
        best_min_key = 0
        best_cost = 1
        for trial_min_key in best_costs_and_epochs:
            print(best_costs_and_epochs[trial_min_key])
            if best_costs_and_epochs[trial_min_key] == 0.693218:
                print(best_cost)
            if best_costs_and_epochs[trial_min_key] < best_cost:
                best_min_key = trial_min_key
                best_cost = best_costs_and_epochs[trial_min_key]
        return best_min_key, best_cost

''' example usage '''
#directory_wandb_files_of_experiment = ""
#nnp_get_cost = NnpGetBestCostPerTrial(directory_wandb_files_of_experiment)
#print(f"min cost: {nnp_get_cost.get_best_cost_of_all_trials()}")