Update app.py
Browse files
app.py
CHANGED
@@ -21,9 +21,9 @@ model.to('cuda')
|
|
21 |
# model_esm.to('cuda')
|
22 |
# model_esm.eval()
|
23 |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
@spaces.GPU
|
29 |
def generate_caption(protein, prompt):
|
@@ -94,9 +94,9 @@ def generate_caption(protein, prompt):
|
|
94 |
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
95 |
esm_emb = result['representations'][36]
|
96 |
'''
|
97 |
-
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True)
|
98 |
with torch.no_grad():
|
99 |
-
outputs =
|
100 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
101 |
|
102 |
print("esm embedding generated")
|
|
|
21 |
# model_esm.to('cuda')
|
22 |
# model_esm.eval()
|
23 |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
24 |
+
model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
25 |
+
model_esm.to('cuda')
|
26 |
+
model_esm.eval()
|
27 |
|
28 |
@spaces.GPU
|
29 |
def generate_caption(protein, prompt):
|
|
|
94 |
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
95 |
esm_emb = result['representations'][36]
|
96 |
'''
|
97 |
+
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
|
98 |
with torch.no_grad():
|
99 |
+
outputs = model_esm(**inputs)
|
100 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
101 |
|
102 |
print("esm embedding generated")
|