from rdkit.Chem import AllChem from mhnreact.data import load_dataset_from_csv from mhnreact.molutils import convert_smiles_to_fp from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants import torch reaction_superclass_names = { 1: 'Heteroatom alkylation and arylation', 2: 'Acylation and related processes', 3: 'C-C bond formation', 4: 'Heterocycle formation', # TODO check 5: 'Protections', 6: 'Deprotections', 7: 'Reductions', 8: 'Oxidations', 9: 'Functional group interconversoin (FGI)', 10: 'Functional group addition (FGA)' } def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'): only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values())) return convert_smiles_to_fp(only_left_side_of_templates, is_smarts=True, which=fp_type, fp_size=fp_size) def FPF(smi, templates, fp_size=8096, fp_type='pattern'): """Fingerprint-Filter for applicability""" tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type) if not isinstance(smi, list): smi = [smi] mfp = convert_smiles_to_fp(smi, which=fp_type, fp_size=fp_size) applicable = ((tfp & mfp).sum(1) == (tfp.sum(1))) return applicable def ssretro(target_smiles: str, clf, num_paths=5, try_max_temp=10, viz=False, use_FPF=False): """single-step-retrosynthesis""" X, y, t, test_reactants_can = load_dataset_from_csv('data/USPTO_50k_MHN_prepro.csv.gz', ssretroeval=True) if hasattr(clf, 'templates'): if clf.X is None: clf.X = clf.template_encoder(clf.templates) preds = clf.forward_smiles([target_smiles]) if use_FPF: appl = FPF(target_smiles, t) preds = preds * torch.tensor(appl) preds = clf.softmax(preds) idxs = preds.argsort().detach().numpy().flatten()[::-1] preds = preds.detach().numpy().flatten() try: prod_rct = rdchiralReactants(target_smiles) except: print('target_smiles', target_smiles, 'not computebale') return [] reactions = [] i = 0 while len(reactions) < num_paths and (i < try_max_temp): resu = [] while (not len(resu)) and (i < try_max_temp): # continue # print(i, end=' \r') try: rxn = rdchiralReaction(t[idxs[i]]) resu = rdchiralRun(rxn, prod_rct, keep_mapnums=True, combine_enantiomers=True, return_mapped=True) except: resu = ['err'] i += 1 if len(resu) == 2: # if there is a result res, mapped_res = resu rs = [AllChem.MolToSmiles(prod_rct.reactants) + '>>' + k[0] for k in list(mapped_res.values())] for r in rs: di = { # 'template_used': t[idxs[i]], # 'template_idx': idxs[i], 'template_rank': i + 1, # get the acutal rank, not the one without non-executable 'reaction': r, # 'reaction_canonical': canonicalize_template(r), 'prob': preds[idxs[i]] * 100 # 'template_class': reaction_superclass_names[ # df[df.reaction_smarts == t[idxs[i]]]["class"].unique()[0]] } # di['template_num_train_samples'] = (y['train'] == di['template_idx']).sum() reactions.append(di) if viz: for r in rs: print('with template #', idxs[i], t[idxs[i]]) # smarts2svg(r, useSmiles=True, highlightByReactant=True); return reactions