Spaces:
Runtime error
Runtime error
uragankatrrin
commited on
Commit
·
69c560c
1
Parent(s):
c8a5819
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pickle
|
3 |
+
from mhnreact.inspect import list_models, load_clf
|
4 |
+
from rdkit.Chem import rdChemReactions as Reaction
|
5 |
+
from rdkit.Chem.Draw import rdMolDraw2D
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
+
from ssretro_template import ssretro, ssretro_custom
|
8 |
+
|
9 |
+
def custom_template_file(template: str):
|
10 |
+
temp = [x.strip() for x in template.split(',')]
|
11 |
+
template_dict = {}
|
12 |
+
for i in range(len(temp)):
|
13 |
+
template_dict[i] = temp[i]
|
14 |
+
with open('saved_dictionary.pkl', 'wb') as f:
|
15 |
+
pickle.dump(template_dict, f)
|
16 |
+
return template_dict
|
17 |
+
|
18 |
+
|
19 |
+
def get_output(p):
|
20 |
+
rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
|
21 |
+
d = rdMolDraw2D.MolDraw2DCairo(800, 200)
|
22 |
+
d.DrawReaction(rxn, highlightByReactant=False)
|
23 |
+
d.FinishDrawing()
|
24 |
+
text = d.GetDrawingText()
|
25 |
+
|
26 |
+
return text
|
27 |
+
|
28 |
+
|
29 |
+
def ssretro_prediction(molecule, custom_template=False):
|
30 |
+
model_fn = list_models()[0]
|
31 |
+
retro_clf = load_clf(model_fn)
|
32 |
+
predict, txt = [], []
|
33 |
+
|
34 |
+
if custom_template:
|
35 |
+
outputs = ssretro_custom(molecule, retro_clf)
|
36 |
+
else:
|
37 |
+
outputs = ssretro(molecule, retro_clf)
|
38 |
+
|
39 |
+
for pred in outputs:
|
40 |
+
txt.append(
|
41 |
+
f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;')
|
42 |
+
predict.append(get_output(pred["reaction"]))
|
43 |
+
|
44 |
+
return predict, txt
|
45 |
+
|
46 |
+
|
47 |
+
def mhn_react_backend(mol, use_custom: bool):
|
48 |
+
output_dir = "outputs"
|
49 |
+
formatter = "03d"
|
50 |
+
images = []
|
51 |
+
|
52 |
+
predictions, comments = ssretro_prediction(mol, use_custom)
|
53 |
+
|
54 |
+
for i in range(len(predictions)):
|
55 |
+
output_im = f"{str(output_dir)}/{format(i, formatter)}.png"
|
56 |
+
|
57 |
+
with open(output_im, "wb") as fh:
|
58 |
+
fh.write(predictions[i])
|
59 |
+
fh.close()
|
60 |
+
font = ImageFont.truetype(r'tools/arial.ttf', 20)
|
61 |
+
img = Image.open(output_im)
|
62 |
+
right = 10
|
63 |
+
left = 10
|
64 |
+
top = 50
|
65 |
+
bottom = 1
|
66 |
+
|
67 |
+
width, height = img.size
|
68 |
+
|
69 |
+
new_width = width + right + left
|
70 |
+
new_height = height + top + bottom
|
71 |
+
|
72 |
+
result = Image.new(img.mode, (new_width, new_height), (255, 255, 255))
|
73 |
+
result.paste(img, (left, top))
|
74 |
+
|
75 |
+
I1 = ImageDraw.Draw(result)
|
76 |
+
I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0))
|
77 |
+
images.append(result)
|
78 |
+
result.save(output_im)
|
79 |
+
|
80 |
+
return images
|
81 |
+
|
82 |
+
|
83 |
+
with gr.Blocks() as demo:
|
84 |
+
gr.Markdown(
|
85 |
+
"""
|
86 |
+
[![Github](https://img.shields.io/badge/github-%20mhn--react-blue)](https://img.shields.io/badge/github-%20mhn--react-blue)
|
87 |
+
[![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065)
|
88 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb)
|
89 |
+
### MHN-react
|
90 |
+
Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities,
|
91 |
+
molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis.
|
92 |
+
"""
|
93 |
+
)
|
94 |
+
|
95 |
+
with gr.Accordion("Information"):
|
96 |
+
gr.Markdown("use one of example molecules <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2, <br> CN1CCC[C@H]1c2cccnc2, <br> OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N"
|
97 |
+
"In case the output is empty, no applicable templates were found"
|
98 |
+
)
|
99 |
+
|
100 |
+
with gr.Tab("Generate Templates"):
|
101 |
+
with gr.Row():
|
102 |
+
with gr.Column(scale = 1):
|
103 |
+
inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule")
|
104 |
+
radio = gr.Radio([False, True], label="use custom templates")
|
105 |
+
|
106 |
+
btn = gr.Button(value="Generate")
|
107 |
+
|
108 |
+
with gr.Column(scale=2):
|
109 |
+
out = gr.Gallery(label="retro-synthesis")
|
110 |
+
|
111 |
+
btn.click(mhn_react_backend, [inp, radio], out)
|
112 |
+
|
113 |
+
with gr.Tab("Create custom templates"):
|
114 |
+
gr.Markdown(
|
115 |
+
"""
|
116 |
+
Input the templates separated by comma. <br> Please do not upload templates one-by-one
|
117 |
+
"""
|
118 |
+
)
|
119 |
+
with gr.Column():
|
120 |
+
inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)")
|
121 |
+
btn = gr.Button(value="upload")
|
122 |
+
out_t = gr.Textbox(label = "added templates")
|
123 |
+
btn.click(custom_template_file, inp_t, out_t)
|
124 |
+
|
125 |
+
demo.launch()
|