# laura maria engist, 2025
# Run as ./sbatch_train.sh <job_name> <config_file>

# Retrieve command name from the command line
JOB_NAME=$1
CONFIG_FILE=$2
mkdir -p logs/slurm
mkdir -p logs/wandb
mkdir -p data/lm_head_pfam

# Use a heredoc to create the script
cat << EOF | sbatch
#!/bin/bash
#SBATCH --job-name="$JOB_NAME"   
#SBATCH --nodes=1        
#SBATCH --gres=gpu:1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=32G
#SBATCH --partition=a100
#SBATCH --output=logs/slurm/"$JOB_NAME".out
#SBATCH --qos=gpu6hours

# Load modules and set environment
source /etc/profile.d/soft_stacks.sh
module purge
module load CUDA/12.4.0
export PATH=$HOME/miniforge3/bin:$PATH
source activate alphabeta

cp $CONFIG_FILE logs/wandb/$JOB_NAME.yaml
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 srun --cpu-bind=cores python /PALM/main.py fit -c logs/wandb/$JOB_NAME.yaml --trainer.logger.name="$JOB_NAME"
EOF