uragankatrrin commited on
Commit
69c560c
1 Parent(s): c8a5819

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ from mhnreact.inspect import list_models, load_clf
4
+ from rdkit.Chem import rdChemReactions as Reaction
5
+ from rdkit.Chem.Draw import rdMolDraw2D
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from ssretro_template import ssretro, ssretro_custom
8
+
9
+ def custom_template_file(template: str):
10
+ temp = [x.strip() for x in template.split(',')]
11
+ template_dict = {}
12
+ for i in range(len(temp)):
13
+ template_dict[i] = temp[i]
14
+ with open('saved_dictionary.pkl', 'wb') as f:
15
+ pickle.dump(template_dict, f)
16
+ return template_dict
17
+
18
+
19
+ def get_output(p):
20
+ rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
21
+ d = rdMolDraw2D.MolDraw2DCairo(800, 200)
22
+ d.DrawReaction(rxn, highlightByReactant=False)
23
+ d.FinishDrawing()
24
+ text = d.GetDrawingText()
25
+
26
+ return text
27
+
28
+
29
+ def ssretro_prediction(molecule, custom_template=False):
30
+ model_fn = list_models()[0]
31
+ retro_clf = load_clf(model_fn)
32
+ predict, txt = [], []
33
+
34
+ if custom_template:
35
+ outputs = ssretro_custom(molecule, retro_clf)
36
+ else:
37
+ outputs = ssretro(molecule, retro_clf)
38
+
39
+ for pred in outputs:
40
+ txt.append(
41
+ f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;')
42
+ predict.append(get_output(pred["reaction"]))
43
+
44
+ return predict, txt
45
+
46
+
47
+ def mhn_react_backend(mol, use_custom: bool):
48
+ output_dir = "outputs"
49
+ formatter = "03d"
50
+ images = []
51
+
52
+ predictions, comments = ssretro_prediction(mol, use_custom)
53
+
54
+ for i in range(len(predictions)):
55
+ output_im = f"{str(output_dir)}/{format(i, formatter)}.png"
56
+
57
+ with open(output_im, "wb") as fh:
58
+ fh.write(predictions[i])
59
+ fh.close()
60
+ font = ImageFont.truetype(r'tools/arial.ttf', 20)
61
+ img = Image.open(output_im)
62
+ right = 10
63
+ left = 10
64
+ top = 50
65
+ bottom = 1
66
+
67
+ width, height = img.size
68
+
69
+ new_width = width + right + left
70
+ new_height = height + top + bottom
71
+
72
+ result = Image.new(img.mode, (new_width, new_height), (255, 255, 255))
73
+ result.paste(img, (left, top))
74
+
75
+ I1 = ImageDraw.Draw(result)
76
+ I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0))
77
+ images.append(result)
78
+ result.save(output_im)
79
+
80
+ return images
81
+
82
+
83
+ with gr.Blocks() as demo:
84
+ gr.Markdown(
85
+ """
86
+ [![Github](https://img.shields.io/badge/github-%20mhn--react-blue)](https://img.shields.io/badge/github-%20mhn--react-blue)
87
+ [![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065)
88
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb)
89
+ ### MHN-react
90
+ Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities,
91
+ molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis.
92
+ """
93
+ )
94
+
95
+ with gr.Accordion("Information"):
96
+ gr.Markdown("use one of example molecules <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2, <br> CN1CCC[C@H]1c2cccnc2, <br> OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N"
97
+ "In case the output is empty, no applicable templates were found"
98
+ )
99
+
100
+ with gr.Tab("Generate Templates"):
101
+ with gr.Row():
102
+ with gr.Column(scale = 1):
103
+ inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule")
104
+ radio = gr.Radio([False, True], label="use custom templates")
105
+
106
+ btn = gr.Button(value="Generate")
107
+
108
+ with gr.Column(scale=2):
109
+ out = gr.Gallery(label="retro-synthesis")
110
+
111
+ btn.click(mhn_react_backend, [inp, radio], out)
112
+
113
+ with gr.Tab("Create custom templates"):
114
+ gr.Markdown(
115
+ """
116
+ Input the templates separated by comma. <br> Please do not upload templates one-by-one
117
+ """
118
+ )
119
+ with gr.Column():
120
+ inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)")
121
+ btn = gr.Button(value="upload")
122
+ out_t = gr.Textbox(label = "added templates")
123
+ btn.click(custom_template_file, inp_t, out_t)
124
+
125
+ demo.launch()