|
import pickle |
|
import torch |
|
from transformers import EsmModel, AutoTokenizer |
|
from transformers import T5Tokenizer, T5EncoderModel |
|
import pickle |
|
import logging |
|
from fuson_plm.utils.logging import log_update |
|
|
|
|
|
def redump_pickle_dictionary(pickle_path): |
|
""" |
|
Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+' |
|
""" |
|
entries = {} |
|
|
|
with open(pickle_path, 'rb') as f: |
|
while True: |
|
try: |
|
entry = pickle.load(f) |
|
entries.update(entry) |
|
except EOFError: |
|
break |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
break |
|
|
|
with open(pickle_path, 'wb') as f: |
|
pickle.dump(entries, f) |
|
|
|
def load_esm2_type(esm_type, device=None): |
|
""" |
|
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) |
|
""" |
|
|
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
|
|
if device is None: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
model = EsmModel.from_pretrained(f"facebook/{esm_type}") |
|
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
return model, tokenizer, device |
|
|
|
def load_prott5(): |
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) |
|
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc") |
|
if device == torch.device('cpu'): |
|
model.to(torch.float32) |
|
model.to(device) |
|
return model, tokenizer, device |
|
|
|
def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): |
|
""" |
|
Compute ESM embeddings. |
|
|
|
Args: |
|
model |
|
tokenizer |
|
sequences |
|
device |
|
average: if True, the average embeddings will be taken and returned |
|
savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle |
|
""" |
|
|
|
if savepath is not None: |
|
if savepath[-4::] != '.pkl': savepath += '.pkl' |
|
|
|
|
|
max_seq_len = max([len(s) for s in sequences]) |
|
if max_length is None: max_length=max_seq_len+2 |
|
|
|
|
|
embedding_dict = {} |
|
|
|
for i in range(len(sequences)): |
|
sequence = sequences[i] |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length) |
|
inputs = {k:v.to(device) for k, v in inputs.items()} |
|
|
|
outputs = model(**inputs) |
|
embedding = outputs.last_hidden_state |
|
|
|
|
|
embedding = embedding.squeeze(0) |
|
|
|
embedding = embedding[1:-1, :] |
|
|
|
|
|
embedding = embedding.cpu().numpy() |
|
|
|
|
|
if average: |
|
embedding = embedding.mean(0) |
|
|
|
|
|
embedding_dict[sequence] = embedding |
|
|
|
|
|
if not(savepath is None) and not(save_at_end): |
|
with open(savepath, 'ab+') as f: |
|
d = {sequence: embedding} |
|
pickle.dump(d, f) |
|
|
|
|
|
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...") |
|
|
|
|
|
if not(savepath is None): |
|
|
|
if save_at_end: |
|
with open(savepath, 'wb') as f: |
|
pickle.dump(embedding_dict, f) |
|
|
|
else: |
|
redump_pickle_dictionary(savepath) |
|
|
|
|
|
return embedding_dict |
|
|
|
def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): |
|
|
|
if savepath is not None: |
|
if savepath[-4::] != '.pkl': savepath += '.pkl' |
|
|
|
|
|
max_seq_len = max([len(s) for s in sequences]) |
|
if max_length is None: max_length=max_seq_len+2 |
|
|
|
|
|
spaced_sequences = [' '.join(list(seq)) for seq in sequences] |
|
|
|
|
|
embedding_dict = {} |
|
|
|
for i in range(0, len(spaced_sequences)): |
|
spaced_sequence = spaced_sequences[i] |
|
seq = spaced_sequence.replace(" ", "") |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) |
|
inputs = {k:v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
|
|
|
|
|
seq_length = len(seq) |
|
embedding = embedding_repr.last_hidden_state.squeeze(0) |
|
embedding = embedding[0:-1] |
|
embedding = embedding.cpu().numpy() |
|
embedding_log = f"\tembedding shape: {embedding.shape}" |
|
|
|
assert embedding.shape[1] == 1024 |
|
assert embedding.shape[0] == seq_length |
|
|
|
|
|
if average: |
|
dim_before = embedding.shape |
|
embedding = embedding.mean(0) |
|
embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}" |
|
|
|
|
|
embedding_dict[seq] = embedding |
|
|
|
|
|
if not(savepath is None) and not(save_at_end): |
|
with open(savepath, 'ab+') as f: |
|
d = {seq: embedding} |
|
pickle.dump(d, f) |
|
|
|
if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}") |
|
|
|
|
|
if not(savepath is None): |
|
|
|
if save_at_end: |
|
with open(savepath, 'wb') as f: |
|
pickle.dump(embedding_dict, f) |
|
|
|
else: |
|
redump_pickle_dictionary(savepath) |
|
|
|
|
|
return embedding_dict |