Spaces:
Runtime error
Runtime error
from transformers import T5EncoderModel, T5Tokenizer | |
import torch | |
import numpy as np | |
import time | |
import os | |
def generate_embedings(input_seq, output_path=None): | |
# Create directories | |
protT5_path = "ps4_data/data/protT5" | |
# where to store the embeddings | |
per_residue_path = "ps4_data/data/protT5/output/per_residue_embeddings" if output_path is None else output_path | |
for dir_path in [protT5_path, per_residue_path]: | |
__create_dir(dir_path) | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
print("Using {}".format(device)) | |
# Load the encoder part of ProtT5-XL-U50 in half-precision (recommended) | |
model, tokenizer = __get_T5_model(device) | |
# Load fasta. | |
all_seqs = {"0": input_seq} | |
# Compute embeddings and/or secondary structure predictions | |
results = __get_embeddings(model, tokenizer, all_seqs, device) | |
return results | |
def __get_T5_model(device): | |
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc") | |
model = model.to(device) # move model to GPU | |
model = model.eval() # set model to evaluation model | |
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) | |
return model, tokenizer | |
def __save_embeddings(emb_dict,out_path): | |
np.savez_compressed(out_path, **emb_dict) | |
def __get_embeddings(model, tokenizer, seqs, device, per_residue=True, | |
max_residues=4000, max_seq_len=1000, max_batch=100): | |
results = {"residue_embs": dict(), | |
"protein_embs": dict(), | |
"sec_structs": dict() | |
} | |
# sort sequences according to length (reduces unnecessary padding --> speeds up embedding) | |
seq_dict = sorted(seqs.items(), key=lambda kv: len(seqs[kv[0]]), reverse=True) | |
start = time.time() | |
batch = list() | |
for seq_idx, (pdb_id, seq) in enumerate(seq_dict, 1): | |
seq = seq | |
seq_len = len(seq) | |
seq = ' '.join(list(seq)) | |
batch.append((pdb_id, seq, seq_len)) | |
# count residues in current batch and add the last sequence length to | |
# avoid that batches with (n_res_batch > max_residues) get processed | |
n_res_batch = sum([s_len for _, _, s_len in batch]) + seq_len | |
if len(batch) >= max_batch or n_res_batch >= max_residues or seq_idx == len(seq_dict) or seq_len > max_seq_len: | |
pdb_ids, seqs, seq_lens = zip(*batch) | |
batch = list() | |
# add_special_tokens adds extra token at the end of each sequence | |
token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest") | |
input_ids = torch.tensor(token_encoding['input_ids']).to(device) | |
attention_mask = torch.tensor(token_encoding['attention_mask']).to(device) | |
try: | |
with torch.no_grad(): | |
# returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim ) | |
embedding_repr = model(input_ids, attention_mask=attention_mask) | |
except RuntimeError: | |
print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len)) | |
continue | |
for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch | |
s_len = seq_lens[batch_idx] | |
# slice off padding --> batch-size x seq_len x embedding_dim | |
emb = embedding_repr.last_hidden_state[batch_idx, :s_len] | |
if per_residue: # store per-residue embeddings (Lx1024) | |
results["residue_embs"][identifier] = emb.detach().cpu().squeeze() | |
print("emb_count:", len(results["residue_embs"])) | |
passed_time = time.time() - start | |
avg_time = passed_time / len(results["residue_embs"]) if per_residue else passed_time / len(results["protein_embs"]) | |
print('\n############# EMBEDDING STATS #############') | |
print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"]))) | |
print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format( | |
passed_time / 60, avg_time)) | |
print('\n############# END #############') | |
return results | |
def __create_dir(path): | |
if not os.path.exists(path): | |
os.makedirs(path) |