abondrn commited on
Commit
463ec0f
1 Parent(s): a9799e9

First commit

Browse files
Files changed (3) hide show
  1. README.md +11 -5
  2. app.py +225 -73
  3. requirements.txt +10 -4
README.md CHANGED
@@ -1,12 +1,18 @@
1
  ---
 
2
  title: SVM
3
- emoji: 🔥
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.21.0
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ # https://huggingface.co/docs/hub/spaces-config-reference
3
  title: SVM
4
+ emoji: 🧬
5
+ colorFrom: green
6
+ colorTo: green
7
+ sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
+ models:
11
+ - InstaDeepAI/nucleotide-transformer-500m-1000g
12
+ - facebook/esmfold_v1
13
+ - sentence-transformers/all-mpnet-base-v2
14
+ python_version: 3.10.4
15
+ license: mit
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,76 +1,228 @@
 
 
 
 
 
 
 
 
 
1
  import torch
2
- import streamlit as st
3
- from transformers import AutoTokenizer, OPTForCausalLM
4
-
5
-
6
- @st.cache_resource
7
- def load_model():
8
- tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-30b")
9
- model = OPTForCausalLM.from_pretrained("facebook/galactica-30b", device_map='auto', low_cpu_mem_usage=True, torch_dtype=torch.float16)
10
- model.gradient_checkpointing_enable()
11
- return tokenizer, model
12
-
13
-
14
- st.set_page_config(
15
- page_title='BioML-SVM',
16
- layout="wide"
17
- )
18
-
19
- with st.spinner("Loading Models and Tokens..."):
20
- tokenizer, model = load_model()
21
-
22
- with st.form(key='my_form'):
23
- col1, col2 = st.columns([10, 1])
24
- text_input = col1.text_input(label='Enter the amino sequence')
25
- with col2:
26
- st.text('')
27
- st.text('')
28
- submit_button = st.form_submit_button(label='Submit')
29
-
30
- if submit_button:
31
- st.session_state['result_done'] = False
32
- # input_text = "[START_AMINO]GHMQSITAGQKVISKHKNGRFYQCEVVRLTTETFYEVNFDDGSFSDNLYPEDIVSQDCLQFGPPAEGEVVQVRWTDGQVYGAKFVASHPIQMYQVEFEDGSQLVVKRDDVYTLDEELP[END_AMINO]"
33
- with st.spinner('Generating...'):
34
- # formatted_text = f"[START_AMINO]{text_input}[END_AMINO]"
35
- # formatted_text = f"Here is the sequence: [START_AMINO]{text_input}[END_AMINO]"
36
- formatted_text = f"{text_input}"
37
- input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda")
38
- outputs = model.generate(
39
- input_ids=input_ids,
40
- max_new_tokens=500
41
- )
42
- result = tokenizer.decode(outputs[0]).replace(formatted_text, "")
43
- st.markdown(result)
44
-
45
- if 'result_done' not in st.session_state or not st.session_state.result_done:
46
- st.session_state['result_done'] = True
47
- st.session_state['previous_state'] = result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  else:
49
- if 'result_done' in st.session_state and st.session_state.result_done:
50
- st.markdown(st.session_state.previous_state)
51
-
52
- if 'result_done' in st.session_state and st.session_state.result_done:
53
- with st.form(key='ask_more'):
54
- col1, col2 = st.columns([10, 1])
55
- text_input = col1.text_input(label='Ask more question')
56
- with col2:
57
- st.text('')
58
- st.text('')
59
- submit_button = st.form_submit_button(label='Submit')
60
-
61
- if submit_button:
62
- with st.spinner('Generating...'):
63
- # formatted_text = f"[START_AMINO]{text_input}[END_AMINO]"
64
- formatted_text = f"Q:{text_input}\n\nA:\n\n"
65
- input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda")
66
-
67
- outputs = model.generate(
68
- input_ids=input_ids,
69
- max_length=len(formatted_text) + 500,
70
- do_sample=True,
71
- top_k=40,
72
- num_beams=1,
73
- num_return_sequences=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
- result = tokenizer.decode(outputs[0]).replace(formatted_text, "")
76
- st.markdown(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
2
+
3
+ import os
4
+ import sys
5
+ from urllib import request
6
+
7
+ import gradio as gr
8
+ import requests
9
+ from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
10
  import torch
11
+ import progres as pg
12
+
13
+
14
+ tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
15
+ model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
16
+ model_nt.eval()
17
+
18
+ tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
19
+ model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
20
+ model_aa.eval()
21
+
22
+ tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
23
+ model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
24
+ model_se.eval()
25
+
26
+
27
+ def nt_embed(sequence: str):
28
+ tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"]
29
+ attention_mask = tokens_ids != tokenizer_nt.pad_token_id
30
+ with torch.no_grad():
31
+ torch_outs = model_nt(
32
+ tokens_ids,#.to('cuda'),
33
+ attention_mask=attention_mask,#.to('cuda'),
34
+ output_hidden_states=True
35
+ )
36
+ last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0]
37
+ return last_layer_CLS
38
+
39
+
40
+ def aa_embed(sequence: str):
41
+ tokens = tokenizer_aa([sequence], return_tensors="pt")
42
+ with torch.no_grad():
43
+ torch_outs = model_aa(**tokens)
44
+ return torch_outs
45
+
46
+
47
+ def se_embed(sentence: str):
48
+ encoded_input = tokenizer_se([sentence], return_tensors='pt')
49
+ with torch.no_grad():
50
+ model_output = model_se(**encoded_input)
51
+ return model_output
52
+
53
+
54
+ def download_data_if_required():
55
+ url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
56
+ fps = [pg.trained_model_fp]
57
+ urls = [f"{url_base}/trained_model.pt"]
58
+ #for targetdb in pre_embedded_dbs:
59
+ # fps.append(os.path.join(database_dir, targetdb + ".pt"))
60
+ # urls.append(f"{url_base}/{targetdb}.pt")
61
+
62
+ if not os.path.isdir(pg.trained_model_dir):
63
+ os.makedirs(pg.trained_model_dir)
64
+ #if not os.path.isdir(database_dir):
65
+ # os.makedirs(database_dir)
66
+
67
+ printed = False
68
+ for fp, url in zip(fps, urls):
69
+ if not os.path.isfile(fp):
70
+ if not printed:
71
+ print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir,
72
+ ", internet connection required, this can take a few minutes",
73
+ sep="", file=sys.stderr)
74
+ printed = True
75
+ try:
76
+ request.urlretrieve(url, fp)
77
+ d = torch.load(fp, map_location="cpu")
78
+ if fp == pg.trained_model_fp:
79
+ assert "model" in d
80
+ else:
81
+ assert "embeddings" in d
82
+ except:
83
+ if os.path.isfile(fp):
84
+ os.remove(fp)
85
+ print("Failed to download from", url, "and save to", fp, file=sys.stderr)
86
+ print("Exiting", file=sys.stderr)
87
+ sys.exit(1)
88
+
89
+ if printed:
90
+ print("Data downloaded successfully", file=sys.stderr)
91
+
92
+
93
+ def get_pdb(pdb_code="", filepath=""):
94
+ if pdb_code is None or pdb_code == "":
95
+ try:
96
+ with open(filepath.name) as f:
97
+ return f.read()
98
+ except AttributeError as e:
99
+ return None
100
  else:
101
+ return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode()
102
+
103
+
104
+ def molecule(pdb):
105
+
106
+ x = (
107
+ """<!DOCTYPE html>
108
+ <html>
109
+ <head>
110
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
111
+ <style>
112
+ body{
113
+ font-family:sans-serif
114
+ }
115
+ .mol-container {
116
+ width: 100%;
117
+ height: 600px;
118
+ position: relative;
119
+ }
120
+ .mol-container select{
121
+ background-image:None;
122
+ }
123
+ </style>
124
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
125
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
126
+ </head>
127
+ <body>
128
+ <div id="container" class="mol-container"></div>
129
+
130
+ <script>
131
+ let pdb = `"""
132
+ + pdb
133
+ + """`
134
+
135
+ $(document).ready(function () {
136
+ let element = $("#container");
137
+ let config = { backgroundColor: "black" };
138
+ let viewer = $3Dmol.createViewer(element, config);
139
+ viewer.addModel(pdb, "pdb");
140
+ viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } });
141
+ viewer.addSurface("MS", { opacity: .5, color: "white" });
142
+ viewer.zoomTo();
143
+ viewer.render();
144
+ viewer.zoom(0.8, 2000);
145
+ })
146
+ </script>
147
+ </body></html>"""
148
+ )
149
+
150
+ return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
151
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
152
+ allow-scripts allow-same-origin allow-popups
153
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
154
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
155
+
156
+
157
+ def str2coords(s):
158
+ coords = []
159
+ for line in s.split('\n'):
160
+ if (line.startswith("ATOM ") or line.startswith("HETATM")) and line[12:16].strip() == "CA":
161
+ coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
162
+ elif line.startswith("ENDMDL"):
163
+ break
164
+ return coords
165
+
166
+
167
+ def update_st(inp, file):
168
+ pdb = get_pdb(inp, file)
169
+ return (molecule(pdb), pg.embed_coords(str2coords(pdb)))
170
+
171
+
172
+ def update_nt(inp):
173
+ return str(nt_embed(inp or ''))
174
+
175
+
176
+ def update_aa(inp):
177
+ return str(aa_embed(inp))
178
+
179
+
180
+ def update_se(inp):
181
+ return str(se_embed(inp))
182
+
183
+
184
+ demo = gr.Blocks()
185
+
186
+ with demo:
187
+ with gr.Tabs():
188
+ with gr.TabItem("PDB Structural Embeddings"):
189
+ with gr.Row():
190
+ with gr.Box():
191
+ inp = gr.Textbox(
192
+ placeholder="PDB Code or upload file below", label="Input structure"
193
+ )
194
+ file = gr.File(file_count="single")
195
+ gr.Examples(["2CBA", "6VXX"], inp)
196
+ btn = gr.Button("View structure")
197
+ gr.Markdown("# PDB viewer using 3Dmol.js")
198
+ mol = gr.HTML()
199
+ emb = gr.Textbox(interactive=False)
200
+ btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb])
201
+ with gr.TabItem("Nucleotide Sequence Embeddings"):
202
+ with gr.Box():
203
+ inp = gr.Textbox(
204
+ placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence"
205
+ )
206
+ btn = gr.Button("View embeddings")
207
+ emb = gr.Textbox(interactive=False)
208
+ btn.click(fn=update_nt, inputs=[inp], outputs=emb)
209
+ with gr.TabItem("Amino Acid Sequence Embeddings"):
210
+ with gr.Box():
211
+ inp = gr.Textbox(
212
+ placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence"
213
  )
214
+ btn = gr.Button("View embeddings")
215
+ emb = gr.Textbox(interactive=False)
216
+ btn.click(fn=update_aa, inputs=[inp], outputs=emb)
217
+ with gr.TabItem("Sentence Embeddings"):
218
+ with gr.Box():
219
+ inp = gr.Textbox(
220
+ placeholder="Your text here", label="Input Sentence"
221
+ )
222
+ btn = gr.Button("View embeddings")
223
+ emb = gr.Textbox(interactive=False)
224
+ btn.click(fn=update_se, inputs=[inp], outputs=emb)
225
+
226
+ if __name__ == "__main__":
227
+ download_data_if_required()
228
+ demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,11 @@
1
- transformers
2
  accelerate
3
- streamlit
4
- # bitsandbytes
5
- # scipy
 
 
 
 
 
 
 
 
 
1
  accelerate
2
+ gradio==3.33.1
3
+ pyg-lib==0.2.0+pt20
4
+ requests==2.31.0
5
+ torch==2.0.1
6
+ torch-cluster==1.6.1
7
+ torch-geometric==2.3.1
8
+ torch-scatter==2.1.1
9
+ torch-sparse==0.6.17
10
+ torch-spline-conv==1.2.2
11
+ transformers==4.29.2