uragankatrrin commited on
Commit
a099a32
1 Parent(s): b18e18b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +126 -0
  2. saved_dictionary.pkl +3 -0
  3. ssretro_template.py +146 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ from mhnreact.inspect import list_models, load_clf
4
+ from rdkit.Chem import rdChemReactions as Reaction
5
+ from rdkit.Chem.Draw import rdMolDraw2D
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from ssretro_template import ssretro, ssretro_custom
8
+
9
+ def custom_template_file(template: str):
10
+ temp = [x.strip() for x in template.split(',')]
11
+ template_dict = {}
12
+ for i in range(len(temp)):
13
+ template_dict[i] = temp[i]
14
+ with open('saved_dictionary.pkl', 'wb') as f:
15
+ pickle.dump(template_dict, f)
16
+ return template_dict
17
+
18
+
19
+ def get_output(p):
20
+ rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
21
+ d = rdMolDraw2D.MolDraw2DCairo(800, 200)
22
+ d.DrawReaction(rxn, highlightByReactant=False)
23
+ d.FinishDrawing()
24
+ text = d.GetDrawingText()
25
+
26
+ return text
27
+
28
+
29
+ def ssretro_prediction(molecule, custom_template=False):
30
+ model_fn = list_models()[0]
31
+ retro_clf = load_clf(model_fn)
32
+ predict, txt = [], []
33
+
34
+ if custom_template:
35
+ outputs = ssretro_custom(molecule, retro_clf)
36
+ else:
37
+ outputs = ssretro(molecule, retro_clf)
38
+
39
+ for pred in outputs:
40
+ txt.append(
41
+ f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;')
42
+ predict.append(get_output(pred["reaction"]))
43
+
44
+ return predict, txt
45
+
46
+
47
+ def mhn_react_backend(mol, use_custom: bool):
48
+ output_dir = "outputs"
49
+ formatter = "03d"
50
+ images = []
51
+
52
+ predictions, comments = ssretro_prediction(mol, use_custom)
53
+
54
+ for i in range(len(predictions)):
55
+ output_im = f"{str(output_dir)}/{format(i, formatter)}.png"
56
+
57
+ with open(output_im, "wb") as fh:
58
+ fh.write(predictions[i])
59
+ fh.close()
60
+ font = ImageFont.truetype(r'tools/arial.ttf', 20)
61
+ img = Image.open(output_im)
62
+ right = 10
63
+ left = 10
64
+ top = 50
65
+ bottom = 1
66
+
67
+ width, height = img.size
68
+
69
+ new_width = width + right + left
70
+ new_height = height + top + bottom
71
+
72
+ result = Image.new(img.mode, (new_width, new_height), (255, 255, 255))
73
+ result.paste(img, (left, top))
74
+
75
+ I1 = ImageDraw.Draw(result)
76
+ I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0))
77
+ images.append(result)
78
+ result.save(output_im)
79
+
80
+ return images
81
+
82
+
83
+ with gr.Blocks() as demo:
84
+ gr.Markdown(
85
+ """
86
+ __[Github](https://github.com/ml-jku/mhn-react)__
87
+ __[acs.jcim](https://pubs.acs.org/doi/10.1021/acs.jcim.1c01065)__
88
+ __[Google Colab](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb)__
89
+ ### MHN-react
90
+ Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities,
91
+ molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis.
92
+ """
93
+ )
94
+
95
+ with gr.Accordion("Guide"):
96
+ gr.Markdown("Information (add) <br> "
97
+ "In case the output is empty => No suitable templates?"
98
+ "use one of example molecules: <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2"
99
+ )
100
+
101
+ with gr.Tab("Generate Templates"):
102
+ with gr.Row():
103
+ with gr.Column(scale = 1):
104
+ inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule")
105
+ radio = gr.Radio([False, True], label="use custom templates")
106
+
107
+ btn = gr.Button(value="Generate")
108
+
109
+ with gr.Column(scale=2):
110
+ out = gr.Gallery(label="retro-synthesis")
111
+
112
+ btn.click(mhn_react_backend, [inp, radio], out)
113
+
114
+ with gr.Tab("Create custom templates"):
115
+ gr.Markdown(
116
+ """
117
+ Input the templates separated by comma. <br> Please do not upload templates one-by-one
118
+ """
119
+ )
120
+ with gr.Column():
121
+ inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)")
122
+ btn = gr.Button(value="upload")
123
+ out_t = gr.Textbox(label = "added templates")
124
+ btn.click(custom_template_file, inp_t, out_t)
125
+
126
+ demo.launch()
saved_dictionary.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87ccfca32bc3f8a4bad6c3fe20b97d47bb0f55f4913920b4f8c707d8b4e3344e
3
+ size 766
ssretro_template.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit.Chem import AllChem
2
+ from mhnreact.data import load_dataset_from_csv
3
+ from mhnreact.molutils import convert_smiles_to_fp
4
+ from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
5
+ import torch
6
+ import pickle
7
+
8
+ reaction_superclass_names = {
9
+ 1: 'Heteroatom alkylation and arylation',
10
+ 2: 'Acylation and related processes',
11
+ 3: 'C-C bond formation',
12
+ 4: 'Heterocycle formation', # TODO check
13
+ 5: 'Protections',
14
+ 6: 'Deprotections',
15
+ 7: 'Reductions',
16
+ 8: 'Oxidations',
17
+ 9: 'Functional group interconversoin (FGI)',
18
+ 10: 'Functional group addition (FGA)'
19
+ }
20
+
21
+ def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'):
22
+ only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values()))
23
+ return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size)
24
+
25
+
26
+ def FPF(smi, templates, fp_size=8096, fp_type='pattern'):
27
+ """Fingerprint-Filter for applicability"""
28
+ tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type)
29
+ if not isinstance(smi, list):
30
+ smi = [smi]
31
+ mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size)
32
+ applicable = ((tfp & mfp).sum(1) == (tfp.sum(1)))
33
+ return applicable
34
+
35
+
36
+ def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
37
+ """single-step-retrosynthesis"""
38
+ X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
39
+ if hasattr(clf, 'templates'):
40
+ if clf.X is None:
41
+ clf.X = clf.template_encoder(clf.templates)
42
+ preds = clf.forward_smiles([target_smiles])
43
+
44
+ if use_FPF:
45
+ appl = FPF(target_smiles, t)
46
+ preds = preds * torch.tensor(appl)
47
+ preds = clf.softmax(preds)
48
+
49
+ idxs = preds.argsort().detach().numpy().flatten()[::-1]
50
+ preds = preds.detach().numpy().flatten()
51
+
52
+ try:
53
+ prod_rct = rdchiralReactants(target_smiles)
54
+ except:
55
+ print('target_smiles', target_smiles, 'not computable')
56
+ return []
57
+ reactions = []
58
+
59
+ i = 0
60
+ while len(reactions) < num_paths and (i < try_max_temp):
61
+ resu = []
62
+ while (not len(resu)) and (i < try_max_temp): # continue
63
+ # print(i, end=' \r')
64
+ try:
65
+ rxn = rdchiralReaction(t[idxs[i]])
66
+ resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
67
+ except:
68
+ resu = ['err']
69
+ i += 1
70
+
71
+ if len(resu) == 2: # if there is a result
72
+ res, mapped_res = resu
73
+
74
+ rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
75
+ for r in rs:
76
+ di = {
77
+ # 'template_used': t[idxs[i]],
78
+ 'template_idx': idxs[i],
79
+ 'template_rank': i + 1, # get the acutal rank, not the one without non-executable
80
+ 'reaction': r,
81
+ 'prob': preds[idxs[i]] * 100
82
+ }
83
+ # di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum()
84
+ reactions.append(di)
85
+ if viz:
86
+ for r in rs:
87
+ print('with template #', idxs[i], t[idxs[i]])
88
+ # smarts2svg(r, useSmiles=True, highlightByReactant=True);
89
+
90
+ return reactions
91
+
92
+ def ssretro_custom(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
93
+ """single-step-retrosynthesis"""
94
+ # X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
95
+ with open('saved_dictionary.pkl', 'rb') as f:
96
+ t = pickle.load(f)
97
+
98
+ if hasattr(clf, 'templates'):
99
+ if clf.X is None:
100
+ clf.X = clf.template_encoder(clf.templates)
101
+ preds = clf.forward_smiles([target_smiles])
102
+
103
+ if use_FPF:
104
+ appl = FPF(target_smiles, t)
105
+ preds = preds * torch.tensor(appl)
106
+ preds = clf.softmax(preds)
107
+
108
+ idxs = preds.argsort().detach().numpy().flatten()[::-1]
109
+ preds = preds.detach().numpy().flatten()
110
+
111
+ try:
112
+ prod_rct = rdchiralReactants(target_smiles)
113
+ except:
114
+ print('target_smiles', target_smiles, 'not computable')
115
+ return []
116
+ reactions = []
117
+
118
+ i = 0
119
+ while len(reactions) < num_paths and (i < try_max_temp):
120
+ resu = []
121
+ while (not len(resu)) and (i < try_max_temp): # continue
122
+ # print(i, end=' \r')
123
+ try:
124
+ rxn = rdchiralReaction(t[idxs[i]])
125
+ resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
126
+ except:
127
+ resu = ['err']
128
+ i += 1
129
+
130
+ if len(resu) == 2: # if there is a result
131
+ res, mapped_res = resu
132
+
133
+ rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
134
+ for r in rs:
135
+ di = {
136
+ # 'template_used': t[idxs[i]],
137
+ 'template_idx': idxs[i],
138
+ 'template_rank': i + 1, # get the acutal rank, not the one without non-executable
139
+ 'reaction': r,
140
+ 'prob': preds[idxs[i]] * 100
141
+ }
142
+ reactions.append(di)
143
+ if viz:
144
+ for r in rs:
145
+ print('with template #', idxs[i], t[idxs[i]])
146
+ return reactions