abondrn commited on
Commit
bc4ccb5
1 Parent(s): 64a6606

Added msa and go stubs

Browse files
Files changed (2) hide show
  1. app.py +43 -1
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
2
-
3
  import os
4
  import sys
5
  from urllib import request
@@ -9,6 +9,9 @@ import requests
9
  from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
10
  import torch
11
  import progres as pg
 
 
 
12
 
13
 
14
  tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
@@ -23,6 +26,11 @@ tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-ba
23
  model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
24
  model_se.eval()
25
 
 
 
 
 
 
26
 
27
  def nt_embed(sequence: str):
28
  tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"]
@@ -51,6 +59,17 @@ def se_embed(sentence: str):
51
  return model_output[0]
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
54
  def download_data_if_required():
55
  url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
56
  fps = [pg.trained_model_fp]
@@ -181,6 +200,14 @@ def update_se(inp):
181
  return str(se_embed(inp))
182
 
183
 
 
 
 
 
 
 
 
 
184
  demo = gr.Blocks()
185
 
186
  with demo:
@@ -222,6 +249,21 @@ with demo:
222
  btn = gr.Button("View embeddings")
223
  emb = gr.Textbox(interactive=False)
224
  btn.click(fn=update_se, inputs=[inp], outputs=emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  if __name__ == "__main__":
227
  download_data_if_required()
 
1
  # credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
2
+ from typing import Tuple
3
  import os
4
  import sys
5
  from urllib import request
 
9
  from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
10
  import torch
11
  import progres as pg
12
+ import esm
13
+
14
+ import msa
15
 
16
 
17
  tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
 
26
  model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
27
  model_se.eval()
28
 
29
+ msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
30
+ msa_transformer = msa_transformer.eval()
31
+ msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()
32
+
33
+
34
 
35
  def nt_embed(sequence: str):
36
  tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"]
 
59
  return model_output[0]
60
 
61
 
62
+ def msa_embed(msa):
63
+ inputs = msa.greedy_select(inputs, num_seqs=128) # can change this to pass more/fewer sequences
64
+ msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
65
+ msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
66
+
67
+ temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
68
+ temp = temp[12][:,:,0,:]
69
+ temp = torch.mean(temp,(0,1))
70
+ return temp
71
+
72
+
73
  def download_data_if_required():
74
  url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
75
  fps = [pg.trained_model_fp]
 
200
  return str(se_embed(inp))
201
 
202
 
203
+ def update_go(inp):
204
+ return str(go_embed(inp))
205
+
206
+
207
+ def update_msa(inp):
208
+ return str(msa_embed(msa.read_msa(inp)))
209
+
210
+
211
  demo = gr.Blocks()
212
 
213
  with demo:
 
249
  btn = gr.Button("View embeddings")
250
  emb = gr.Textbox(interactive=False)
251
  btn.click(fn=update_se, inputs=[inp], outputs=emb)
252
+ with gr.TabItem("MSA Embeddings"):
253
+ with gr.Box():
254
+ inp = gr.File(file_count="single", label="Input MSA")
255
+ btn = gr.Button("View embeddings")
256
+ emb = gr.Textbox(interactive=False)
257
+ btn.click(fn=update_msa, inputs=[inp], outputs=emb)
258
+ with gr.TabItem("GO Embeddings"):
259
+ with gr.Box():
260
+ inp = gr.Textbox(
261
+ placeholder="", label="Input GO Terms"
262
+ )
263
+ btn = gr.Button("View embeddings")
264
+ emb = gr.Textbox(interactive=False)
265
+ btn.click(fn=update_go, inputs=[inp], outputs=emb)
266
+
267
 
268
  if __name__ == "__main__":
269
  download_data_if_required()
requirements.txt CHANGED
@@ -5,8 +5,9 @@ requests==2.31.0
5
  torch==2.0.1
6
  --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-cluster==1.6.1
7
  torch-geometric==2.3.1
8
- --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-scatter==2.1.1
9
  --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-sparse==0.6.17
10
  --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-spline-conv==1.2.2
11
  transformers==4.29.2
12
- progres
 
 
5
  torch==2.0.1
6
  --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-cluster==1.6.1
7
  torch-geometric==2.3.1
8
+ torch-scatter==2.1.1
9
  --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-sparse==0.6.17
10
  --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-spline-conv==1.2.2
11
  transformers==4.29.2
12
+ progres
13
+ fair-esm