Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
FusOn-pLM / fuson_plm /utils /embedding.py
svincoff's picture
adding utility files used throughout FusOn-pLM training and benchmarking
ffaff91
raw
history blame
7.96 kB
import pickle
import torch
from transformers import EsmModel, AutoTokenizer
from transformers import T5Tokenizer, T5EncoderModel
import pickle
import logging
from fuson_plm.utils.logging import log_update
def redump_pickle_dictionary(pickle_path):
"""
Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+'
"""
entries = {}
# Load one by one
with open(pickle_path, 'rb') as f:
while True:
try:
entry = pickle.load(f)
entries.update(entry)
except EOFError:
break # End of file reached
except Exception as e:
print(f"An error occurred: {e}")
break
# Redump
with open(pickle_path, 'wb') as f:
pickle.dump(entries, f)
def load_esm2_type(esm_type, device=None):
"""
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
"""
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = EsmModel.from_pretrained(f"facebook/{esm_type}")
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
model.to(device)
model.eval() # disables dropout for deterministic results
return model, tokenizer, device
def load_prott5():
# Initialize tokenizer and model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
if device == torch.device('cpu'):
model.to(torch.float32)
model.to(device)
return model, tokenizer, device
def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
"""
Compute ESM embeddings.
Args:
model
tokenizer
sequences
device
average: if True, the average embeddings will be taken and returned
savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle
"""
# Correct save path to pickle if necessary
if savepath is not None:
if savepath[-4::] != '.pkl': savepath += '.pkl'
# If no max length was passed, just set it to the maximum in the dataset
max_seq_len = max([len(s) for s in sequences])
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
# Initialize an empty dict to store the ESM embeddings
embedding_dict = {}
# Iterate through the seqs
for i in range(len(sequences)):
sequence = sequences[i]
# Get the embeddings
with torch.no_grad():
inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length)
inputs = {k:v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
embedding = outputs.last_hidden_state
# remove extra dimension
embedding = embedding.squeeze(0)
# remove BOS and EOS tokens
embedding = embedding[1:-1, :]
# Convert embeddings to numpy array (if needed)
embedding = embedding.cpu().numpy()
# Average (if necessary)
if average:
embedding = embedding.mean(0)
# Add to dictionary
embedding_dict[sequence] = embedding
# Save individual embedding (if necessary)
if not(savepath is None) and not(save_at_end):
with open(savepath, 'ab+') as f:
d = {sequence: embedding}
pickle.dump(d, f)
# Print update (if necessary)
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
# Dump all at once at the end (if necessary)
if not(savepath is None):
# If saving for the first time, just dump it
if save_at_end:
with open(savepath, 'wb') as f:
pickle.dump(embedding_dict, f)
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
else:
redump_pickle_dictionary(savepath)
# Return the dictionary
return embedding_dict
def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
# Correct save path to pickle if necessary
if savepath is not None:
if savepath[-4::] != '.pkl': savepath += '.pkl'
# If no max length was passed, just set it to the maximum in the dataset
max_seq_len = max([len(s) for s in sequences])
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
# the ProtT5 tokenizer requires that there are spaces between residues
spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer
# Store embeddings here
embedding_dict = {} # store embeddings here
for i in range(0, len(spaced_sequences)):
spaced_sequence = spaced_sequences[i] # get current sequence
seq = spaced_sequence.replace(" ", "")
with torch.no_grad():
inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1
inputs = {k:v.to(device) for k, v in inputs.items()}
# Pass through the model with no gradient to get embeddings
embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
# Process the embedding
seq_length = len(seq) # length of the sequence is after you remove spaces
embedding = embedding_repr.last_hidden_state.squeeze(0) # remove batch dimension
embedding = embedding[0:-1] # remove EOS token (there is no BOS token)
embedding = embedding.cpu().numpy() # put on CPU and numpy
embedding_log = f"\tembedding shape: {embedding.shape}"
# MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length
assert embedding.shape[1] == 1024
assert embedding.shape[0] == seq_length
# Average (if necessary)
if average:
dim_before = embedding.shape
embedding = embedding.mean(0)
embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}"
# Add the embedding to the dictionary
embedding_dict[seq] = embedding
# Save individual embedding (if necessary)
if not(savepath is None) and not(save_at_end):
with open(savepath, 'ab+') as f:
d = {seq: embedding}
pickle.dump(d, f)
if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}")
# Dump all at once at the end (if necessary)
if not(savepath is None):
# If saving for the first time, just dump it
if save_at_end:
with open(savepath, 'wb') as f:
pickle.dump(embedding_dict, f)
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
else:
redump_pickle_dictionary(savepath)
# Return the dictionary
return embedding_dict