wenkai commited on
Commit
3e891e6
1 Parent(s): 38f9971

Update esm_scripts/extract.py

Browse files
Files changed (1) hide show
  1. esm_scripts/extract.py +65 -0
esm_scripts/extract.py CHANGED
@@ -131,6 +131,71 @@ def run(args):
131
  )
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def main():
135
  parser = create_parser()
136
  args = parser.parse_args()
 
131
  )
132
 
133
 
134
+ def run_demo(model_location, fasta_file, output_dir, include, nogpu,
135
+ repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
136
+ model, alphabet = pretrained.load_model_and_alphabet(model_location)
137
+ model.eval()
138
+ if isinstance(model, MSATransformer):
139
+ raise ValueError(
140
+ "This script currently does not handle models with MSA input (MSA Transformer)."
141
+ )
142
+ if torch.cuda.is_available() and not nogpu:
143
+ model = model.cuda()
144
+ print("Transferred model to GPU")
145
+
146
+ dataset = FastaBatchedDataset.from_file(fasta_file)
147
+ batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
148
+ data_loader = torch.utils.data.DataLoader(
149
+ dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
150
+ )
151
+ print(f"Read {fasta_file} with {len(dataset)} sequences")
152
+
153
+ output_dir.mkdir(parents=True, exist_ok=True)
154
+ return_contacts = "contacts" in include
155
+
156
+ assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
157
+ repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
158
+
159
+ with torch.no_grad():
160
+ for batch_idx, (labels, strs, toks) in enumerate(data_loader):
161
+ print(
162
+ f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
163
+ )
164
+ if torch.cuda.is_available() and not nogpu:
165
+ toks = toks.to(device="cuda", non_blocking=True)
166
+
167
+ out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
168
+
169
+ logits = out["logits"].to(device="cpu")
170
+ representations = {
171
+ layer: t.to(device="cpu") for layer, t in out["representations"].items()
172
+ }
173
+ if return_contacts:
174
+ contacts = out["contacts"].to(device="cpu")
175
+
176
+ for i, label in enumerate(labels):
177
+ result = {"label": label}
178
+ truncate_len = min(truncation_seq_length, len(strs[i]))
179
+ # Call clone on tensors to ensure tensors are not views into a larger representation
180
+ # See https://github.com/pytorch/pytorch/issues/1995
181
+ if "per_tok" in include:
182
+ result["representations"] = {
183
+ layer: t[i, 1 : truncate_len + 1].clone()
184
+ for layer, t in representations.items()
185
+ }
186
+ if "mean" in include:
187
+ result["mean_representations"] = {
188
+ layer: t[i, 1 : truncate_len + 1].mean(0).clone()
189
+ for layer, t in representations.items()
190
+ }
191
+ if "bos" in include:
192
+ result["bos_representations"] = {
193
+ layer: t[i, 0].clone() for layer, t in representations.items()
194
+ }
195
+ if return_contacts:
196
+ result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
197
+
198
+
199
  def main():
200
  parser = create_parser()
201
  args = parser.parse_args()