Witold Wydmański commited on
Commit
8ee7dbf
·
1 Parent(s): 137a7d5

feat: Add get_esm2_embeddings function

Browse files
Files changed (2) hide show
  1. app.py +18 -1
  2. client.py +14 -0
app.py CHANGED
@@ -38,6 +38,22 @@ def fold_prot_locally(sequence):
38
  pdb = convert_outputs_to_pdb(output)
39
  return pdb
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def get_esmfold_embeddings(sequence):
42
  logger.info("Getting embeddings for: " + sequence)
43
  tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
@@ -165,11 +181,12 @@ with gr.Blocks() as demo:
165
  with gr.Row(visible=False):
166
  with gr.Column():
167
  gr.Markdown("## Embeddings")
168
- embs = gr.JSON(label="Embeddings", interactive=False)
169
 
170
  name.change(fn=suggest, inputs=name, outputs=inp)
171
  btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
172
  btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
 
173
  out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
174
 
175
  demo.launch()
 
38
  pdb = convert_outputs_to_pdb(output)
39
  return pdb
40
 
41
+ def get_esm2_embeddings(sequence):
42
+ logger.info("Getting embeddings for: " + sequence)
43
+ tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
44
+
45
+ with torch.no_grad():
46
+ aa = tokenized_input
47
+ L = aa.shape[1]
48
+ device = tokenized_input.device
49
+ attention_mask = torch.ones_like(aa, device=device)
50
+
51
+ # === ESM ===
52
+ esmaa = model.af2_idx_to_esm_idx(aa, attention_mask)
53
+ esm_s = model.compute_language_model_representations(esmaa)
54
+
55
+ return {"res": esm_s.cpu().tolist()}
56
+
57
  def get_esmfold_embeddings(sequence):
58
  logger.info("Getting embeddings for: " + sequence)
59
  tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
 
181
  with gr.Row(visible=False):
182
  with gr.Column():
183
  gr.Markdown("## Embeddings")
184
+ embs = gr.JSON(label="Embeddings")
185
 
186
  name.change(fn=suggest, inputs=name, outputs=inp)
187
  btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
188
  btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
189
+ btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings")
190
  out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
191
 
192
  demo.launch()
client.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from gradio_client import Client
3
+
4
+ #%%
5
+ # client = Client("https://wwydmanski-esmfold.hf.space/")
6
+ client = Client("http://localhost:7860")
7
+
8
+ # %%
9
+ result = client.predict(
10
+ "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", # str in 'sequence' Textbox component
11
+ api_name="/esm2_embeddings")
12
+
13
+ # %%
14
+ result