MHN-React / ssretro_template.py
uragankatrrin's picture
Upload 8 files
6bd991c
raw
history blame
3.63 kB
from rdkit.Chem import AllChem
from mhnreact.data import load_dataset_from_csv
from mhnreact.molutils import convert_smiles_to_fp
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
import torch
reaction_superclass_names = {
1: 'Heteroatom alkylation and arylation',
2: 'Acylation and related processes',
3: 'C-C bond formation',
4: 'Heterocycle formation', # TODO check
5: 'Protections',
6: 'Deprotections',
7: 'Reductions',
8: 'Oxidations',
9: 'Functional group interconversoin (FGI)',
10: 'Functional group addition (FGA)'
}
def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'):
only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values()))
return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size)
def FPF(smi, templates, fp_size=8096, fp_type='pattern'):
"""Fingerprint-Filter for applicability"""
tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type)
if not isinstance(smi, list):
smi = [smi]
mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size)
applicable = ((tfp & mfp).sum(1) == (tfp.sum(1)))
return applicable
def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False):
"""single-step-retrosynthesis"""
X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True)
if hasattr(clf, 'templates'):
if clf.X is None:
clf.X = clf.template_encoder(clf.templates)
preds = clf.forward_smiles([target_smiles])
if use_FPF:
appl = FPF(target_smiles, t)
preds = preds * torch.tensor(appl)
preds = clf.softmax(preds)
idxs = preds.argsort().detach().numpy().flatten()[::-1]
preds = preds.detach().numpy().flatten()
try:
prod_rct = rdchiralReactants(target_smiles)
except:
print('target_smiles', target_smiles, 'not computebale')
return []
reactions = []
i = 0
while len(reactions) < num_paths and (i < try_max_temp):
resu = []
while (not len(resu)) and (i < try_max_temp): # continue
# print(i, end=' \r')
try:
rxn = rdchiralReaction(t[idxs[i]])
resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True)
except:
resu = ['err']
i += 1
if len(resu) == 2: # if there is a result
res, mapped_res = resu
rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())]
for r in rs:
di = {
# 'template_used': t[idxs[i]],
# 'template_idx': idxs[i],
'template_rank': i + 1, # get the acutal rank, not the one without non-executable
'reaction': r,
# 'reaction_canonical': canonicalize_template(r),
'prob': preds[idxs[i]] * 100
# 'template_class': reaction_superclass_names[
# df[df.reaction_smarts == t[idxs[i]]]["class"].unique()[0]]
}
# di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum()
reactions.append(di)
if viz:
for r in rs:
print('with template #', idxs[i], t[idxs[i]])
# smarts2svg(r, useSmiles=True, highlightByReactant=True);
return reactions