# laura maria engist, 2025
# The Amino Acid Pipeline
# varying / coming from SMAC: go (gap-open penalty), ge (gap-extension penalty); build (if build matrix to use or take the default one)

import tempfile
import os
from evotuner import cma
from evotuner import aap_cost
from evotuner import constants
import sys

class AaPipeline():
    def __init__(self, go, ge, build, return_type):
        self.go = go
        self.ge = ge
        self.build = build
        self.return_type = return_type
    
    ''''
    method to run the amino acid pipeline
    go: gap open penalty
    ge: gap extend penalty
    build: if build matrix to use or take the default one
    return_type: type of AUC to return (all, sfam, fam, sfam_fam)
    '''
    def run_aa_pipeline(self):
        s = constants.S_AA
        if self.build:
            with tempfile.TemporaryDirectory() as tmpdir: # tmp directory for the duration of the process
                # build matrix M = CMA(A)
                text_dir = os.getcwd()
                m = os.path.join(tmpdir, "cma_matrix_aa.out")
                a = constants.GTA_S_AA
                cma_process = cma.CMA(a, a, constants.DISTANCE_THRESHOLD, constants.TM_SCORE_THRESHOLD, m)
                minimum, maximum = cma_process.compute_matrix_build(m, text_dir, s)

                # compute C = 1 - AUC(M, S, go, ge)
                aap_cost_process = aap_cost.AapCost(tmpdir, minimum, maximum, self.return_type)
                cost, auc = aap_cost_process.cost(m, s, self.go, self.ge)

                print(f"aap build - cost: {cost} with penalties: {self.go}, {self.ge}; and min_value: {minimum}, max_value: {maximum}")
                sys.stdout.flush()
                if cost == {}:
                    cost = 300.0
                    print(f"go: {self.go}; ge: {self.ge}")
                return cost, auc, minimum, maximum
        else:
            with tempfile.TemporaryDirectory() as tmpdir: # tmp directory for the duration of the process
                minimum = -4
                maximum = 11

                aap_cost_process = aap_cost.AapCost(tmpdir, minimum, maximum, self.return_type)
                m = "" # spaceholder because the default matrix is already implemented in mmseqs
                # compute C = 1 - AUC(M, S, go, ge)
                cost, auc = aap_cost_process.cost(m, s, self.go, self.ge)
                print(f"aap no build - cost: {cost} with penalties: {self.go}, {self.ge}; and min_value: {minimum}, max_value: {maximum}")
                if cost == {}:
                    cost = 300.0
                    print(f"go: {self.go}; ge: {self.ge}")
                return cost, auc, minimum, maximum