uragankatrrin commited on
Commit
6bd991c
1 Parent(s): 088913d

Upload 8 files

Browse files
Files changed (8) hide show
  1. .dockerignore +1 -0
  2. .gitignore +131 -0
  3. LICENSE +23 -0
  4. README.md +1 -12
  5. app.py +61 -0
  6. env.yml +24 -0
  7. requirements.txt +16 -0
  8. ssretro_template.py +93 -0
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/.git
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__
2
+ /mhnreact/.ipynb_checkpoints/
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ pip-wheel-metadata/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ .python-version
88
+
89
+ # pipenv
90
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
92
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
93
+ # install all needed dependencies.
94
+ #Pipfile.lock
95
+
96
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97
+ __pypackages__/
98
+
99
+ # Celery stuff
100
+ celerybeat-schedule
101
+ celerybeat.pid
102
+
103
+ # SageMath parsed files
104
+ *.sage.py
105
+
106
+ # Environments
107
+ .env
108
+ .venv
109
+ env/
110
+ venv/
111
+ ENV/
112
+ env.bak/
113
+ venv.bak/
114
+
115
+ # Spyder project settings
116
+ .spyderproject
117
+ .spyproject
118
+
119
+ # Rope project settings
120
+ .ropeproject
121
+
122
+ # mkdocs documentation
123
+ /site
124
+
125
+ # mypy
126
+ .mypy_cache/
127
+ .dmypy.json
128
+ dmypy.json
129
+
130
+ # Pyre type checker
131
+ .pyre/
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NeurIPS | 2021 License
2
+
3
+ Copyright (c) 2021, by the authors
4
+ All rights reserved.
5
+
6
+ Source code is for reviewing purpose only.
7
+
8
+ Redistribution of the source code is not permitted.
9
+
10
+ Use in source and binary forms, with or without
11
+ modification, are permitted provided that the following conditions are met:
12
+
13
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
17
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
18
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
19
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
20
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
21
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23
+
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: MHN React
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.4.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # mhn-react-app
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from mhnreact.inspect import list_models, load_clf
3
+ from rdkit.Chem import rdChemReactions as Reaction
4
+ from rdkit.Chem.Draw import rdMolDraw2D
5
+ from PIL import Image, ImageDraw
6
+ from ssretro_template import ssretro
7
+
8
+
9
+ def get_output(p):
10
+ rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
11
+ d = rdMolDraw2D.MolDraw2DCairo(800, 200)
12
+ d.DrawReaction(rxn, highlightByReactant=False)
13
+ d.FinishDrawing()
14
+ text = d.GetDrawingText()
15
+
16
+ return text
17
+
18
+ def ssretro_prediction(molecule):
19
+
20
+ model_fn = list_models()[0]
21
+ retro_clf = load_clf(model_fn)
22
+
23
+ outputs = ssretro(molecule, retro_clf)
24
+ predict, txt = [], []
25
+ for pred in outputs:
26
+ txt.append(f'predicted top-{pred["template_rank"]-1}, prob: {pred["prob"]:2.1f}%; {pred["reaction"]}')
27
+ predict.append(get_output(pred["reaction"]))
28
+
29
+ return predict, txt
30
+
31
+
32
+ def mhn_react_backend(mol):
33
+
34
+ output_dir = "outputs"
35
+ formatter = "03d"
36
+ images = []
37
+
38
+ predictions, comments = ssretro_prediction(mol)
39
+
40
+ for i in range(len(predictions)):
41
+ output_im = f"{str(output_dir)}/{format(i, formatter)}.png"
42
+
43
+ with open(output_im, "wb") as fh:
44
+ fh.write(predictions[i])
45
+ fh.close()
46
+
47
+ img = Image.open(output_im)
48
+ I1 = ImageDraw.Draw(img)
49
+
50
+ I1.text((20, 10), comments[i], fill=(30, 0, 44))
51
+
52
+ images.append(img)
53
+ img.save(output_im)
54
+
55
+ return images
56
+
57
+
58
+ demo = gr.Interface(fn=mhn_react_backend, inputs="text", outputs="gallery")
59
+ demo.launch()
60
+
61
+
env.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mhnreact_env
2
+ channels:
3
+ - bioconda
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ # change to your cuda version
8
+ - cudatoolkit=10.2
9
+ - torch==1.6
10
+ - torchvision==0.7
11
+ - pandas=1.0.5
12
+ - pip=20.1.1=py_1
13
+ - python=3.7
14
+ - rdkit=2021.03.1 #2020.03.4
15
+ # optionally
16
+ - ipython
17
+ - jupyterlab
18
+ - pip:
19
+ - numpy==1.19
20
+ - scikit-learn==0.23.1
21
+ - scipy==1.4
22
+ - hydra-core
23
+ - tqdm
24
+ - rdchiral==1.1.0
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy~=1.23.2
2
+ pandas~=1.4.4
3
+ scipy~=1.9.1
4
+ joblib~=1.1.0
5
+ sklearn~=0.0
6
+ scikit-learn~=1.1.2
7
+ requests~=2.28.1
8
+ rdchiral~=1.1.0
9
+ setuptools~=60.2.0
10
+ gradio~=3.3.1
11
+ Pillow~=9.2.0
12
+ matplotlib~=3.5.3
13
+ torch~=1.12.1
14
+ wandb~=0.13.2
15
+ tqdm~=4.64.1
16
+ swifter~=1.3.4
ssretro_template.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit.Chem import AllChem
2
+ from mhnreact.data import load_dataset_from_csv
3
+ from mhnreact.molutils import convert_smiles_to_fp
4
+ from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
5
+ import torch
6
+
7
+ reaction_superclass_names = {
8
+ 1: 'Heteroatom alkylation and arylation',
9
+ 2: 'Acylation and related processes',
10
+ 3: 'C-C bond formation',
11
+ 4: 'Heterocycle formation', # TODO check
12
+ 5: 'Protections',
13
+ 6: 'Deprotections',
14
+ 7: 'Reductions',
15
+ 8: 'Oxidations',
16
+ 9: 'Functional group interconversoin (FGI)',
17
+ 10: 'Functional group addition (FGA)'
18
+ }
19
+
20
+ def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'):
21
+ only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values()))
22
+ return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size)
23
+
24
+
25
+ def FPF(smi, templates, fp_size=8096, fp_type='pattern'):
26
+ """Fingerprint-Filter for applicability"""
27
+ tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type)
28
+ if not isinstance(smi, list):
29
+ smi = [smi]
30
+ mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size)
31
+ applicable = ((tfp & mfp).sum(1) == (tfp.sum(1)))
32
+ return applicable
33
+
34
+
35
+ def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
36
+ """single-step-retrosynthesis"""
37
+ X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
38
+ if hasattr(clf, 'templates'):
39
+ if clf.X is None:
40
+ clf.X = clf.template_encoder(clf.templates)
41
+ preds = clf.forward_smiles([target_smiles])
42
+
43
+ if use_FPF:
44
+ appl = FPF(target_smiles, t)
45
+ preds = preds * torch.tensor(appl)
46
+ preds = clf.softmax(preds)
47
+
48
+ idxs = preds.argsort().detach().numpy().flatten()[::-1]
49
+ preds = preds.detach().numpy().flatten()
50
+
51
+ try:
52
+ prod_rct = rdchiralReactants(target_smiles)
53
+ except:
54
+ print('target_smiles', target_smiles, 'not computebale')
55
+ return []
56
+ reactions = []
57
+
58
+ i = 0
59
+ while len(reactions) < num_paths and (i < try_max_temp):
60
+ resu = []
61
+ while (not len(resu)) and (i < try_max_temp): # continue
62
+ # print(i, end=' \r')
63
+ try:
64
+ rxn = rdchiralReaction(t[idxs[i]])
65
+ resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
66
+ except:
67
+ resu = ['err']
68
+ i += 1
69
+
70
+ if len(resu) == 2: # if there is a result
71
+ res, mapped_res = resu
72
+
73
+ rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
74
+ for r in rs:
75
+ di = {
76
+ # 'template_used': t[idxs[i]],
77
+ # 'template_idx': idxs[i],
78
+ 'template_rank': i + 1, # get the acutal rank, not the one without non-executable
79
+ 'reaction': r,
80
+ # 'reaction_canonical': canonicalize_template(r),
81
+ 'prob': preds[idxs[i]] * 100
82
+ # 'template_class': reaction_superclass_names[
83
+ # df[df.reaction_smarts == t[idxs[i]]]["class"].unique()[0]]
84
+ }
85
+ # di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum()
86
+ reactions.append(di)
87
+ if viz:
88
+ for r in rs:
89
+ print('with template #', idxs[i], t[idxs[i]])
90
+ # smarts2svg(r, useSmiles=True, highlightByReactant=True);
91
+
92
+ return reactions
93
+