File size: 7,958 Bytes
ffaff91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import pickle
import torch
from transformers import EsmModel, AutoTokenizer
from transformers import T5Tokenizer, T5EncoderModel
import pickle
import logging
from fuson_plm.utils.logging import log_update
def redump_pickle_dictionary(pickle_path):
"""
Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+'
"""
entries = {}
# Load one by one
with open(pickle_path, 'rb') as f:
while True:
try:
entry = pickle.load(f)
entries.update(entry)
except EOFError:
break # End of file reached
except Exception as e:
print(f"An error occurred: {e}")
break
# Redump
with open(pickle_path, 'wb') as f:
pickle.dump(entries, f)
def load_esm2_type(esm_type, device=None):
"""
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
"""
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = EsmModel.from_pretrained(f"facebook/{esm_type}")
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
model.to(device)
model.eval() # disables dropout for deterministic results
return model, tokenizer, device
def load_prott5():
# Initialize tokenizer and model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
if device == torch.device('cpu'):
model.to(torch.float32)
model.to(device)
return model, tokenizer, device
def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
"""
Compute ESM embeddings.
Args:
model
tokenizer
sequences
device
average: if True, the average embeddings will be taken and returned
savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle
"""
# Correct save path to pickle if necessary
if savepath is not None:
if savepath[-4::] != '.pkl': savepath += '.pkl'
# If no max length was passed, just set it to the maximum in the dataset
max_seq_len = max([len(s) for s in sequences])
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
# Initialize an empty dict to store the ESM embeddings
embedding_dict = {}
# Iterate through the seqs
for i in range(len(sequences)):
sequence = sequences[i]
# Get the embeddings
with torch.no_grad():
inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length)
inputs = {k:v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
embedding = outputs.last_hidden_state
# remove extra dimension
embedding = embedding.squeeze(0)
# remove BOS and EOS tokens
embedding = embedding[1:-1, :]
# Convert embeddings to numpy array (if needed)
embedding = embedding.cpu().numpy()
# Average (if necessary)
if average:
embedding = embedding.mean(0)
# Add to dictionary
embedding_dict[sequence] = embedding
# Save individual embedding (if necessary)
if not(savepath is None) and not(save_at_end):
with open(savepath, 'ab+') as f:
d = {sequence: embedding}
pickle.dump(d, f)
# Print update (if necessary)
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
# Dump all at once at the end (if necessary)
if not(savepath is None):
# If saving for the first time, just dump it
if save_at_end:
with open(savepath, 'wb') as f:
pickle.dump(embedding_dict, f)
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
else:
redump_pickle_dictionary(savepath)
# Return the dictionary
return embedding_dict
def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
# Correct save path to pickle if necessary
if savepath is not None:
if savepath[-4::] != '.pkl': savepath += '.pkl'
# If no max length was passed, just set it to the maximum in the dataset
max_seq_len = max([len(s) for s in sequences])
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
# the ProtT5 tokenizer requires that there are spaces between residues
spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer
# Store embeddings here
embedding_dict = {} # store embeddings here
for i in range(0, len(spaced_sequences)):
spaced_sequence = spaced_sequences[i] # get current sequence
seq = spaced_sequence.replace(" ", "")
with torch.no_grad():
inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1
inputs = {k:v.to(device) for k, v in inputs.items()}
# Pass through the model with no gradient to get embeddings
embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
# Process the embedding
seq_length = len(seq) # length of the sequence is after you remove spaces
embedding = embedding_repr.last_hidden_state.squeeze(0) # remove batch dimension
embedding = embedding[0:-1] # remove EOS token (there is no BOS token)
embedding = embedding.cpu().numpy() # put on CPU and numpy
embedding_log = f"\tembedding shape: {embedding.shape}"
# MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length
assert embedding.shape[1] == 1024
assert embedding.shape[0] == seq_length
# Average (if necessary)
if average:
dim_before = embedding.shape
embedding = embedding.mean(0)
embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}"
# Add the embedding to the dictionary
embedding_dict[seq] = embedding
# Save individual embedding (if necessary)
if not(savepath is None) and not(save_at_end):
with open(savepath, 'ab+') as f:
d = {seq: embedding}
pickle.dump(d, f)
if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}")
# Dump all at once at the end (if necessary)
if not(savepath is None):
# If saving for the first time, just dump it
if save_at_end:
with open(savepath, 'wb') as f:
pickle.dump(embedding_dict, f)
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
else:
redump_pickle_dictionary(savepath)
# Return the dictionary
return embedding_dict |