Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
File size: 7,958 Bytes
ffaff91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import pickle
import torch
from transformers import EsmModel, AutoTokenizer
from transformers import T5Tokenizer, T5EncoderModel
import pickle
import logging
from fuson_plm.utils.logging import log_update


def redump_pickle_dictionary(pickle_path):
    """
    Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+'
    """
    entries = {}
    # Load one by one
    with open(pickle_path, 'rb') as f:
        while True:
            try:
                entry = pickle.load(f)
                entries.update(entry)
            except EOFError:
                break  # End of file reached
            except Exception as e:
                print(f"An error occurred: {e}")
                break
    # Redump
    with open(pickle_path, 'wb') as f:
        pickle.dump(entries, f)
        
def load_esm2_type(esm_type, device=None):
    """
    Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
    """
    # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
    logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
    
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

    model = EsmModel.from_pretrained(f"facebook/{esm_type}")
    tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")

    model.to(device)
    model.eval()  # disables dropout for deterministic results
    
    return model, tokenizer, device

def load_prott5():
    # Initialize tokenizer and model
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    if device == torch.device('cpu'):
        model.to(torch.float32)
    model.to(device)
    return model, tokenizer, device

def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): 
    """
    Compute ESM embeddings. 
    
    Args:
        model
        tokenizer
        sequences
        device
        average: if True, the average embeddings will be taken and returned
        savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle
    """
    # Correct save path to pickle if necessary
    if savepath is not None:
        if savepath[-4::] != '.pkl': savepath += '.pkl'
        
    # If no max length was passed, just set it to the maximum in the dataset
    max_seq_len = max([len(s) for s in sequences])
    if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
    
    # Initialize an empty dict to store the ESM embeddings
    embedding_dict = {}
    # Iterate through the seqs
    for i in range(len(sequences)):
        sequence = sequences[i]
        # Get the embeddings
        with torch.no_grad():
            inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length)
            inputs = {k:v.to(device) for k, v in inputs.items()}
            
            outputs = model(**inputs)
            embedding = outputs.last_hidden_state

            # remove extra dimension
            embedding = embedding.squeeze(0)
            # remove BOS and EOS tokens
            embedding = embedding[1:-1, :]

            # Convert embeddings to numpy array (if needed)
            embedding = embedding.cpu().numpy()
         
        # Average (if necessary)   
        if average:
            embedding = embedding.mean(0)
        
        # Add to dictionary
        embedding_dict[sequence] = embedding
            
        # Save individual embedding (if necessary)
        if not(savepath is None) and not(save_at_end): 
            with open(savepath, 'ab+') as f:
                d = {sequence: embedding}
                pickle.dump(d, f)

        # Print update (if necessary)
        if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
    
    # Dump all at once at the end (if necessary)
    if not(savepath is None):
        # If saving for the first time, just dump it
        if save_at_end:
            with open(savepath, 'wb') as f:
                pickle.dump(embedding_dict, f)
        # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
        else:
            redump_pickle_dictionary(savepath)
    
    # Return the dictionary
    return embedding_dict

def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
    # Correct save path to pickle if necessary
    if savepath is not None:
        if savepath[-4::] != '.pkl': savepath += '.pkl'
    
    # If no max length was passed, just set it to the maximum in the dataset
    max_seq_len = max([len(s) for s in sequences])
    if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
        
    # the ProtT5 tokenizer requires that there are spaces between residues 
    spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer
    
    # Store embeddings here
    embedding_dict = {} # store embeddings here
    
    for i in range(0, len(spaced_sequences)):
        spaced_sequence = spaced_sequences[i] # get current sequence
        seq = spaced_sequence.replace(" ", "")

        with torch.no_grad():
          inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1
          inputs = {k:v.to(device) for k, v in inputs.items()}

          # Pass through the model with no gradient to get embeddings 
          embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

          # Process the embedding
          seq_length = len(seq)  # length of the sequence is after you remove spaces 
          embedding = embedding_repr.last_hidden_state.squeeze(0)  # remove batch dimension
          embedding = embedding[0:-1]  # remove EOS token (there is no BOS token)
          embedding = embedding.cpu().numpy()  # put on CPU and numpy
          embedding_log = f"\tembedding shape: {embedding.shape}"
          # MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length
          assert embedding.shape[1] == 1024
          assert embedding.shape[0] == seq_length
          
          # Average (if necessary)   
          if average:
            dim_before = embedding.shape
            embedding = embedding.mean(0)
            embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}"

          # Add the embedding to the dictionary 
          embedding_dict[seq] = embedding

          # Save individual embedding (if necessary)
          if not(savepath is None) and not(save_at_end): 
              with open(savepath, 'ab+') as f:
                  d = {seq: embedding}
                  pickle.dump(d, f)

          if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}")
      
    # Dump all at once at the end (if necessary)
    if not(savepath is None):
        # If saving for the first time, just dump it
        if save_at_end:
            with open(savepath, 'wb') as f:
                pickle.dump(embedding_dict, f)
        # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
        else:
            redump_pickle_dictionary(savepath)
    
    # Return the dictionary
    return embedding_dict