# laura maria engist, 2025
# script to convert the ground truth alignments and compute the custom substitution matrix, given a fasta file of protein sequences
# two steps in one script to save computing time 
# A = convert_gta(GTA(S_AA), DCBV)
# M = CMA(A)

from evotuner import constants
from evotuner import cma
from Bio import SeqIO
import torch
import json
import os
import re

class KmpConvertCma:
    def __init__(self, text_dir, liblambda_shared_library_dir):
        self.text_dir = text_dir
        self.liblambda_shared_library_dir = liblambda_shared_library_dir
    
    '''
    map the gta to the embeddings and for each mapping directly compute counts for m
    s: path to the fasta file of protein sequences
    dcbv: path to the embeddings of the protein sequences in s
    m: path to the output custom substitution matrix
    '''
    def gta_and_m(self, s, dcbv, m):
        gta = constants.GTA_S_AA

        embedded_sequences_path = os.path.join(self.text_dir, "embedded_sequences.fasta")
        embedded_sequences = self.load_embeddings_in_var_from_pt(dcbv, embedded_sequences_path)
        cma_process = cma.CMA(gta, "", constants.DISTANCE_THRESHOLD, constants.TM_SCORE_THRESHOLD, m)
        aa_pair_counts = dict()
        aa_counts = dict()
        count_alignments = 0
        count_alignments_used = 0
        alphabet = constants.ALPHABET
        not_in_embeddings = []

        # compute embeddings and counts
        fam_dirs = os.listdir(gta)
        for fam in fam_dirs:
            if fam != "one_file_per_family":
                aln_dirs = os.listdir(os.path.join(gta, fam))
                for aln_file in aln_dirs:
                    with open(os.path.join(gta, fam, aln_file), "r") as aln:
                        if os.path.getsize(os.path.join(gta, fam, aln_file)) > 0:
                            data = json.load(aln)
                            # compute embedding
                            if data['description1'] not in embedded_sequences or data['description2'] not in embedded_sequences:
                                if data['description1'] not in embedded_sequences:
                                    if data['description1'] not in not_in_embeddings:
                                        not_in_embeddings.append(data['description1'])
                                if data['description2'] not in embedded_sequences:
                                    if data['description2'] not in not_in_embeddings:
                                        not_in_embeddings.append(data['description2'])
                            else:
                                embedded_line1 = self.map_embedded_sequence_to_aligned_embedded_sequence(embedded_sequences[data['description1']], data['line1'])
                                embedded_line2 = self.map_embedded_sequence_to_aligned_embedded_sequence(embedded_sequences[data['description2']], data['line2'])
                                # compute counts
                                if embedded_line1 is not None and embedded_line2 is not None:
                                    for a, b, dist in zip(embedded_line1, embedded_line2, data['distances']):
                                        count_alignments += 1
                                        if dist != None and dist < constants.DISTANCE_THRESHOLD and data["tm_score"] > constants.TM_SCORE_THRESHOLD:
                                            if a in alphabet:
                                                if a not in aa_counts:
                                                    aa_counts[a] = 1
                                                else:
                                                    aa_counts[a] += 1
                                            if b in alphabet:
                                                if b not in aa_counts:
                                                    aa_counts[b] = 1
                                                else:
                                                    aa_counts[b] += 1

                                        if a in alphabet and b in alphabet:
                                            pair_key = (min(a,b), max(a,b))
                                            if dist != None and dist < constants.DISTANCE_THRESHOLD and data["tm_score"] > constants.TM_SCORE_THRESHOLD: 
                                                count_alignments_used += 1
                                                if pair_key not in aa_pair_counts:
                                                    aa_pair_counts[pair_key] = 1
                                                else:
                                                    aa_pair_counts[pair_key] += 1
                                
        minimum, maximum = cma_process.compute_matrix_nnp(m, aa_pair_counts, aa_counts, count_alignments, count_alignments_used, self.liblambda_shared_library_dir, embedded_sequences, embedded_sequences_path)
        return m, minimum, maximum, embedded_sequences
    
    '''
     load the embeddings for all sequences s into the local variable and also write a fasta file with the sequences
    '''
    def load_embeddings_in_var(self, dcbv):
        embedded_sequences = dict()
        for record in SeqIO.parse(dcbv, "fasta"):
            embedded_sequences[record.id] = str(record.seq)
        return embedded_sequences

    ''' 
    load the embeddings for all sequences s into the local variable and also write a fasta file with the sequences 
    '''
    def load_embeddings_in_var_from_pt(self, e, fasta_path):
        embedded_sequences_num = torch.load(e,weights_only=False)
        embedded_sequences_char = dict()
        alphabet = constants.ALPHABET
        if (os.path.exists(fasta_path)):
            os.remove(fasta_path)
        for description in embedded_sequences_num:
            seq_char = ""
            for num in embedded_sequences_num[description]:
                seq_char += alphabet[num]
            embedded_sequences_char[description] = seq_char

            with open(fasta_path, "a") as f: 
                f.write(">" + description + "\n" + seq_char + "\n")

        return embedded_sequences_char

    ''' 
    Map the embedded sequence (without gaps) to the aligned embedded sequence (with gaps) 
    '''
    def map_embedded_sequence_to_aligned_embedded_sequence(self, embedded_sequence, aligned_sequence):
        index = 0
        aligned_embedded_sequence = ""
        aligned_without_gaps = re.sub("-", "", aligned_sequence)
        add_x = False
        if len(embedded_sequence) != len(aligned_without_gaps): # different lengths --> add X
            for char in aligned_sequence:
                if char == '-':
                    aligned_embedded_sequence += '-'
                elif char == 'X':
                    aligned_embedded_sequence += 'X'
                else:
                    aligned_embedded_sequence += embedded_sequence[index]
                    index += 1
        else:
            for char in aligned_sequence:
                if char == '-':
                    aligned_embedded_sequence += '-'
                else:
                    aligned_embedded_sequence += embedded_sequence[index]
                    index += 1
        return aligned_embedded_sequence