wenkai commited on
Commit
52ea231
1 Parent(s): 3e891e6

Update esm_scripts/extract.py

Browse files
Files changed (1) hide show
  1. esm_scripts/extract.py +5 -3
esm_scripts/extract.py CHANGED
@@ -131,7 +131,7 @@ def run(args):
131
  )
132
 
133
 
134
- def run_demo(model_location, fasta_file, output_dir, include, nogpu,
135
  repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
136
  model, alphabet = pretrained.load_model_and_alphabet(model_location)
137
  model.eval()
@@ -143,14 +143,14 @@ def run_demo(model_location, fasta_file, output_dir, include, nogpu,
143
  model = model.cuda()
144
  print("Transferred model to GPU")
145
 
146
- dataset = FastaBatchedDataset.from_file(fasta_file)
147
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
148
  data_loader = torch.utils.data.DataLoader(
149
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
150
  )
151
  print(f"Read {fasta_file} with {len(dataset)} sequences")
152
 
153
- output_dir.mkdir(parents=True, exist_ok=True)
154
  return_contacts = "contacts" in include
155
 
156
  assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
@@ -194,6 +194,8 @@ def run_demo(model_location, fasta_file, output_dir, include, nogpu,
194
  }
195
  if return_contacts:
196
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
 
 
197
 
198
 
199
  def main():
 
131
  )
132
 
133
 
134
+ def run_demo(protein_name, protein_seq, model_location, include, nogpu,
135
  repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
136
  model, alphabet = pretrained.load_model_and_alphabet(model_location)
137
  model.eval()
 
143
  model = model.cuda()
144
  print("Transferred model to GPU")
145
 
146
+ dataset = FastaBatchedDataset([protein_name], [protein_seq])
147
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
148
  data_loader = torch.utils.data.DataLoader(
149
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
150
  )
151
  print(f"Read {fasta_file} with {len(dataset)} sequences")
152
 
153
+ # output_dir.mkdir(parents=True, exist_ok=True)
154
  return_contacts = "contacts" in include
155
 
156
  assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
 
194
  }
195
  if return_contacts:
196
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
197
+
198
+ return result['representations'][36]
199
 
200
 
201
  def main():