# laura maria engist, 2025
# script to calculate the auc for computed alignments

from evotuner import constants
import subprocess
import resource
import os
import sys

class AlnIdent:
    def __init__(self, go, ge, matrix_file, max_value_matrix, aln_file, rocx_file, sequences_to_align):
        self.go = go
        self.ge = ge
        self.matrix_file = matrix_file
        self.max_value_matrix = max_value_matrix
        self.aln_file = aln_file
        self.rocx_file = rocx_file
        self.sequences_to_align = sequences_to_align
    
    ''''
    methods to compute the cost for given gap open and gap extend penalties
    cost is defined as 1 - auc
    if no alignments are found, cost is set to 100.0
    if an error occurs during the computation, cost is set to 400.0
    '''
    def cost(self):
        try:
            alignments_existing = self.run_mmseqs()
            if alignments_existing:
                output_file = self.rocx_file
                self.benchmark_scop40(
                    self.aln_file, constants.AWK_FILE, constants.SCOP_LOOKUP_FILE, output_file
                )
                auc = self.get_auc_metrics()
                cost = 1.0 - auc
                return cost
            else:
                cost = 100.0
                auc = 0.0
                return cost
        except Exception as e:
            print(f"exception: {e}")
            return 400.0
    
    def run_mmseqs(self):
        command = [constants.MMSEQS, "easy-search", self.sequences_to_align, self.sequences_to_align, self.aln_file, "tmp"]
        command.append("--gap-open")
        command.append(str(self.go))
        command.append("--gap-extend")
        command.append(str(self.ge))
        if self.matrix_file != "":
            command.append("--sub-mat")
            command.append(self.matrix_file)
        print(f"command: {command}")

        MEM_LIMIT = 50 * 1024 * 1024 * 1024 # 50 GB
        try:
            proc = subprocess.Popen(command, preexec_fn=resource.setrlimit(resource.RLIMIT_AS, (MEM_LIMIT, MEM_LIMIT)))
            print("Subprozess PID:", proc.pid)
            sys.stdout.flush()
            proc.wait()
            print("Rückgabecode:", proc.returncode)
        except subprocess.CalledProcessError as e:
            print(f"failed with error: {e.returncode}; {e.output}; {e.stdout}; {e.stderr}")
            return False

        alignments_existing = False
        print(f"output_file: {self.aln_file}")
        if not os.path.exists(self.aln_file):
            print("❌ File was not created!")
            return False
        elif os.path.getsize(self.aln_file) == 0:
            print("⚠️ File was created, but is empty.")
        else:
            print("✅ File found!")
            alignments_existing = True

        return alignments_existing
    
    def benchmark_scop40(self, mmseqs_file, awk_file, scop_lookup_file, output_file):
        ps = subprocess.Popen(["cat", mmseqs_file], stdout=subprocess.PIPE)
        try:
            subprocess.check_call(
            ["awk", "-f", awk_file, scop_lookup_file, "-"],
            stdin=ps.stdout,
            stdout=open(output_file, "w"),
            )
        except subprocess.CalledProcessError as e:
            print(f"failed with error: {e.returncode}; {e.output}; {e.stdout}; {e.stderr}")
        ps.stdout.close()

    def get_auc_metrics(self):
        result = subprocess.run(
            [
                "awk",
                "{ famsum+=$3; supfamsum+=$4; foldsum+=$5}END{print famsum/NR,supfamsum/NR,foldsum/NR}",
                self.rocx_file,
            ],
            capture_output=True,
            text=True,
            check=True,
        )
        fam_auc, supfam_auc, fold_auc = map(float, result.stdout.strip().split())
        return (fam_auc + supfam_auc + fold_auc) / 3