added visualisation of generated molecules and restart button
Browse files
app.py
CHANGED
@@ -6,8 +6,20 @@ import pygad
|
|
6 |
|
7 |
from VQGAE.models import VQGAE, OrderingNetwork
|
8 |
from CGRtools.containers import QueryContainer
|
|
|
9 |
from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# define groups to filter
|
12 |
allene = QueryContainer()
|
13 |
allene.add_atom("C")
|
@@ -121,7 +133,7 @@ def load_data(batch_size):
|
|
121 |
st.title('Inverse QSAR of Tubulin with VQGAE')
|
122 |
|
123 |
with st.sidebar:
|
124 |
-
with st.form("
|
125 |
num_generations = st.slider(
|
126 |
'Number of generations for GA',
|
127 |
min_value=3,
|
@@ -190,7 +202,7 @@ with st.sidebar:
|
|
190 |
)
|
191 |
# 2/3 of num_parents_mating
|
192 |
use_ordering_score = st.toggle('Use ordering score', value=True)
|
193 |
-
batch_size = int(st.number_input("
|
194 |
random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
|
195 |
submit = st.form_submit_button('Start optimisation')
|
196 |
|
@@ -291,31 +303,53 @@ if submit:
|
|
291 |
decoding_bar.progress(decoding_step / total_decoding_steps, text=decoding_progress_text)
|
292 |
gen_stats = pd.DataFrame(results)
|
293 |
decoding_bar.empty()
|
294 |
-
full_stats = pd.concat([gen_stats, chosen_gen[["similarity_score", "rf_score"]]
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
st.download_button(
|
299 |
label="Download results as CSV",
|
300 |
-
data=convert_df(
|
301 |
-
file_name='
|
302 |
mime='text/csv',
|
303 |
)
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
6 |
|
7 |
from VQGAE.models import VQGAE, OrderingNetwork
|
8 |
from CGRtools.containers import QueryContainer
|
9 |
+
from CGRtools.utils import grid_depict
|
10 |
from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
|
11 |
|
12 |
+
import base64
|
13 |
+
from streamlit.components.v1 import html
|
14 |
+
|
15 |
+
|
16 |
+
def render_svg(svg_string):
|
17 |
+
"""Renders the given svg string."""
|
18 |
+
c = st.container()
|
19 |
+
with c:
|
20 |
+
html(svg_string)
|
21 |
+
|
22 |
+
|
23 |
# define groups to filter
|
24 |
allene = QueryContainer()
|
25 |
allene.add_atom("C")
|
|
|
133 |
st.title('Inverse QSAR of Tubulin with VQGAE')
|
134 |
|
135 |
with st.sidebar:
|
136 |
+
with st.form("ga_options"):
|
137 |
num_generations = st.slider(
|
138 |
'Number of generations for GA',
|
139 |
min_value=3,
|
|
|
202 |
)
|
203 |
# 2/3 of num_parents_mating
|
204 |
use_ordering_score = st.toggle('Use ordering score', value=True)
|
205 |
+
batch_size = int(st.number_input("Batch size", value=200, placeholder="Type a number..."))
|
206 |
random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
|
207 |
submit = st.form_submit_button('Start optimisation')
|
208 |
|
|
|
303 |
decoding_bar.progress(decoding_step / total_decoding_steps, text=decoding_progress_text)
|
304 |
gen_stats = pd.DataFrame(results)
|
305 |
decoding_bar.empty()
|
306 |
+
full_stats = pd.concat([gen_stats, chosen_gen.reset_index()[["similarity_score", "rf_score"]]], axis=1)
|
307 |
+
valid_gen_stats = full_stats[full_stats.validity == 1]
|
308 |
+
|
309 |
+
valid_gen_mols = []
|
310 |
+
for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
|
311 |
+
valid_gen_mols.append(gen_molecules[i])
|
312 |
+
|
313 |
+
filtered_gen_mols = []
|
314 |
+
filtered_indices = []
|
315 |
+
for mol_i, mol in enumerate(valid_gen_mols):
|
316 |
+
is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
|
317 |
+
is_ring = False
|
318 |
+
for ring in mol.sssr:
|
319 |
+
if len(ring) > 8 or len(ring) < 4:
|
320 |
+
is_ring = True
|
321 |
+
break
|
322 |
+
if not is_frag and not is_ring:
|
323 |
+
filtered_gen_mols.append(mol)
|
324 |
+
filtered_indices.append(mol_i)
|
325 |
+
|
326 |
+
filtered_gen_stats = valid_gen_stats.iloc[filtered_indices]
|
327 |
+
|
328 |
+
st.subheader('Generation results', divider='rainbow')
|
329 |
+
st.dataframe(filtered_gen_stats)
|
330 |
st.download_button(
|
331 |
label="Download results as CSV",
|
332 |
+
data=convert_df(filtered_gen_stats),
|
333 |
+
file_name='vqgae_tubulin_inhibitors_valid.csv',
|
334 |
mime='text/csv',
|
335 |
)
|
336 |
|
337 |
+
st.subheader('Examples of generated molecules')
|
338 |
+
examples_inds = sorted(filtered_gen_stats.sort_values(by=["rf_score"], ascending=False).index[:9])
|
339 |
+
examples = [filtered_gen_mols[i] for i in examples_inds]
|
340 |
+
svg = grid_depict(examples, 2)
|
341 |
+
render_svg(svg)
|
342 |
+
|
343 |
+
show_full_stats = st.checkbox('Show full stats')
|
344 |
+
if show_full_stats:
|
345 |
+
st.dataframe(full_stats)
|
346 |
+
|
347 |
+
st.download_button(
|
348 |
+
label="Download full results as CSV",
|
349 |
+
data=convert_df(full_stats),
|
350 |
+
file_name='vqgae_tubulin_inhibitors_full.csv',
|
351 |
+
mime='text/csv',
|
352 |
+
)
|
353 |
+
|
354 |
+
if st.button("Restart"):
|
355 |
+
st.rerun()
|