import pickle import torch from transformers import EsmModel, AutoTokenizer from transformers import T5Tokenizer, T5EncoderModel import pickle import logging from fuson_plm.utils.logging import log_update def redump_pickle_dictionary(pickle_path): """ Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+' """ entries = {} # Load one by one with open(pickle_path, 'rb') as f: while True: try: entry = pickle.load(f) entries.update(entry) except EOFError: break # End of file reached except Exception as e: print(f"An error occurred: {e}") break # Redump with open(pickle_path, 'wb') as f: pickle.dump(entries, f) def load_esm2_type(esm_type, device=None): """ Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) """ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = EsmModel.from_pretrained(f"facebook/{esm_type}") tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") model.to(device) model.eval() # disables dropout for deterministic results return model, tokenizer, device def load_prott5(): # Initialize tokenizer and model device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc") if device == torch.device('cpu'): model.to(torch.float32) model.to(device) return model, tokenizer, device def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): """ Compute ESM embeddings. Args: model tokenizer sequences device average: if True, the average embeddings will be taken and returned savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle """ # Correct save path to pickle if necessary if savepath is not None: if savepath[-4::] != '.pkl': savepath += '.pkl' # If no max length was passed, just set it to the maximum in the dataset max_seq_len = max([len(s) for s in sequences]) if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS # Initialize an empty dict to store the ESM embeddings embedding_dict = {} # Iterate through the seqs for i in range(len(sequences)): sequence = sequences[i] # Get the embeddings with torch.no_grad(): inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length) inputs = {k:v.to(device) for k, v in inputs.items()} outputs = model(**inputs) embedding = outputs.last_hidden_state # remove extra dimension embedding = embedding.squeeze(0) # remove BOS and EOS tokens embedding = embedding[1:-1, :] # Convert embeddings to numpy array (if needed) embedding = embedding.cpu().numpy() # Average (if necessary) if average: embedding = embedding.mean(0) # Add to dictionary embedding_dict[sequence] = embedding # Save individual embedding (if necessary) if not(savepath is None) and not(save_at_end): with open(savepath, 'ab+') as f: d = {sequence: embedding} pickle.dump(d, f) # Print update (if necessary) if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...") # Dump all at once at the end (if necessary) if not(savepath is None): # If saving for the first time, just dump it if save_at_end: with open(savepath, 'wb') as f: pickle.dump(embedding_dict, f) # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely else: redump_pickle_dictionary(savepath) # Return the dictionary return embedding_dict def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): # Correct save path to pickle if necessary if savepath is not None: if savepath[-4::] != '.pkl': savepath += '.pkl' # If no max length was passed, just set it to the maximum in the dataset max_seq_len = max([len(s) for s in sequences]) if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS # the ProtT5 tokenizer requires that there are spaces between residues spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer # Store embeddings here embedding_dict = {} # store embeddings here for i in range(0, len(spaced_sequences)): spaced_sequence = spaced_sequences[i] # get current sequence seq = spaced_sequence.replace(" ", "") with torch.no_grad(): inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1 inputs = {k:v.to(device) for k, v in inputs.items()} # Pass through the model with no gradient to get embeddings embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) # Process the embedding seq_length = len(seq) # length of the sequence is after you remove spaces embedding = embedding_repr.last_hidden_state.squeeze(0) # remove batch dimension embedding = embedding[0:-1] # remove EOS token (there is no BOS token) embedding = embedding.cpu().numpy() # put on CPU and numpy embedding_log = f"\tembedding shape: {embedding.shape}" # MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length assert embedding.shape[1] == 1024 assert embedding.shape[0] == seq_length # Average (if necessary) if average: dim_before = embedding.shape embedding = embedding.mean(0) embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}" # Add the embedding to the dictionary embedding_dict[seq] = embedding # Save individual embedding (if necessary) if not(savepath is None) and not(save_at_end): with open(savepath, 'ab+') as f: d = {seq: embedding} pickle.dump(d, f) if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}") # Dump all at once at the end (if necessary) if not(savepath is None): # If saving for the first time, just dump it if save_at_end: with open(savepath, 'wb') as f: pickle.dump(embedding_dict, f) # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely else: redump_pickle_dictionary(savepath) # Return the dictionary return embedding_dict