Spaces:
Runtime error
Runtime error
from rdkit.Chem import AllChem | |
from mhnreact.data import load_dataset_from_csv | |
from mhnreact.molutils import convert_smiles_to_fp | |
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants | |
import torch | |
reaction_superclass_names = { | |
1: 'Heteroatom alkylation and arylation', | |
2: 'Acylation and related processes', | |
3: 'C-C bond formation', | |
4: 'Heterocycle formation', # TODO check | |
5: 'Protections', | |
6: 'Deprotections', | |
7: 'Reductions', | |
8: 'Oxidations', | |
9: 'Functional group interconversoin (FGI)', | |
10: 'Functional group addition (FGA)' | |
} | |
def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'): | |
only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values())) | |
return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size) | |
def FPF(smi, templates, fp_size=8096, fp_type='pattern'): | |
"""Fingerprint-Filter for applicability""" | |
tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type) | |
if not isinstance(smi, list): | |
smi = [smi] | |
mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size) | |
applicable = ((tfp & mfp).sum(1) == (tfp.sum(1))) | |
return applicable | |
def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False): | |
"""single-step-retrosynthesis""" | |
X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True) | |
if hasattr(clf, 'templates'): | |
if clf.X is None: | |
clf.X = clf.template_encoder(clf.templates) | |
preds = clf.forward_smiles([target_smiles]) | |
if use_FPF: | |
appl = FPF(target_smiles, t) | |
preds = preds * torch.tensor(appl) | |
preds = clf.softmax(preds) | |
idxs = preds.argsort().detach().numpy().flatten()[::-1] | |
preds = preds.detach().numpy().flatten() | |
try: | |
prod_rct = rdchiralReactants(target_smiles) | |
except: | |
print('target_smiles', target_smiles, 'not computebale') | |
return [] | |
reactions = [] | |
i = 0 | |
while len(reactions) < num_paths and (i < try_max_temp): | |
resu = [] | |
while (not len(resu)) and (i < try_max_temp): # continue | |
# print(i, end=' \r') | |
try: | |
rxn = rdchiralReaction(t[idxs[i]]) | |
resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True) | |
except: | |
resu = ['err'] | |
i += 1 | |
if len(resu) == 2: # if there is a result | |
res, mapped_res = resu | |
rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())] | |
for r in rs: | |
di = { | |
# 'template_used': t[idxs[i]], | |
# 'template_idx': idxs[i], | |
'template_rank': i + 1, # get the acutal rank, not the one without non-executable | |
'reaction': r, | |
# 'reaction_canonical': canonicalize_template(r), | |
'prob': preds[idxs[i]] * 100 | |
# 'template_class': reaction_superclass_names[ | |
# df[df.reaction_smarts == t[idxs[i]]]["class"].unique()[0]] | |
} | |
# di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum() | |
reactions.append(di) | |
if viz: | |
for r in rs: | |
print('with template #', idxs[i], t[idxs[i]]) | |
# smarts2svg(r, useSmiles=True, highlightByReactant=True); | |
return reactions | |