File size: 3,626 Bytes
6bd991c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from rdkit.Chem import AllChem
from mhnreact.data import load_dataset_from_csv
from mhnreact.molutils import convert_smiles_to_fp
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
import torch

reaction_superclass_names = {
    1: 'Heteroatom alkylation and arylation',
    2: 'Acylation and related processes',
    3: 'C-C bond formation',
    4: 'Heterocycle formation',  # TODO check
    5: 'Protections',
    6: 'Deprotections',
    7: 'Reductions',
    8: 'Oxidations',
    9: 'Functional group interconversoin (FGI)',
    10: 'Functional group addition (FGA)'
}

def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'):
    only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values()))
    return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size)


def FPF(smi, templates, fp_size=8096, fp_type='pattern'):
    """Fingerprint-Filter for applicability"""
    tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type)
    if not isinstance(smi, list):
        smi = [smi]
    mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size)
    applicable = ((tfp & mfp).sum(1) == (tfp.sum(1)))
    return applicable


def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
    """single-step-retrosynthesis"""
    X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
    if hasattr(clf, 'templates'):
        if clf.X is None:
            clf.X = clf.template_encoder(clf.templates)
    preds = clf.forward_smiles([target_smiles])

    if use_FPF:
        appl = FPF(target_smiles, t)
        preds = preds * torch.tensor(appl)
    preds = clf.softmax(preds)

    idxs = preds.argsort().detach().numpy().flatten()[::-1]
    preds = preds.detach().numpy().flatten()

    try:
        prod_rct = rdchiralReactants(target_smiles)
    except:
        print('target_smiles', target_smiles, 'not computebale')
        return []
    reactions = []

    i = 0
    while len(reactions) < num_paths and (i < try_max_temp):
        resu = []
        while (not len(resu)) and (i < try_max_temp):  # continue
            # print(i, end='  \r')
            try:
                rxn = rdchiralReaction(t[idxs[i]])
                resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
            except:
                resu = ['err']
            i += 1

        if len(resu) == 2:  # if there is a result
            res, mapped_res = resu

            rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
            for r in rs:
                di = {
                    # 'template_used': t[idxs[i]],
                    # 'template_idx': idxs[i],
                    'template_rank': i + 1,  # get the acutal rank, not the one without non-executable
                    'reaction': r,
                    # 'reaction_canonical': canonicalize_template(r),
                    'prob': preds[idxs[i]] * 100
                    # 'template_class': reaction_superclass_names[
                    #   df[df.reaction_smarts == t[idxs[i]]]["class"].unique()[0]]
                }
                # di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum()
                reactions.append(di)
            if viz:
                for r in rs:
                    print('with template #', idxs[i], t[idxs[i]])
                    # smarts2svg(r, useSmiles=True, highlightByReactant=True);

    return reactions