tagirshin commited on
Commit
845fb10
1 Parent(s): cea7e20

added visualisation of generated molecules and restart button

Browse files
Files changed (1) hide show
  1. app.py +59 -25
app.py CHANGED
@@ -6,8 +6,20 @@ import pygad
6
 
7
  from VQGAE.models import VQGAE, OrderingNetwork
8
  from CGRtools.containers import QueryContainer
 
9
  from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  # define groups to filter
12
  allene = QueryContainer()
13
  allene.add_atom("C")
@@ -121,7 +133,7 @@ def load_data(batch_size):
121
  st.title('Inverse QSAR of Tubulin with VQGAE')
122
 
123
  with st.sidebar:
124
- with st.form("my_form"):
125
  num_generations = st.slider(
126
  'Number of generations for GA',
127
  min_value=3,
@@ -190,7 +202,7 @@ with st.sidebar:
190
  )
191
  # 2/3 of num_parents_mating
192
  use_ordering_score = st.toggle('Use ordering score', value=True)
193
- batch_size = int(st.number_input("Random seed", value=200, placeholder="Type a number..."))
194
  random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
195
  submit = st.form_submit_button('Start optimisation')
196
 
@@ -291,31 +303,53 @@ if submit:
291
  decoding_bar.progress(decoding_step / total_decoding_steps, text=decoding_progress_text)
292
  gen_stats = pd.DataFrame(results)
293
  decoding_bar.empty()
294
- full_stats = pd.concat([gen_stats, chosen_gen[["similarity_score", "rf_score"]].reset_index(), ], axis=1, ignore_index=False)
295
-
296
- st.dataframe(full_stats)
297
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  st.download_button(
299
  label="Download results as CSV",
300
- data=convert_df(full_stats),
301
- file_name='vqgae_tubulin_inhibitors.csv',
302
  mime='text/csv',
303
  )
304
 
305
- # valid_gen_stats = full_stats[full_stats.valid == 1]
306
- #
307
- # valid_gen_mols = []
308
- # for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
309
- # mol = gen_molecules[i]
310
- # valid_gen_mols.append(mol)
311
- #
312
- # filtered_gen_mols = []
313
- # for mol in valid_gen_mols:
314
- # is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
315
- # is_macro = False
316
- # for ring in mol.sssr:
317
- # if len(ring) > 8 or len(ring) < 4:
318
- # is_macro = True
319
- # break
320
- # if not is_frag and not is_macro:
321
- # filtered_gen_mols.append(mol)
 
 
 
6
 
7
  from VQGAE.models import VQGAE, OrderingNetwork
8
  from CGRtools.containers import QueryContainer
9
+ from CGRtools.utils import grid_depict
10
  from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
11
 
12
+ import base64
13
+ from streamlit.components.v1 import html
14
+
15
+
16
+ def render_svg(svg_string):
17
+ """Renders the given svg string."""
18
+ c = st.container()
19
+ with c:
20
+ html(svg_string)
21
+
22
+
23
  # define groups to filter
24
  allene = QueryContainer()
25
  allene.add_atom("C")
 
133
  st.title('Inverse QSAR of Tubulin with VQGAE')
134
 
135
  with st.sidebar:
136
+ with st.form("ga_options"):
137
  num_generations = st.slider(
138
  'Number of generations for GA',
139
  min_value=3,
 
202
  )
203
  # 2/3 of num_parents_mating
204
  use_ordering_score = st.toggle('Use ordering score', value=True)
205
+ batch_size = int(st.number_input("Batch size", value=200, placeholder="Type a number..."))
206
  random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
207
  submit = st.form_submit_button('Start optimisation')
208
 
 
303
  decoding_bar.progress(decoding_step / total_decoding_steps, text=decoding_progress_text)
304
  gen_stats = pd.DataFrame(results)
305
  decoding_bar.empty()
306
+ full_stats = pd.concat([gen_stats, chosen_gen.reset_index()[["similarity_score", "rf_score"]]], axis=1)
307
+ valid_gen_stats = full_stats[full_stats.validity == 1]
308
+
309
+ valid_gen_mols = []
310
+ for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
311
+ valid_gen_mols.append(gen_molecules[i])
312
+
313
+ filtered_gen_mols = []
314
+ filtered_indices = []
315
+ for mol_i, mol in enumerate(valid_gen_mols):
316
+ is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
317
+ is_ring = False
318
+ for ring in mol.sssr:
319
+ if len(ring) > 8 or len(ring) < 4:
320
+ is_ring = True
321
+ break
322
+ if not is_frag and not is_ring:
323
+ filtered_gen_mols.append(mol)
324
+ filtered_indices.append(mol_i)
325
+
326
+ filtered_gen_stats = valid_gen_stats.iloc[filtered_indices]
327
+
328
+ st.subheader('Generation results', divider='rainbow')
329
+ st.dataframe(filtered_gen_stats)
330
  st.download_button(
331
  label="Download results as CSV",
332
+ data=convert_df(filtered_gen_stats),
333
+ file_name='vqgae_tubulin_inhibitors_valid.csv',
334
  mime='text/csv',
335
  )
336
 
337
+ st.subheader('Examples of generated molecules')
338
+ examples_inds = sorted(filtered_gen_stats.sort_values(by=["rf_score"], ascending=False).index[:9])
339
+ examples = [filtered_gen_mols[i] for i in examples_inds]
340
+ svg = grid_depict(examples, 2)
341
+ render_svg(svg)
342
+
343
+ show_full_stats = st.checkbox('Show full stats')
344
+ if show_full_stats:
345
+ st.dataframe(full_stats)
346
+
347
+ st.download_button(
348
+ label="Download full results as CSV",
349
+ data=convert_df(full_stats),
350
+ file_name='vqgae_tubulin_inhibitors_full.csv',
351
+ mime='text/csv',
352
+ )
353
+
354
+ if st.button("Restart"):
355
+ st.rerun()