Update app.py
Browse files
app.py
CHANGED
@@ -37,15 +37,17 @@ def generate_caption(protein, prompt):
|
|
37 |
print("dataset prepared")
|
38 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
39 |
print("batches prepared")
|
|
|
|
|
|
|
|
|
|
|
40 |
data_loader = torch.utils.data.DataLoader(
|
41 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
42 |
)
|
43 |
print(f"Read sequences")
|
44 |
return_contacts = "contacts" in include
|
45 |
|
46 |
-
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
47 |
-
model_esm.to('cuda')
|
48 |
-
model_esm.eval()
|
49 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
50 |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
51 |
|
|
|
37 |
print("dataset prepared")
|
38 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
39 |
print("batches prepared")
|
40 |
+
|
41 |
+
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
42 |
+
model_esm.to('cuda')
|
43 |
+
model_esm.eval()
|
44 |
+
|
45 |
data_loader = torch.utils.data.DataLoader(
|
46 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
47 |
)
|
48 |
print(f"Read sequences")
|
49 |
return_contacts = "contacts" in include
|
50 |
|
|
|
|
|
|
|
51 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
52 |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
53 |
|