wenkai commited on
Commit
0f66ac3
1 Parent(s): c8e59d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -37,15 +37,17 @@ def generate_caption(protein, prompt):
37
  print("dataset prepared")
38
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
39
  print("batches prepared")
 
 
 
 
 
40
  data_loader = torch.utils.data.DataLoader(
41
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
42
  )
43
  print(f"Read sequences")
44
  return_contacts = "contacts" in include
45
 
46
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
47
- model_esm.to('cuda')
48
- model_esm.eval()
49
  assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
50
  repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
51
 
 
37
  print("dataset prepared")
38
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
39
  print("batches prepared")
40
+
41
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
42
+ model_esm.to('cuda')
43
+ model_esm.eval()
44
+
45
  data_loader = torch.utils.data.DataLoader(
46
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
47
  )
48
  print(f"Read sequences")
49
  return_contacts = "contacts" in include
50
 
 
 
 
51
  assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
52
  repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
53