# laura maria engist, 2025
# script to compute a customized substitution matrix
# compute_matrix_build for the Amino Acid Pipeline; compute_matrix_kmeans for the K-means Pipeline;
# compute_matrix_nnp for the Neural Network Pipelines VQ-VAE and LM-head; compute_matrix_aqp for the AQP Pipeline

import os
import numpy as np
import json
import sys
sys.path.append('/scicore/home/schwede/pantol0000/repositories/glyphtemys/glyphtemys')
#import models as models
import pandas as pd
from evotuner import constants
from collections import defaultdict, Counter
from evotuner.lambda_for_cma import lambda_for_cma
from Bio import SeqIO

class CMA():
    def __init__(self, path_structural_alignments, path_structural_alignments_after_embedding, distance_threshold, tm_score_threshold, m_b):
        self.path_structural_alignments = path_structural_alignments
        self.path_structural_alignments_after_embedding = path_structural_alignments_after_embedding # is same as path_structural_alignments if we use original amino acid alphabet
        self.distance_threshold = distance_threshold
        self.tm_score_threshold = tm_score_threshold
        self.code_book_size = 20
        self.alphabet = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y', 'X']
        self.m_b = m_b
        
        self._aa_pair_counts = None # pairwise frequencies
        self._aa_counts = None # counts per character 'A': 36 etc. 
        self._n_pairs = None # counts how many items are in dic _aa_pair_counts
        self._n_aa = None # sum of individual counts per character
        self._aa_pair_prob = None # pair probabilities
        self._aa_prob = None # per amino acid probabilities
        self._subst_score = None # scores for the matrix
        self._sum_aa_prob = None
        self._sum_aa_pair_prob = None
        self._count_alignments = None
        self._count_alignments_used = None
        self._matrix_array = None
        self._matrix_parasail = None # matrix as text to insert into file - format cma for parasail
        self._matrix_mmseqs = None # matrix as text to insert into file - format for mmseqs
        self._alginments_created = None
    
    @property
    def alginments_created(self):
        if self._alginments_created is None:
            if (os.path.isfile(os.path.join(self.path_structural_alignments_after_embedding, f)) for f in os.listdir(self.path_structural_alignments_after_embedding)):
                self._alginments_created = True
            else:
                self._alginments_created = False
        return self._alginments_created

    @property
    def aa_pair_counts(self):
        if self._aa_pair_counts is None:
            self.compute_counts()
        return self._aa_pair_counts

    @property
    def aa_counts(self):
        if self._aa_counts is None:
            self.compute_counts()
        return self._aa_counts    

    @property
    def n_aa(self):
        if self._n_aa is None:
            self._n_aa = sum(self.aa_counts.values())
        return self._n_aa

    @property
    def n_pairs(self):
        if self._n_pairs is None:
            self._n_pairs = sum(self.aa_pair_counts.values())
        return self._n_pairs

    @property
    def aa_prob(self):
        if self._aa_prob is None:
            self._aa_prob = {k: v/self.n_aa for k,v in self.aa_counts.items()}
        return self._aa_prob
    
    @property
    def aa_pair_prob(self):
        if self._aa_pair_prob is None:
            self._aa_pair_prob = {k: v/self.n_pairs for k,v in self.aa_pair_counts.items()}
        return self._aa_pair_prob

    @property
    def subst_score(self):
        if self._subst_score is None:
            #self._subst_score = {k: 2*np.log2(v/(self.aa_prob[k[0]]*self.aa_prob[k[1]])) for k,v in self.aa_pair_prob.items()} # with factor 2
            self._subst_score = {k: np.log2(v/(self.aa_prob[k[0]]*self.aa_prob[k[1]])) for k,v in self.aa_pair_prob.items()} # without factor 2
        return self._subst_score
    
    @property
    def matrix_mmseqs(self):
        if self._matrix_mmseqs is None:
            self._matrix_mmseqs = self.compute_matrix()
        return self._matrix_mmseqs

    @property
    def matrix_array(self):
        if self._matrix_array is None:
            self.create_matrix_array(self.subst_score.items(),4) # was at 0
        return self._matrix_array
    
    @property
    def sum_aa_prob(self):
        if self._sum_aa_prob is None:
            self._sum_aa_prob = 0
            for k,v in self.aa_counts.items():
                self._sum_aa_prob += v/self.n_aa
        return self._sum_aa_prob
    
    @property
    def sum_aa_pair_prob(self):
        if self._sum_aa_pair_prob is None:
            self._sum_aa_pair_prob = 0
            for k,v in self.aa_pair_counts.items():
                self._sum_aa_pair_prob += v/self.n_pairs
        return self._sum_aa_pair_prob
    
    @property 
    def count_alignments(self):
        if self._count_alignments is None:
            self.compute_counts()
        return self._count_alignments
    
    @property 
    def count_alignments_used(self):
        if self._count_alignments_used is None:
            self.compute_counts()
        return self._count_alignments_used

    '''
    method to compute the counts per amino acid pair
    '''
    def compute_counts(self):
        self._aa_pair_counts = dict()
        self._aa_counts = dict()
        self._count_alignments = 0
        self._count_alignments_used = 0
        family_dirs = os.listdir(self.path_structural_alignments_after_embedding)
        for fdir in family_dirs:
            if fdir != "one_file_per_family":
                aln_files = os.listdir(os.path.join(self.path_structural_alignments_after_embedding, fdir))
                for af in aln_files:
                    with open(os.path.join(self.path_structural_alignments_after_embedding, fdir, af), 'r') as fh:
                        if os.path.getsize(os.path.join(self.path_structural_alignments_after_embedding, fdir, af)) > 0:
                            data = json.load(fh)
                            for a,b, dist in zip(data["line1"], data["line2"], data["distances"]):
                                self._count_alignments += 1
                                if dist != None and dist < self.distance_threshold and data["tm_score"] > self.tm_score_threshold:
                                    if a in self.alphabet:
                                        if a not in self._aa_counts:
                                            self._aa_counts[a] = 1
                                        else:
                                            self._aa_counts[a] += 1
                                    if b in self.alphabet:
                                        if b not in self._aa_counts:
                                            self._aa_counts[b] = 1
                                        else:
                                            self._aa_counts[b] += 1

                                if a in self.alphabet and b in self.alphabet:
                                    pair_key = (min(a,b), max(a,b))
                                    if dist != None and dist < self.distance_threshold and data["tm_score"] > self.tm_score_threshold: 
                                        self._count_alignments_used += 1
                                        if pair_key not in self._aa_pair_counts:
                                            self._aa_pair_counts[pair_key] = 1
                                        else:
                                            self._aa_pair_counts[pair_key] += 1
                        else:
                            print(f"empty: {os.path.join(fdir, af)}")
        print(f"aa pair counts: {self.aa_pair_counts}")
        print(f"aa counts: {self.aa_counts}")
        print(f"count alignments: {self.count_alignments}")
        print(f"count_alignments_used: {self.count_alignments_used}")
    
    '''
    method to compute the values of the matrix
    '''
    def create_matrix_array(self, scores, round_n):
        print("CREATE MATRIX ARRAY")
        size_of_alphabet = len(self.alphabet)
        matrix_array = [ [0] * size_of_alphabet for _ in range(size_of_alphabet)]
        for k, v in scores: # self.subst_score.items()
            index1 = self.get_index_of_character(k[0])
            index2 = self.get_index_of_character(k[1])
            if round_n > 0:
                matrix_array[index1][index2] = round(v,round_n)
                matrix_array[index2][index1] = round(v,round_n)
            else:
                matrix_array[index1][index2] = round(v)
                matrix_array[index2][index1] = round(v)
        self._matrix_array = matrix_array
        print(matrix_array)
        return matrix_array

    '''
    method to get the index of a character in the alphabet
    '''
    def get_index_of_character(self, character):
        for index, c in enumerate(self.alphabet):
            if c == character:
                return index
    
    '''
    method to read sequences from fasta file
    '''
    def get_sequences_from_fasta(self, fasta_file):
        sequences = {}
        with open(fasta_file, "r") as f:
            for record in SeqIO.parse(f, "fasta"):
                sequences[record.id] = str(record.seq)
        return sequences

    '''
    method to create the matrix file for the Amino Acid pipeline in the format needed for mmseqs
    '''
    def compute_matrix_build(self, m, sequences_path):
        sequences = self.get_sequences_from_fasta(sequences_path)
        matrix_array = self.matrix_array
        minimum, maximum = self.get_min_max(matrix_array)
        row_labels = self.alphabet
        df = pd.DataFrame(matrix_array, index=row_labels, columns=row_labels)
        alphabet = list("ACDEFGHIKLMNPQRSTVWYX")

        frequencies = self.calculate_background_frequencies(sequences, alphabet)
        lambda_val = 0.34657
        #lambda_val = self.calculate_lambda(matrix_array, text_dir)
        with open(m, "w") as f:
            # Write header
            f.write("# HypOptHD CMA AA Build\n")
            f.write(
                "# Background (precomputed optional): "
                + " ".join(f"{freq:.4f}" for freq in frequencies)
                + "\n"
            )
            f.write(
                "# Lambda     (precomputed optional): " + f"{lambda_val:.4f}" + "\n"
            )

            # Write amino acid header row
            f.write("   " + "       ".join(self.alphabet) + "\n")

            # Write matrix rows
            for i, aa in enumerate(self.alphabet):
                row = [f"{aa} "]
                row.extend(f"{matrix_array[i][j]:.4f} " for j in range(len(self.alphabet)))
                f.write("".join(row) + "\n")

        return minimum, maximum+1
    
    '''
    method to get the minimum and maximum value of the matrix
    '''
    def get_min_max(self, matrix_array):
        current_min = 0
        current_max = 0
        for row in matrix_array:
            for c in row:
                if c < current_min:
                    current_min = c
                if c > current_max:
                    current_max = c
        return current_min, current_max
    
    '''
    method to compute matrix for kmeans
    '''
    def compute_matrix_kmeans(self, m, aa_pair_counts, aa_counts, count_alignments, count_alignments_used, text_dir, sequences_path):
        sequences = self.get_sequences_from_fasta(sequences_path)
        n_aa = sum(aa_counts.values())
        n_pairs = sum(aa_pair_counts.values())
        aa_prob = {k: v/n_aa for k,v in aa_counts.items()}
        aa_pair_prob = {k: v/n_pairs for k,v in aa_pair_counts.items()}
        subst_score = {k: 2*np.log2(v/(aa_prob[k[0]]*aa_prob[k[1]])) for k,v in aa_pair_prob.items()}
        matrix_array = self.create_matrix_array(subst_score.items(),0)
        minimum, maximum = self.get_min_max(matrix_array)
        with open(os.path.join(text_dir, "text_loggs_debugging.txt"), "a") as f:
            f.write(f"\n matrix array: {matrix_array} \n minimum: {minimum}, maximum: {maximum} \n")
        row_labels = self.alphabet
        df = pd.DataFrame(matrix_array, index=row_labels, columns=row_labels)

        alphabet = list("ACDEFGHIKLMNPQRSTVWYX")

        frequencies = self.calculate_background_frequencies(sequences, alphabet)
        #lambda_val = 0.34657
        lambda_val = self.calculate_lambda(matrix_array, text_dir)
        with open(m, "w") as f:
            # Write header
            f.write("# HypOptHD CMA AA Build\n")
            f.write(
                "# Background (precomputed optional): "
                + " ".join(f"{freq:.4f}" for freq in frequencies)
                + "\n"
            )
            f.write(
                "# Lambda     (precomputed optional): " + f"{lambda_val:.4f}" + "\n"
            )

            # Write amino acid header row
            f.write("   " + "       ".join(self.alphabet) + "\n")

            # Write matrix rows
            for i, aa in enumerate(self.alphabet):
                row = [f"{aa} "]
                row.extend(f"{matrix_array[i][j]:.4f} " for j in range(len(self.alphabet)))
                f.write("".join(row) + "\n")
        return minimum, maximum+1
    
    '''
    method to compute matrix for the neural network pipelines
    '''
    def compute_matrix_nnp(self, m, aa_pair_counts, aa_counts, count_alignments, count_alignments_used, dir_experiment, embedded_sequences, sequences_path):
        sequences = self.get_sequences_from_fasta(sequences_path)
        n_aa = sum(aa_counts.values())
        n_pairs = sum(aa_pair_counts.values())
        aa_prob = {k: v/n_aa for k,v in aa_counts.items()}
        aa_pair_prob = {k: v/n_pairs for k,v in aa_pair_counts.items()}
        subst_score = {k: 2*np.log2(v/(aa_prob[k[0]]*aa_prob[k[1]])) for k,v in aa_pair_prob.items()}
        matrix_array = self.create_matrix_array(subst_score.items(),4) # round to 4
        minimum, maximum = self.get_min_max(matrix_array)
        alphabet = list("ACDEFGHIKLMNPQRSTVWYX")

        frequencies = self.calculate_background_frequencies(sequences, alphabet)
        lambda_val = 0.34657
        #lambda_val = self.calculate_lambda(matrix_array, dir_experiment)
        with open(m, "w") as f:
            # Write header
            f.write("# HypOptHD CMA\n")
            f.write(
                "# Background (precomputed optional): "
                + " ".join(f"{freq:.4f}" for freq in frequencies)
                + "\n"
            )
            f.write(
                "# Lambda     (precomputed optional): " + f"{lambda_val:.4f}" + "\n"
            )

            # Write amino acid header row
            f.write("   " + "       ".join(self.alphabet) + "\n")

            # Write matrix rows
            for i, aa in enumerate(self.alphabet):
                row = [f"{aa} "]
                row.extend(f"{matrix_array[i][j]:.4f} " for j in range(len(self.alphabet)))
                f.write("".join(row) + "\n")
        return minimum, maximum+1
    
    '''
    method to calculate background frequencies
    '''
    def calculate_background_frequencies(self, embedded_sequences, alphabet):
        counts = Counter()
       
        for id in embedded_sequences:
            #print(f"embedded_sequences[id]: {embedded_sequences[id]}")
            counts.update(aa for aa in embedded_sequences[id] if aa in alphabet)
        total = sum(counts.values())

        frequencies = [counts[aa] / total for aa in alphabet]
        return frequencies
    
    '''
    calculate lambda for matrix
    '''
    def calculate_lambda(self, matrix_array, dir_experiment):
        lmb = lambda_for_cma.LambdaForCMA(matrix_array, dir_experiment)
        return lmb.calculate_lambda()

    ''' 
    method to compute matrix for aqp
    '''
    def compute_matrix_aqp(self, m, aa_pair_counts, aa_counts, count_alignments, count_alignments_used, text_dir, embedded_sequences):
        n_aa = sum(aa_counts.values())
        n_pairs = sum(aa_pair_counts.values())
        aa_prob = {k: v/n_aa for k,v in aa_counts.items()}
        aa_pair_prob = {k: v/n_pairs for k,v in aa_pair_counts.items()}
        subst_score = {k: 2*np.log2(v/(aa_prob[k[0]]*aa_prob[k[1]])) for k,v in aa_pair_prob.items()}
        matrix_array = self.create_matrix_array(subst_score.items(),0)
        minimum, maximum = self.get_min_max(matrix_array)
        alphabet = list("ACDEFGHIKLMNPQRSTVWYX")

        frequencies = self.calculate_background_frequencies(embedded_sequences, alphabet)
        lambda_val = 0.34657
        with open(m, "w") as f:
            # Write header
            frequencies_string = "# Background (precomputed optional): " + " ".join(f"{freq:.4f}" for freq in frequencies) + "\n"
            lambda_string = "# Lambda     (precomputed optional): " + f"{lambda_val:.4f}" + "\n"
            text = self.create_matrix_string_parasail(matrix_array, minimum, frequencies_string, lambda_string)
            f.write(text)
        return minimum, maximum
    
    '''
    method to create the matrix as text in the format needed for parasail - only needed when parasail is used instead of MMseqs2
    '''
    def create_matrix_string_parasail(self, matrix_array, min_value, frequencies_string, lambda_string):
        text = ""
        # info thresholds as comment in the top
        text += "# scores matrix with distance < " + str(self.distance_threshold) + " and tm_score > " + str(self.tm_score_threshold) + "\n"
        text += "# HypOptHD CMA AQP\n"
        text += frequencies_string
        text += lambda_string
        
        text += ("    ")
        for letter in constants.ALPHABET:
            text += (letter + "   ")
        text += ("*\n")

        for indexI, i in enumerate(matrix_array):
            left = True
            for indexJ, j in enumerate(matrix_array[indexI]):
                if indexJ < len(matrix_array[indexI])-1:
                    # add letter left
                    if left:
                        if matrix_array[indexI][indexJ] < -9: # if first is negative and <-9
                            text += (self.alphabet[indexI] + " ")
                            left = False
                        elif matrix_array[indexI][indexJ] < 0: # if first is just negative
                            text += (self.alphabet[indexI] + "  ")
                            left = False
                        else: # if first is positive
                            text += (self.alphabet[indexI] + "  ")
                            left = False

                    if matrix_array[indexI][indexJ+1] < -9:
                        if j < 0:
                            text += (str(j) + " ")
                        else:
                            if j > 9:
                                text += (str(j) + " ")
                            else:
                                text += (" " + str(j) + " ")
                    else:
                        if j < 0:
                            text += (str(j) + "  ")
                        else:
                            if j > 9:
                                text += (str(j) + "  ")
                            else:
                                text += (" " + str(j) + "  ")
                else:
                    if j < 0:
                        text += (str(j) + "  ")
                    else:
                        if j > 9:
                            text += (str(j) + "  ")
                        else:
                            text += (" " + str(j) + "  ")

            if indexI < len(matrix_array):
                if min_value < 0:
                    text += (str(min_value) + "\n")
                else:
                    text += (" " + str(min_value) + "\n")
    
        if min_value < -9 or min_value > 9:
            text += ("* ")
        else:
            text += ("*  ")
        for letter in self.alphabet:
            if min_value < 0:
                if min_value < -9:
                    text += (str(min_value) + " ")
                else:
                    text += (str(min_value) + "  ")
            else:
                if min_value > 9:
                    text += (" " + str(min_value))
                else:
                    text += (" " + str(min_value) + "  ")
        if min_value < -9 or min_value > 9:
            text += ("   1\n")
        else:
            text += (" 1\n")

        return text