uragankatrrin commited on
Commit
14ebfa4
1 Parent(s): 78a4a8b

Delete ssretro_template.py

Browse files
Files changed (1) hide show
  1. ssretro_template.py +0 -93
ssretro_template.py DELETED
@@ -1,93 +0,0 @@
1
- from rdkit.Chem import AllChem
2
- from mhnreact.data import load_dataset_from_csv
3
- from mhnreact.molutils import convert_smiles_to_fp
4
- from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
5
- import torch
6
-
7
- reaction_superclass_names = {
8
- 1: 'Heteroatom alkylation and arylation',
9
- 2: 'Acylation and related processes',
10
- 3: 'C-C bond formation',
11
- 4: 'Heterocycle formation', # TODO check
12
- 5: 'Protections',
13
- 6: 'Deprotections',
14
- 7: 'Reductions',
15
- 8: 'Oxidations',
16
- 9: 'Functional group interconversoin (FGI)',
17
- 10: 'Functional group addition (FGA)'
18
- }
19
-
20
- def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'):
21
- only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values()))
22
- return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size)
23
-
24
-
25
- def FPF(smi, templates, fp_size=8096, fp_type='pattern'):
26
- """Fingerprint-Filter for applicability"""
27
- tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type)
28
- if not isinstance(smi, list):
29
- smi = [smi]
30
- mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size)
31
- applicable = ((tfp & mfp).sum(1) == (tfp.sum(1)))
32
- return applicable
33
-
34
-
35
- def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
36
- """single-step-retrosynthesis"""
37
- X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
38
- if hasattr(clf, 'templates'):
39
- if clf.X is None:
40
- clf.X = clf.template_encoder(clf.templates)
41
- preds = clf.forward_smiles([target_smiles])
42
-
43
- if use_FPF:
44
- appl = FPF(target_smiles, t)
45
- preds = preds * torch.tensor(appl)
46
- preds = clf.softmax(preds)
47
-
48
- idxs = preds.argsort().detach().numpy().flatten()[::-1]
49
- preds = preds.detach().numpy().flatten()
50
-
51
- try:
52
- prod_rct = rdchiralReactants(target_smiles)
53
- except:
54
- print('target_smiles', target_smiles, 'not computebale')
55
- return []
56
- reactions = []
57
-
58
- i = 0
59
- while len(reactions) < num_paths and (i < try_max_temp):
60
- resu = []
61
- while (not len(resu)) and (i < try_max_temp): # continue
62
- # print(i, end=' \r')
63
- try:
64
- rxn = rdchiralReaction(t[idxs[i]])
65
- resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
66
- except:
67
- resu = ['err']
68
- i += 1
69
-
70
- if len(resu) == 2: # if there is a result
71
- res, mapped_res = resu
72
-
73
- rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
74
- for r in rs:
75
- di = {
76
- # 'template_used': t[idxs[i]],
77
- # 'template_idx': idxs[i],
78
- 'template_rank': i + 1, # get the acutal rank, not the one without non-executable
79
- 'reaction': r,
80
- # 'reaction_canonical': canonicalize_template(r),
81
- 'prob': preds[idxs[i]] * 100
82
- # 'template_class': reaction_superclass_names[
83
- # df[df.reaction_smarts == t[idxs[i]]]["class"].unique()[0]]
84
- }
85
- # di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum()
86
- reactions.append(di)
87
- if viz:
88
- for r in rs:
89
- print('with template #', idxs[i], t[idxs[i]])
90
- # smarts2svg(r, useSmiles=True, highlightByReactant=True);
91
-
92
- return reactions
93
-