# laura maria engist, 2025
# The K-Means Pipeline
# C = 1 - Quality(M, S, go, ge)

import tempfile
import os
import subprocess
from evotuner import constants
import sys
import resource

class KmpCost():
    def __init__(self, tmpdir, minimum, maximum, return_type):
        self.tmpdir = tmpdir # temporary directory
        self.minimum = minimum
        self.maximum = maximum
        self.return_type = return_type
    
    '''
    cost function to be minimized by smac
    m: path to the custom substitution matrix
    s: path to the fasta file of protein sequences
    go: gap open penalty
    ge: gap extend penalty
    C = 1 - Q(M, S, go, ge)
    '''
    def cost(self, m, s, go, ge):
        mmseqs_output_file_scope = os.path.join(self.tmpdir, "alnResult.m8")
        try:
            alignments_existing, go_actual, ge_actual = self.run_mmseqs(s, mmseqs_output_file_scope, m, go, ge)
            if alignments_existing:
                output_file = os.path.join(self.tmpdir, "output_mmseqs_kmeans_build.rocx")
                self.benchmark_scop40(
                    mmseqs_output_file_scope, constants.AWK_FILE, constants.SCOP_LOOKUP_FILE, output_file
                )
                auc = self.get_auc_metrics(output_file, self.return_type)
                cost = 1.0 - auc
                print(f"RETURNED AUC: {auc}")
                print(f"RETURNED COST: {cost}")
                sys.stdout.flush()
                return cost, auc    
            else: # for some combinations and values no alignments can be build --> we return the highest cost to smac
                print("no alignments built")
                sys.stdout.flush()
                cost = 1.0
                return cost, 0
        except Exception as e:
            print(f"Error running mmseqs: {e}")
            sys.stdout.flush()
            return 1.0, 0
    
    '''
    method to run mmseqs with memory limit
    '''
    def run_mmseqs(self, fasta_file, output_file, mat_file, go, ge):
        #r = self.maximum - self.minimum
        #print(f"r: {r}")
        #go = round(go * r)
        #ge = round(ge * r)
        command = [constants.MMSEQS, "easy-search", fasta_file, fasta_file, output_file, "tmp"]
        command.append("--gap-open")
        command.append(str(go))
        command.append("--gap-extend")
        command.append(str(ge))
        if mat_file != "":
            command.append("--sub-mat")
            command.append(mat_file)
        print(f"command: {command}")

        MEM_LIMIT = 50 * 1024 * 1024 * 1024 # 50 GB
        try:
            #subprocess.check_call(command)
            proc = subprocess.Popen(command, preexec_fn=resource.setrlimit(resource.RLIMIT_AS, (MEM_LIMIT, MEM_LIMIT)))
            print("Subprozess PID:", proc.pid)
            sys.stdout.flush()
            proc.wait()
            print("return code:", proc.returncode)
        except subprocess.CalledProcessError as e:
            print(f"failed with error: {e.returncode}; {e.output}; {e.stdout}; {e.stderr}")
            return False, go, ge

        # check if alignments created
        alignments_existing = False
        print(f"output_file: {output_file}")
        if not os.path.exists(output_file):
            print("❌ File 'alnResult.m8' was not created!")
        elif os.path.getsize(output_file) == 0:
            print("⚠️ File 'alnResult.m8' was created, but is empty.")
        else:
            print("✅ File 'alnResult.m8' found!")
            alignments_existing = True

        return alignments_existing, go, ge

    '''
     benchmark the mmseqs output file against the scop40 benchmark set
     mmseqs_file: path to the mmseqs output file
     awk_file: path to the awk script to compute the AUC values from the mmseqs output file
     scop_lookup_file: path to the scop40 lookup file
     output_file: path to the output file to write the AUC values to
    '''
    def benchmark_scop40(self, mmseqs_file, awk_file, scop_lookup_file, output_file):
        print("NOW IN BENCHAMRK SCOP40")
        # Create a pipe to simulate process substitution
        ps = subprocess.Popen(["cat", mmseqs_file], stdout=subprocess.PIPE)
        # Run awk command using the pipe
        try:
            subprocess.check_call(
                ["awk", "-f", awk_file, scop_lookup_file, "-"],
                stdin=ps.stdout,
                stdout=open(output_file, "w"),
            )
        except subprocess.CalledProcessError as e:
            print(f"failed with error: {e.returncode}; {e.output}; {e.stdout}; {e.stderr}")
        # Close the pipe
        ps.stdout.close()
    
    '''
     Calculate and return AUC values
     output_file: path to the output file to read the AUC values from
     return_type: type of AUC value to return (all, sfam, fam, sfam_fam)
    '''
    def get_auc_metrics(self, output_file, return_type):
        # Calculate and return AUC values
        result = subprocess.run(
            [
                "awk",
                "{ famsum+=$3; supfamsum+=$4; foldsum+=$5}END{print famsum/NR,supfamsum/NR,foldsum/NR}",
                output_file,
            ],
            capture_output=True,
            text=True,
            check=True,
        )
        # Parse the output into floats
        print(f"output file: {output_file}")
        fam_auc, supfam_auc, fold_auc = map(float, result.stdout.strip().split())
        print(f"family_auc: {fam_auc}, superfamily_auc: {supfam_auc}, fold_auc: {fold_auc}")
        sys.stdout.flush()

        if return_type == 'all':
            return (fam_auc + supfam_auc + fold_auc) / 3 # return arithmetic mean
        elif return_type == 'sfam':
            return supfam_auc
        elif return_type == 'fam':
            return fam_auc
        elif return_type == 'sfam_fam':
            return (fam_auc + supfam_auc) / 2
        