# laura maria engist, 2025
# script to conduct experiments for gap penalties in alignment quality and identification
import os

import alignment_quality
from evotuner.gap_penalties_alignment_quality_or_identification import alignment_identification
from evotuner.pipelines.kmeans_pipeline import kmp_cluster_bisecting
from evotuner.pipelines.kmeans_pipeline import kmp_convert_gta_cma
from evotuner import constants
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from evotuner import cma
import json

class GapPenaltiesQualAlnQualIdent:
    def __init__(self, resulting_dictionaries_file):
        self.resulting_dictionaries_file = resulting_dictionaries_file

    ''''
    methods for each model/pipeline
    each method computes the qualityAln and qualityIdent dictionaries for the respective model/pipeline
    and saves them in the resulting_dictionaries_file specified when initializing the class
    '''
    def compute_aap_build(self, aa_sequences, directory_experiment, matrix_file_path):
        cma_process = cma.CMA(constants.GTA_S_AA, constants.GTA_S_AA, constants.DISTANCE_THRESHOLD, constants.TM_SCORE_THRESHOLD, matrix_file_path)
        min_value_matrix, max_value_matrix = cma_process.compute_matrix_build(matrix_file_path, aa_sequences)
        qualityAln_dictionary, qualityIdent_dictionary = self.compute_quality_dictionaries(max_value_matrix, matrix_file_path, aa_sequences, directory_experiment, False)
        self.save_resulting_dictionary("AAP BUILD - qualityAln", self.dictionary_to_array(qualityAln_dictionary, max_value_matrix), qualityAln_dictionary)
        self.save_resulting_dictionary("AAP BUILD - qualityIdent", self.dictionary_to_array(qualityIdent_dictionary, max_value_matrix), qualityIdent_dictionary)

    def compute_aap_no_build(self, aa_sequences, directory_experiment):
        matrix_file = "/hypopthd/constants/blosum62.out"
        max_value_matrix = 11
        qualityAln_dictionary, qualityIdent_dictionary = self.compute_quality_dictionaries(max_value_matrix, matrix_file, aa_sequences, directory_experiment, False)
        self.save_resulting_dictionary("AAP NO BUILD - qualityAln", self.dictionary_to_array(qualityAln_dictionary, max_value_matrix), qualityAln_dictionary)
        self.save_resulting_dictionary("AAP NO BUILD - qualityIdent", self.dictionary_to_array(qualityIdent_dictionary, max_value_matrix), qualityIdent_dictionary)

    def compute_kmp(self, init, normalize, aa_sequences, directory_experiment):
        matrix_file, max_value_matrix, fasta_embedded_sequences = self.compute_kmeans(init, normalize, aa_sequences, directory_experiment)
        qualityAln_dictionary, qualityIdent_dictionary = self.compute_quality_dictionaries(max_value_matrix, matrix_file, aa_sequences, directory_experiment, False)
        self.save_resulting_dictionary(f"KMP {init} {normalize} - qualityAln", self.dictionary_to_array(qualityAln_dictionary, max_value_matrix), qualityAln_dictionary)
        self.save_resulting_dictionary(f"KMP {init} {normalize} - qualityIdent", self.dictionary_to_array(qualityIdent_dictionary, max_value_matrix), qualityIdent_dictionary)

    def compute_3di(self, directory_experiment):
        matrix_file = "/hypopthd/constants/mat3di.out"
        max_value_matrix = 10
        fasta_embedded_sequences = "/hypopthd/constants/3di.fasta"
        qualityAln_dictionary, qualityIdent_dictionary = self.compute_quality_dictionaries(max_value_matrix, matrix_file, fasta_embedded_sequences, directory_experiment, True)
        self.save_resulting_dictionary("3di - qualityAln", self.dictionary_to_array(qualityAln_dictionary, max_value_matrix), qualityAln_dictionary)
        self.save_resulting_dictionary("3di - qualityIdent", self.dictionary_to_array(qualityIdent_dictionary, max_value_matrix), qualityIdent_dictionary)
    
    def compute_vqvae(self, fasta_embedded_sequences_for_alignments, fasta_embedded_sequences, sequences_for_alignments, matrix_file, max_value_matrix, directory_experiment):
        self.create_embedded_fasta_for_aln(fasta_embedded_sequences, sequences_for_alignments, fasta_embedded_sequences_for_alignments)
        qualityAln_dictionary, qualityIdent_dictionary = self.compute_quality_dictionaries(max_value_matrix, matrix_file, fasta_embedded_sequences_for_alignments, directory_experiment, True)
        self.save_resulting_dictionary("VQ-VAE - qualityAln", self.dictionary_to_array(qualityAln_dictionary, max_value_matrix), qualityAln_dictionary)
        self.save_resulting_dictionary("VQ-VAE - qualityIdent", self.dictionary_to_array(qualityIdent_dictionary, max_value_matrix), qualityIdent_dictionary)

    def compute_lmhead(self, fasta_embedded_sequences_for_alignments, fasta_embedded_sequences, sequences_for_alignments, matrix_file, max_value_matrix, directory_experiment):
        self.create_embedded_fasta_for_aln(fasta_embedded_sequences, sequences_for_alignments, fasta_embedded_sequences_for_alignments)
        qualityAln_dictionary, qualityIdent_dictionary = self.compute_quality_dictionaries(max_value_matrix, matrix_file, fasta_embedded_sequences_for_alignments, directory_experiment, True)
        self.save_resulting_dictionary("LM-head - qualityAln", self.dictionary_to_array(qualityAln_dictionary, max_value_matrix), qualityAln_dictionary)
        self.save_resulting_dictionary("LM-head - qualityIdent", self.dictionary_to_array(qualityIdent_dictionary, max_value_matrix), qualityIdent_dictionary)

    ''''
    helper methods 
    '''
    def compute_kmeans(self, init, normalize, sequences, directory_experiments):
        clustering = kmp_cluster_bisecting.KmpClusterBisecting(constants.PROTT5_EMBEDDINGS, directory_experiments)
        e = clustering.e(sequences, init, normalize)
        fasta_embedded_sequences = os.path.join(directory_experiments, "embedded_sequences.fasta")
        convert_and_cma = kmp_convert_gta_cma.KmpConvertCma(directory_experiments, "")
        matrix_file = os.path.join(directory_experiments, f"kmeans_{init}_{normalize}.out")
        matrix_file, min_value_matrix, max_value_matrix, embedded_sequences = convert_and_cma.gta_and_m(sequences, e, matrix_file)
        return matrix_file, max_value_matrix, fasta_embedded_sequences
        
    def compute_quality_dictionaries(self, max_value_matrix, matrix_file, fasta_embedded_sequences, aln_file_dir, remap):
        qualityAln_dictionary = dict()
        qualityIdent_dictionary = dict()
        possible_penalty_values = self.get_possible_penalty_values(max_value_matrix)

        for go in possible_penalty_values:
            for ge in possible_penalty_values:
                if go > ge:
                    #### qualityAln ####
                    aln_file_qualityAln = os.path.join(aln_file_dir, f"{go}_{ge}_qualityAln.m8")
                    aq = alignment_quality.AlignmentQuality(go, ge, matrix_file, fasta_embedded_sequences)
                    pa, alignments_used, alignments_existing = aq.compute_alignment_quality(aln_file_qualityAln, remap)
                    if alignments_existing:
                        sensitivity = round(pa[1],3)
                        precision = round(pa[2],3)
                        quality = (sensitivity + precision) / 2
                        cost = 1 - quality
                        qualityAln_dictionary[(go, ge)] = cost
                    else:
                        qualityAln_dictionary[(go, ge)] = 400.0

                    #### qualityIdent ####
                    aln_file_qualityIdent = os.path.join(aln_file_dir, f"{go}_{ge}_qualityIdent.m8")
                    rocx_file_qualityIdent = os.path.join(aln_file_dir, f"{go}_{ge}_qualityIdent.rocx")
                    alignment_identification_process = alignment_identification.AlnIdent(go, ge, matrix_file, max_value_matrix, aln_file_qualityIdent, rocx_file_qualityIdent, fasta_embedded_sequences)
                    cost = alignment_identification_process.cost()
                    qualityIdent_dictionary[(go, ge)] = cost
                else:
                    qualityAln_dictionary[(go, ge)] = 400.0
                    qualityIdent_dictionary[(go, ge)] = 400.0
        
        return qualityAln_dictionary, qualityIdent_dictionary

    def get_possible_penalty_values(self, max_value_matrix):
        all_possible_penalty_values = []
        max_value_matrix = int(round(max_value_matrix))
        for i in range(1, max_value_matrix + 1):
            all_possible_penalty_values.append(i)
        return all_possible_penalty_values
    
    def dictionary_to_array(self, dictionary_to_convert, maximum_m):
        print(f"dictionary_to_convert: {dictionary_to_convert}")
        maximum_m = int(round(maximum_m))
        grid_array = [[0 for _ in range(maximum_m + 1)] for _ in range(maximum_m + 1)]
        for description in dictionary_to_convert:
            go = description[0]
            ge = description[1]
            value = dictionary_to_convert[description]
            grid_array[go][ge] = float(value) 
        print(grid_array)
        return grid_array
    
    def save_resulting_dictionary(self, model_name, grid, dictionary):
        with open(self.resulting_dictionaries_file, "a") as f:
            f.write(f"################ RESULT FOR MODEL: {model_name} ################\n")
            f.write(f"grid: {grid}\n")
            f.write(f"quality_dictionary: {dictionary}\n")
            f.write(f"###########################################################\n")
    
    def create_embedded_fasta_for_aln(self, embedded_sequences, sequences_for_alignments, path_sequences_for_alignments_embedded):
        sequences_for_alignments_embedded = []
        sequences = []
        with open(sequences_for_alignments, "r") as f:
            data = json.load(f)
            for entry in data:
                sequences.append(entry)
        with open(embedded_sequences, "r") as f:
            for record in SeqIO.parse(embedded_sequences, "fasta"):
                if record.id in sequences:
                    sequences_for_alignments_embedded.append(SeqRecord(record.seq, id=record.id, description=""))
        
        SeqIO.write(sequences_for_alignments_embedded, path_sequences_for_alignments_embedded, "fasta")
    
    def save_sequences_needed_for_alignments(self, alignment_file, sequences_file):
        sequences_for_alignments = []
        with open(alignment_file, "r") as f:
            data = json.load(f)
            for entry in data:
                if entry['description1'] not in sequences_for_alignments:
                    sequences_for_alignments.append(entry['description1'])
                if entry['description2'] not in sequences_for_alignments:
                    sequences_for_alignments.append(entry['description2'])
        with open(sequences_file, "w") as f:
            json.dump(sequences_for_alignments,f)

'''uncomment and specify the parameters to run the experiments'''
#directory_experiment = "" # specify the directory of the experiment you want to run here TODO: change for each experiment
text_file_to_store_dictionaries = "" # specify the text file to store the resulting dictionaries TODO: set once
aa_sequences = "" # amino acid sequences in fasta format TODO: set once
directory_experiment = "" # directory of the experiment you want to run here TODO: change for each experiment
matrix_file_path_aap_build = "" # path to store the matrix computed in the aap build pipeline TODO: set once
fasta_embedded_sequences_for_alignments_vqvae = "" # path to store the embedded sequences for the alignments in the vq-vae pipeline TODO: set once
fasta_embedded_sequences_for_alignments_lmhead = "" # path to store the embedded sequences for the alignments in the lm-head pipeline TODO: set once
fasta_embedded_sequences_vqvae = "" # path to store the embedded sequences in the vq-vae pipeline TODO: set once
fasta_embedded_sequences_lmhead = "" # path to store the embedded sequences in the lm-head pipeline TODO: set once
sequences_for_alignments = "" # path to the sequences for the alignments in the amino acid alphabet TODO: set once
matrix_file_vqvae = "" # path to the matrix computed in the vq-vae pipeline TODO: set once
matrix_file_lmhead = "" # path to the matrix computed in the lm-head pipeline TODO: set once
max_value_matrix_vqvae = 0 # max value of the matrix computed in the vq-vae pipeline TODO: set once
max_value_matrix_lmhead = 0 # max value of the matrix computed in the lm-head pipeline TODO: set once
gap_pen_qualAl_qualId = GapPenaltiesQualAlnQualIdent(text_file_to_store_dictionaries)
gap_pen_qualAl_qualId.compute_aap_build(aa_sequences, directory_experiment, matrix_file_path_aap_build) # amino acid pipeline: compute matrix
gap_pen_qualAl_qualId.compute_aap_no_build(aa_sequences, directory_experiment) # amino acid pipeline: use existing matrix
gap_pen_qualAl_qualId.compute_kmp('k-means++', False, aa_sequences, directory_experiment) # k-means pipeline: k-means++ init, no normalization
gap_pen_qualAl_qualId.compute_kmp('k-means++', True, aa_sequences, directory_experiment) # k-means pipeline: k-means++ init, with normalization
gap_pen_qualAl_qualId.compute_kmp('random', False, aa_sequences, directory_experiment) # k-means pipeline: random init, no normalization
gap_pen_qualAl_qualId.compute_kmp('random', True, aa_sequences, directory_experiment) # k-means pipeline: random init, with normalization
gap_pen_qualAl_qualId.compute_3di(directory_experiment) # 3Di alphabet from foldseek
gap_pen_qualAl_qualId.compute_vqvae(fasta_embedded_sequences_for_alignments_vqvae, fasta_embedded_sequences_vqvae, sequences_for_alignments, matrix_file_vqvae, max_value_matrix_vqvae, directory_experiment) # vq-vae pipeline
gap_pen_qualAl_qualId.compute_lmhead(fasta_embedded_sequences_for_alignments_lmhead, fasta_embedded_sequences_lmhead, sequences_for_alignments, matrix_file_lmhead, max_value_matrix_lmhead, directory_experiment) # lm-head pipeline


'''run once to save the sequences needed for the alignments - only needed if you want to create others than in the directory constants '''
#path_json_file_for_alignments = "" # TODO: specify the path to store the alignments
#path_json_file_for_sequences = "" # TODO: specify the path to store the sequences
#gap_pen_qualAl_qualId.save_sequences_needed_for_alignments(path_json_file_for_alignments, path_json_file_for_sequences)


       