File size: 5,479 Bytes
a099a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from rdkit.Chem import AllChem
from mhnreact.data import load_dataset_from_csv
from mhnreact.molutils import convert_smiles_to_fp
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
import torch
import pickle

reaction_superclass_names = {
    1: 'Heteroatom alkylation and arylation',
    2: 'Acylation and related processes',
    3: 'C-C bond formation',
    4: 'Heterocycle formation',  # TODO check
    5: 'Protections',
    6: 'Deprotections',
    7: 'Reductions',
    8: 'Oxidations',
    9: 'Functional group interconversoin (FGI)',
    10: 'Functional group addition (FGA)'
}

def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'):
    only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values()))
    return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size)


def FPF(smi, templates, fp_size=8096, fp_type='pattern'):
    """Fingerprint-Filter for applicability"""
    tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type)
    if not isinstance(smi, list):
        smi = [smi]
    mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size)
    applicable = ((tfp & mfp).sum(1) == (tfp.sum(1)))
    return applicable


def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
    """single-step-retrosynthesis"""
    X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
    if hasattr(clf, 'templates'):
        if clf.X is None:
            clf.X = clf.template_encoder(clf.templates)
    preds = clf.forward_smiles([target_smiles])

    if use_FPF:
        appl = FPF(target_smiles, t)
        preds = preds * torch.tensor(appl)
    preds = clf.softmax(preds)

    idxs = preds.argsort().detach().numpy().flatten()[::-1]
    preds = preds.detach().numpy().flatten()

    try:
        prod_rct = rdchiralReactants(target_smiles)
    except:
        print('target_smiles', target_smiles, 'not computable')
        return []
    reactions = []

    i = 0
    while len(reactions) < num_paths and (i < try_max_temp):
        resu = []
        while (not len(resu)) and (i < try_max_temp):  # continue
            # print(i, end='  \r')
            try:
                rxn = rdchiralReaction(t[idxs[i]])
                resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
            except:
                resu = ['err']
            i += 1

        if len(resu) == 2:  # if there is a result
            res, mapped_res = resu

            rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
            for r in rs:
                di = {
                    # 'template_used': t[idxs[i]],
                    'template_idx': idxs[i],
                    'template_rank': i + 1,  # get the acutal rank, not the one without non-executable
                    'reaction': r,
                    'prob': preds[idxs[i]] * 100
                }
                # di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum()
                reactions.append(di)
            if viz:
                for r in rs:
                    print('with template #', idxs[i], t[idxs[i]])
                    # smarts2svg(r, useSmiles=True, highlightByReactant=True);

    return reactions

def ssretro_custom(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
    """single-step-retrosynthesis"""
    # X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
    with open('saved_dictionary.pkl', 'rb') as f:
        t = pickle.load(f)

    if hasattr(clf, 'templates'):
        if clf.X is None:
            clf.X = clf.template_encoder(clf.templates)
    preds = clf.forward_smiles([target_smiles])

    if use_FPF:
        appl = FPF(target_smiles, t)
        preds = preds * torch.tensor(appl)
    preds = clf.softmax(preds)

    idxs = preds.argsort().detach().numpy().flatten()[::-1]
    preds = preds.detach().numpy().flatten()

    try:
        prod_rct = rdchiralReactants(target_smiles)
    except:
        print('target_smiles', target_smiles, 'not computable')
        return []
    reactions = []

    i = 0
    while len(reactions) < num_paths and (i < try_max_temp):
        resu = []
        while (not len(resu)) and (i < try_max_temp):  # continue
            # print(i, end='  \r')
            try:
                rxn = rdchiralReaction(t[idxs[i]])
                resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
            except:
                resu = ['err']
            i += 1

        if len(resu) == 2:  # if there is a result
            res, mapped_res = resu

            rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
            for r in rs:
                di = {
                    # 'template_used': t[idxs[i]],
                    'template_idx': idxs[i],
                    'template_rank': i + 1,  # get the acutal rank, not the one without non-executable
                    'reaction': r,
                    'prob': preds[idxs[i]] * 100
                }
                reactions.append(di)
            if viz:
                for r in rs:
                    print('with template #', idxs[i], t[idxs[i]])
    return reactions