# laura maria engist, 2025
# script to compute the alignment quality assessment

import os
import subprocess
from evotuner import constants
import sys
import json
import matplotlib.pyplot as plt
import numpy as np
import re
from Bio import SeqIO
import random
import resource

class AlignmentQuality:
    def __init__(self, go, ge, cma, fasta_embedded_sequences):
        self.go = go
        self.ge = ge
        self.cma = cma
        self.fasta_embedded_sequences = fasta_embedded_sequences
        self._embedded_sequences = None
        self._embedded_sequences_aligned = None
        self._scope_alignments = None
        self._start_indices = None
    
    @property
    def embedded_sequences(self):
        if self._embedded_sequences is None:
            self._embedded_sequences = self.load_embedded_sequences()
        return self._embedded_sequences
    
    @property
    def embedded_sequences_aligned(self):
        if self._embedded_sequences_aligned is None:
            self._embedded_sequences_aligned = dict()
        return self._embedded_sequences_aligned
    
    @property
    def scope_alignments(self):
        if self._scope_alignments is None:
            self._scope_alignments = dict()
        return self._scope_alignments
    
    @property
    def start_indices(self):
        if self._start_indices is None:
            self._start_indices = dict()
        return self._start_indices

    ''' 
    method to compute the alignment quality 
    remap: boolean, if True remap the aligned embedded sequences to the scope sequences
    '''
    def compute_alignment_quality(self, aln_file, remap):
        # load SCOPe alignments
        # self.load_scope_alignments_into_json() # must be done only once
        self.load_scope_alignments()
        # MMseqs for embedded sequences self.create_alignments_mmseqs()
        alignments_existing = self.align_two_sequences_mmseqs(self.fasta_embedded_sequences, self.fasta_embedded_sequences, aln_file, self.cma, self.go, self.ge)
        if alignments_existing:
            self.load_aligned_embedded_sequences(aln_file)
            # compute sensitivity and precision
            pa, alignments_used = self.sensitivity_and_precision(remap)
            return pa, alignments_used, alignments_existing
        else:
            return 0,0,alignments_existing
    
    def alignment_quality_mmseqs(self):
        # load SCOPe alignments
        # self.load_scope_alignments_into_json() # must be done only once
        self.load_scope_alignments()
    
    '''
    method to convert the mmseqs output file to a more usable format
    '''
    def load_aligned_embedded_sequences(self, mmseqs_output_file_scope):
        # load the aligned embedded sequences from the mmseqs output file
        with open(mmseqs_output_file_scope, "r") as f:
            for line in f:
                parts = line.strip().split("\t")
                if len(parts) < 2:
                    continue
                description1 = parts[0]
                description2 = parts[1]
                aligned_seq1 = parts[2]
                aligned_seq2 = parts[3]
                start_aligned_seq1 = parts[4]
                start_aligned_seq2 = parts[5]
                if description1 != description2:  # ensure we only store pairs of different sequences
                    self.embedded_sequences_aligned[(description1, description2)] = (aligned_seq1, aligned_seq2)
                    self.start_indices[(description1, description2)] = (start_aligned_seq1, start_aligned_seq2)
    
    '''
    method to remap the aligned embedded sequences to the scope sequences
    '''
    def re_map_aligned_embedded_sequences(self, aligned_embedded_sequence, aligned_scope_sequence, start_indices):
        scope_sequence1 = re.sub("-", "", aligned_scope_sequence[0])
        scope_sequence2 = re.sub("-", "", aligned_scope_sequence[1])
        aligned_embedded_sequence_remapped1 = ""
        aligned_embedded_sequence_remapped2 = ""
        start_index1 = start_indices[0]
        start_index2 = start_indices[1]
        counter = 0
        for character in aligned_embedded_sequence[0]:
            if character != '-':
                aligned_embedded_sequence_remapped1 += scope_sequence1[counter]
            else:
                aligned_embedded_sequence_remapped1 += '-'
        counter = 0
        for character in aligned_embedded_sequence[1]:
            if character != '-':
                aligned_embedded_sequence_remapped2 += scope_sequence2[counter]
            else:
                aligned_embedded_sequence_remapped2 += '-'
        aligned_embedded_sequence_remapped = (aligned_embedded_sequence_remapped1, aligned_embedded_sequence_remapped2)
        print(f"remapped: {aligned_embedded_sequence_remapped}")
        return aligned_embedded_sequence_remapped

    
    '''
    method to compute the alignment quality metric based on HOMSTRAD benchmarking approach but we compare to the SCOPe alignments
    '''
    def get_alignment_quality_metric(self, mmseqs_output_file_scope):
        alignment_quality = 0
        scope_alignments = constants.GTA_S_AA
        mmseqs_alignments, families = self.convert_mmseqs_output_to_var(mmseqs_output_file_scope)
        return alignment_quality
    
    ''' 
    method to align two sequences with MMseqs2 and a memory limit of 50 GB 
    '''
    def align_two_sequences_mmseqs(self, temp_path1, temp_path2, tmp_output_file, matrix_file_path, go, ge):
        command = [constants.MMSEQS, "easy-search", temp_path1, temp_path2, tmp_output_file, "tmp"]
        command.append("--gap-open")
        command.append(str(go))
        command.append("--gap-extend")
        command.append(str(ge))
        if matrix_file_path != "":
            command.append("--sub-mat")
            command.append(matrix_file_path)
        command.append("--format-output")
        command.append("query,target,qaln,taln,qstart,tstart")
        command.append("--exhaustive-search")
        print(f"command: {command}")
        

        MEM_LIMIT = 50 * 1024 * 1024 * 1024 # 50 GB
        try:
            proc = subprocess.Popen(command, preexec_fn=resource.setrlimit(resource.RLIMIT_AS, (MEM_LIMIT, MEM_LIMIT)))
            print("Subprozess PID:", proc.pid)
            sys.stdout.flush()
            proc.wait()
            print("Rückgabecode:", proc.returncode)
        except subprocess.CalledProcessError as e:
            print(f"failed with error: {e.returncode}; {e.output}; {e.stdout}; {e.stderr}")
            return False, go, ge

        # check if alignments created
        alignments_existing = False
        if not os.path.exists(tmp_output_file):
            print("❌ File 'alnResult.m8' was not created!")
            return False
        elif os.path.getsize(tmp_output_file) == 0:
            print("⚠️ File 'alnResult.m8' was created, but is empty.")
            return False
        else:
            print("✅ File 'alnResult.m8' found!")
            alignments_existing = True

        return alignments_existing
        
    '''
    method to load the embedded sequences from fasta file
    '''
    def load_embedded_sequences(self):
        embedded_sequences = dict()
        with open(self.fasta_embedded_sequences, "r") as f:
            for record in SeqIO.parse(self.fasta_embedded_sequences, "fasta"):
                embedded_sequences[record.id] = str(record.seq)
        return embedded_sequences
    
    '''
    method to load a random ten percent of the SCOPe alignments into a json file
    '''
    def load_scope_alignments_into_json(self):
        dirs = os.listdir(constants.GTA_S_AA)
        json_objects = []
        count = 0
        # 4161 families, 58520 - 4161 = 54359 alignments, (54359 / 100) * 10 = 5435 alignments
        while count < 5435:
            family_index = random.randint(0, 4160 - 1) # we do not want the last directory, because that contains one file per family
            if "one_file_per_family" not in dirs[family_index]:
                files = os.listdir(os.path.join(constants.GTA_S_AA, dirs[family_index]))
                if len(files) != 0:
                    alignment_index = random.randint(0, len(files) - 1)
                    with open(os.path.join(constants.GTA_S_AA, dirs[family_index], files[alignment_index]), "r") as f:
                        data = json.load(f)
                        description1 = data['description1']
                        description2 = data['description2']
                        line1 = data['line1']
                        line2 = data['line2']

                        if description1 in self.embedded_sequences and description2 in self.embedded_sequences:
                            json_object = {
                                "description1": description1,
                                "description2": description2,
                                "line1": line1,
                                "line2": line2
                            }
                            if json_object not in json_objects:
                                json_objects.append(json_object)
                                count += 1

        with open(constants.RANDOM_TEN_PERCENT_SCOPE, "w") as f:
            json.dump(json_objects, f)
    
    '''
    method to load the scope alignments from the json file
    '''
    def load_scope_alignments(self):
        scope_alignments = dict()
        with open(constants.RANDOM_TEN_PERCENT_SCOPE, "r") as f:
            data = json.load(f)
            for alignment in data:
                description1 = alignment['description1']
                description2 = alignment['description2']
                line1 = alignment['line1']
                line2 = alignment['line2']
                scope_alignments[(description1, description2)] = (line1, line2)
        self._scope_alignments = scope_alignments
    
    ''''
    method to create the alignments scope_alignments with parasail
    '''
    def create_alignments(self):
        for pair in self.scope_alignments:
            description1, description2 = pair
            sequence1, sequence2 = self.scope_alignments[pair]
            self.align_two_sequences(description1, description2, self.embedded_sequences[description1], self.embedded_sequences[description2])
    
    '''
    method to create paris of alignments
    '''
    def pairs_from_aln(self, start1, aln1, start2, aln2):
        i1, i2 = start1, start2
        pairs = []
        for letter1, letter2 in zip(aln1, aln2):
            if letter1 not in ('-', '/') and letter2 not in ('-', '/'):
                pairs.append((i1, i2))
            if letter1 not in ('-', '/'):
                i1 += 1
            if letter2 not in ('-', '/'):
                i2 += 1
        return pairs
    
    '''
    method to parse an alignment
    '''
    def parse_custom_ali(self, pair,alg_sequences):
        aln = {}
        if pair[0] != pair[1]:
            aln[(pair[0], pair[1])] = (1, alg_sequences[0], 1, alg_sequences[1])
        return aln
    
    '''
    method to compare the alignments based on sensitivity and precision
    '''
    def check_alignment(self, pair, embedded_alignment, scope_alignment, start_indices):
        aln = self.parse_custom_ali(pair, embedded_alignment)
        ref_aln =  self.parse_custom_ali(pair, scope_alignment)
        
        res = []
        log = ''
        
        name1 = pair[0]
        name2 = pair[1]
        ref_aln1 = scope_alignment[0]
        ref_aln2 = scope_alignment[1]
        
        if (name1, name2) not in aln:
            res.append([name1, name2, None, None, None])
            return res
        start1, aln1, start2, aln2 = predicted_aln = aln[(name1, name2)]

        ref_pairs = self.pairs_from_aln(1, ref_aln1, 1, ref_aln2)
        pairs = self.pairs_from_aln(int(start_indices[0]), aln1, int(start_indices[1]), aln2)

        sensitivity = len([p for p in ref_pairs if p in pairs]) / len(ref_pairs)
        accuracy = len([p for p in pairs if p in ref_pairs]) / len(pairs)

        res.append([name1, name2, sensitivity, accuracy, predicted_aln])

        # Logging
        correct_pairs = [p for p in pairs if p in ref_pairs]
        correct_pos1 = {p[0] for p in correct_pairs}
        correct_pos2 = {p[1] for p in correct_pairs}
        refaln_pos1 = [p[0] for p in ref_pairs]
        refaln_pos2 = [p[1] for p in ref_pairs]
        seq1, between, seq2 = '', '', ''
        assert len(aln1) == len(aln2)

        i = start1
        for letter in aln1:
            if letter not in ('-', '/'):
                if i in correct_pos1:
                    seq1 += str(refaln_pos1.index(i) % 10)  # 'o'
                else:
                    if i in refaln_pos1:
                        #seq1 += 'x'
                        seq1 += str(refaln_pos1.index(i) % 10)  # 'x'
                    else:
                        seq1 += 'n'  # FP
                i += 1
            else:
                seq1 += '-'
        i = start2
        for letter in aln2:
            if letter not in ('-', '/'):
                if i in correct_pos2:
                    seq2 += str(refaln_pos2.index(i) % 10)  # 'o'
                    between += '|'
                else:                        
                    between += ' '
                    if i in refaln_pos2:
                        seq2 += str(refaln_pos2.index(i) % 10)  # 'x'
                    else:
                        seq2 += 'n'  # FP
                i += 1
            else:
                seq2 += '-'
                between += ' '

        log += f'> {name1} - {name2} (sensitivity: {sensitivity:.2f}, precision: {accuracy:.2f})\n\n'
        for i in range((len(aln1) + 99) // 100):
            log += seq1[i*100:(i+1)*100] + '\n' + between[i*100:(i+1)*100] + '\n' + seq2[i*100:(i+1)*100] + '\n\n'

        log += '*' * 100 + '\n\n'
                
        return res
    
    '''
    method to compute sensitivity and precision
    '''
    def stats(self, res):
        found_ration = len([r for r in res if r[2] is not None]) / len(res)
        sensitivity = np.mean([r[2] or 0 for r in res])
        precision = np.mean([r[3] for r in res if r[3] is not None])
        return found_ration, sensitivity, precision
    
    '''
    method to collect computed sensitivity and precision
    '''
    def sensitivity_and_precision(self, remap):
        all_results = dict()
        alignments_used = 0
        for k in self.embedded_sequences_aligned:
            if k in self.scope_alignments:
                res = []
                if remap:
                    embedded_alignment = self.re_map_aligned_embedded_sequences(self.embedded_sequences_aligned[k], self.scope_alignments[k], self.start_indices[k])
                else:
                    embedded_alignment = self.embedded_sequences_aligned[k]
                res += self.check_alignment(k,embedded_alignment,self.scope_alignments[k],self.start_indices[k])
                alignments_used += 1

        pa = self.stats(res)
        self._results = pa
        all_results[k] = (pa[1],pa[2])
        return pa, alignments_used
    
    '''
    method to create the plot and save it
    '''
    def create_plot_of_comparison(self):
        all_results = dict()
        all_results['Best500'] = (0.6429, 0.5556)
        all_results['3di'] = (0.6571, 0.5679) # np.float64(0.6571428571428571), np.float64(0.5679012345679012))
        all_results['ml'] = (0.7143, 0.6250) # np.float64(0.7142857142857143), np.float64(0.625)
        all_results['Best500 Default'] = (0.6286, 0.5301) # best500: pa: (1.0, np.float64(0.6285714285714286), np.float64(0.5301204819277109)), go: 11, ge: 1
        all_results['3di Default'] = (0.3857, 0.3139) # 3di: pa: (1.0, np.float64(0.38571428571428573), np.float64(0.313953488372093)), go: 11, ge: 1
        all_results['ml Default'] = (0.7143, 0.6173) # ml: pa: (1.0, np.float64(0.7142857142857143), np.float64(0.6172839506172839)), go: 11, ge: 1

        colors = {"Best500":"#dc267f","3di":"#648fff","ml":"#fe6100","Best500 Default":"#ffb000","3di Default":"#ffb000","ml Default":"#ffb000"}
        markers = {"Best500":"*","3di":"o","ml":"X","Best500 Default":"*","3di Default":"o","ml Default":"X"}

        lines = []
        for name in all_results:
            sens = all_results[name][0]
            prec = all_results[name][1]
            lines.append(f'{name:<10}\t{sens:.3f}\t\t{prec:.3f}')
            plt.scatter([prec], [sens], label=name, color=colors[name],
                        marker=markers[name],edgecolors="black", s=200)
        # Plot
        background_color = '#F2F2F2'
        plt.gca().set_facecolor(background_color)
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.ylabel('Sensitivity',fontsize="x-large")
        plt.xlabel('Precision',fontsize="x-large")
        plt.legend(frameon=False)
        plt.show()

        self.save_plot_and_table(plt, lines, "aqp")
    
    '''
    method to save plot and corresponding information in files
    '''
    def save_plot_and_table(self, plot, lines, name):
        directory_experiment = ""
        path_for_plot = os.path.join(directory_experiment, name)
        plot.savefig(path_for_plot, dpi=300, bbox_inches='tight')
        plot.savefig(path_for_plot + ".svg", dpi=300, bbox_inches='tight')
        with open(path_for_plot + '.txt', 'w') as f:
            f.write('Name     \tSensitvity\tPrecision')
            f.write('\n')
            f.write('-' * 33)
            f.write('\n')
            f.write('\n'.join(sorted(lines, key=lambda x: x.split()[2], reverse=True)))