# laura maria engist, 2025
# Clustering for the K-Means Pipeline
# S = E(S_AA, init, normalize)

import torch
import os
from scipy.cluster.vq import kmeans2
from sklearn.cluster import BisectingKMeans, KMeans
from evotuner import constants

class KmpClusterBisecting:
    def __init__(self, non_clustered_embeddings, tmpdir):
        self.non_clustered_embeddings = non_clustered_embeddings
        self.tmpdir = tmpdir
    
    '''
    S_AA: dict of protein embeddings {protein_id: tensor of shape (L, D)}
    init: 'random' or 'k-means++' 
    normalize: True or False
    '''
    def e(self, S, init, normalize):
        clustering_alg_name = "Bisecting K-Means"
        # S is already embedded here but not yet clustered --> self.non_clustered_embeddings
        scop_embeddings_path = self.non_clustered_embeddings
        emb_type = scop_embeddings_path.split('/')[-1].split('.')[0]
        embeddings_dict = self.load_embeddings_dict(scop_embeddings_path)
        print(f"Landscape of {len(embeddings_dict)} proteins")

        ### make sure to keep the order
        tensor_list, sequences_ids = self.keep_order(embeddings_dict)
        print("KEEP ORDER DONE")
        data_landscape = self.get_data_landscape(tensor_list)
        print(f'Number of residues: {data_landscape.shape[0]}')

        if normalize:
            data_landscape = self.normalize(data_landscape)
        print(f'Number of residues after normalization: {data_landscape.shape[0]}')

        clusters = self.get_cluster(init, clustering_alg_name, data_landscape)
        print("GOT CLUSTERS")

        codebook = clusters.cluster_centers_
        labels = clusters.labels_

        alphabet_path = os.path.join(self.tmpdir, f'bisectingkmeans_std_{constants.ALPHABET_SIZE}_{emb_type}')

        if not os.path.exists(alphabet_path):
            os.mkdir(alphabet_path)
        
        codebook_path = os.path.join(alphabet_path,'codebook.pt')
        torch.save(codebook, codebook_path)
        print("CODEBOOK SAVED")

        conversion_path = os.path.join(self.tmpdir, f'kmeans_{constants.ALPHABET_SIZE}')

        alphabeta = self.create_alphabeta(sequences_ids, embeddings_dict, labels)
        print("ALPHABETA CREATED")

        fasta_path = os.path.join(alphabet_path,'alphabeta.pt')
        torch.save(alphabeta, fasta_path)
        print(f"fasta_path: {fasta_path}")
        return fasta_path
    
    ''' Helper Functions '''
    def load_embeddings_dict(self, scop_embeddings_path):
        embeddings_dict = torch.load(scop_embeddings_path, map_location=torch.device('cpu'))
        return embeddings_dict
    
    def keep_order(self, embeddings_dict):
        tensor_list = list()
        sequences_ids = list()
        for x in embeddings_dict:
            tensor_list.append(embeddings_dict[x])
            sequences_ids.append(x)
        return tensor_list, sequences_ids

    def get_data_landscape(self, tensor_list):
        data_landscape = torch.cat(tensor_list, dim=0)
        print("DATA LANDSCAPE DONE")
        return data_landscape
    
    def normalize(self, data_landscape):
        mean = torch.mean(data_landscape, dim=0, keepdim=True)
        std = torch.std(data_landscape, dim=0, keepdim=True)
        data_landscape = (data_landscape - mean) / (std + 1e-8)
        return data_landscape
    
    def get_cluster(self, init, clustering_alg_name, data_landscape):
        clustering_algorithms = {
            "Bisecting K-Means": BisectingKMeans,
            "K-Means": KMeans,
        }

        if init == 'random':
            algo = clustering_algorithms[clustering_alg_name](init="random", n_clusters=constants.ALPHABET_SIZE)
        elif init == 'k-means++':
            algo = clustering_algorithms[clustering_alg_name](init="k-means++", n_clusters=constants.ALPHABET_SIZE)

        clusters = algo.fit(data_landscape)
        return clusters
    
    def create_alphabeta(self, sequences_ids, embeddings_dict, labels):
        alphabeta = dict()
        start = 0
        for p in sequences_ids :
            seq_length = embeddings_dict[p].shape[0]
            alphabeta[p] = labels[start:start+seq_length]
            start+= seq_length
        return alphabeta