|
from rdkit.Chem import AllChem |
|
from mhnreact.data import load_dataset_from_csv |
|
from mhnreact.molutils import convert_smiles_to_fp |
|
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants |
|
import torch |
|
|
|
reaction_superclass_names = { |
|
1: 'Heteroatom alkylation and arylation', |
|
2: 'Acylation and related processes', |
|
3: 'C-C bond formation', |
|
4: 'Heterocycle formation', |
|
5: 'Protections', |
|
6: 'Deprotections', |
|
7: 'Reductions', |
|
8: 'Oxidations', |
|
9: 'Functional group interconversoin (FGI)', |
|
10: 'Functional group addition (FGA)' |
|
} |
|
|
|
def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'): |
|
only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values())) |
|
return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size) |
|
|
|
|
|
def FPF(smi, templates, fp_size=8096, fp_type='pattern'): |
|
"""Fingerprint-Filter for applicability""" |
|
tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type) |
|
if not isinstance(smi, list): |
|
smi = [smi] |
|
mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size) |
|
applicable = ((tfp & mfp).sum(1) == (tfp.sum(1))) |
|
return applicable |
|
|
|
|
|
def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False): |
|
"""single-step-retrosynthesis""" |
|
X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True) |
|
if hasattr(clf, 'templates'): |
|
if clf.X is None: |
|
clf.X = clf.template_encoder(clf.templates) |
|
preds = clf.forward_smiles([target_smiles]) |
|
|
|
if use_FPF: |
|
appl = FPF(target_smiles, t) |
|
preds = preds * torch.tensor(appl) |
|
preds = clf.softmax(preds) |
|
|
|
idxs = preds.argsort().detach().numpy().flatten()[::-1] |
|
preds = preds.detach().numpy().flatten() |
|
|
|
try: |
|
prod_rct = rdchiralReactants(target_smiles) |
|
except: |
|
print('target_smiles', target_smiles, 'not computebale') |
|
return [] |
|
reactions = [] |
|
|
|
i = 0 |
|
while len(reactions) < num_paths and (i < try_max_temp): |
|
resu = [] |
|
while (not len(resu)) and (i < try_max_temp): |
|
|
|
try: |
|
rxn = rdchiralReaction(t[idxs[i]]) |
|
resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True) |
|
except: |
|
resu = ['err'] |
|
i += 1 |
|
|
|
if len(resu) == 2: |
|
res, mapped_res = resu |
|
|
|
rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())] |
|
for r in rs: |
|
di = { |
|
|
|
|
|
'template_rank': i + 1, |
|
'reaction': r, |
|
|
|
'prob': preds[idxs[i]] * 100 |
|
|
|
|
|
} |
|
|
|
reactions.append(di) |
|
if viz: |
|
for r in rs: |
|
print('with template #', idxs[i], t[idxs[i]]) |
|
|
|
|
|
return reactions |
|
|
|
|