# laura maria engist, 2025
# script to compute the average number of epochs per trial
# across all trials in a given experiment

import os

class AverageNumEpochsPerTrial:
    def __init__(self, experiment_wandb_logs):
        self.experiment_wandb_logs = experiment_wandb_logs
        self._total_epochs = 0
        self._total_trials = 0
    
    @property
    def total_epochs(self):
        if self._total_epochs == 0 and self._total_trials == 0:
            self.get_total_epochs_and_trials()
        return self._total_epochs
    
    @property
    def total_trials(self):
        if self._total_epochs == 0 and self._total_trials == 0:
            self.get_total_epochs_and_trials()
        return self._total_trials
    
    '''
    Get total number of epochs and trials from wandb logs directory
    1 epoch = 1 .out file in a trial directory = one matrix file
    1 trial = 1 trial directory
    1 trial directory = directory in wandb logs directory that is not 'latest-run' or 'offline'
    '''
    def get_total_epochs_and_trials(self):
        trial_dirs = [d for d in os.listdir(self.experiment_wandb_logs) if os.path.isdir(os.path.join(self.experiment_wandb_logs, d)) and 'latest-run' not in d and 'offline' not in d]
        for trial_dir in trial_dirs:
            trial_path = os.path.join(self.experiment_wandb_logs, trial_dir)
            mat_log_files = [f for f in os.listdir(trial_path) if f.endswith('.out')]
            for mat_log_file in mat_log_files:
                self._total_epochs += 1
            self._total_trials += 1
            print(f"Trial {trial_dir}: {len(mat_log_files)} epochs")

    '''
    Compute average number of epochs per trial
    '''
    def compute(self):
        average_epochs = self.total_epochs / self.total_trials
        return average_epochs

''' Example usage '''
path_to_wandb_directory = "" # path to wandb directory of an experiment - stores files created within the training
average_num_epochs_per_trial = AverageNumEpochsPerTrial(path_to_wandb_directory)
print(average_num_epochs_per_trial.compute())