wenkai commited on
Commit
77b966b
1 Parent(s): 3705c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -21,9 +21,9 @@ model.to('cuda')
21
  # model_esm.to('cuda')
22
  # model_esm.eval()
23
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
24
- model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
25
- model.to('cuda')
26
- model.eval()
27
 
28
  @spaces.GPU
29
  def generate_caption(protein, prompt):
@@ -94,9 +94,9 @@ def generate_caption(protein, prompt):
94
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
95
  esm_emb = result['representations'][36]
96
  '''
97
- inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True)
98
  with torch.no_grad():
99
- outputs = model(**inputs)
100
  esm_emb = outputs.last_hidden_state.detach()[0]
101
 
102
  print("esm embedding generated")
 
21
  # model_esm.to('cuda')
22
  # model_esm.eval()
23
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
24
+ model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
25
+ model_esm.to('cuda')
26
+ model_esm.eval()
27
 
28
  @spaces.GPU
29
  def generate_caption(protein, prompt):
 
94
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
95
  esm_emb = result['representations'][36]
96
  '''
97
+ inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
98
  with torch.no_grad():
99
+ outputs = model_esm(**inputs)
100
  esm_emb = outputs.last_hidden_state.detach()[0]
101
 
102
  print("esm embedding generated")