# laura maria engist, 2025
# script to load the embedder of a protein language model and get embeddings for protein sequences

import torch
from EBA.eba import plm_extractor as plm
from Bio import SeqIO
from evotuner import constants

scop_fasta_file = constants.S_AA
output_path_emb = constants.PROTT5_EMBEDDINGS

### load language model extractor: ProtT5 or ESMb1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
protT5_ext = plm.load_extractor('ProtT5', 'residue', device=device)


scop_sequences = dict()
with open(scop_fasta_file, "r") as fasta_file:
    for record in SeqIO.parse(fasta_file, "fasta"):
        scop_sequences[record.id] = str(record.seq)

embeddings = dict()
for i in scop_sequences:
    embeddings[i] = protT5_ext.extract(scop_sequences[i])

    if len(scop_sequences[i])!=embeddings[i].shape[0]:
        print(f'Length missmatch: {i}, {scop_sequences[i]}-{embeddings[i].shape[0]}')
    
    assert(len(scop_sequences[i])!=embeddings[i].shape[0], f'Length missmatch: {i}, {scop_sequences[i]}-{embeddings[i].shape[0]}')
    # if sequences of length mismatch should be ignored here already:
    '''embedding = protT5_ext.extract(scop_sequences[i])
    if len(scop_sequences[i])!=embedding.shape[0]:
        print(f'Length missmatch: {i}, {scop_sequences[i]}-{embeddings[i].shape[0]}')
    else:
        embeddings[i] = embedding'''

torch.save(embeddings, output_path_emb)

