import numpy as np import argparse import re import random import textdistance from rdkit import Chem from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') def smi_tokenizer(smi): pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" regex = re.compile(pattern) tokens = [token for token in regex.findall(smi)] assert smi == ''.join(tokens) return ' '.join(tokens) def clear_map_canonical_smiles(smi, canonical=True, root=-1): mol = Chem.MolFromSmiles(smi) if mol is not None: for atom in mol.GetAtoms(): if atom.HasProp('molAtomMapNumber'): atom.ClearProp('molAtomMapNumber') return Chem.MolToSmiles(mol, isomericSmiles=True, rootedAtAtom=root, canonical=canonical) else: return smi def get_cano_map_number(smi,root=-1): atommap_mol = Chem.MolFromSmiles(smi) canonical_mol = Chem.MolFromSmiles(clear_map_canonical_smiles(smi,root=root)) cano2atommapIdx = atommap_mol.GetSubstructMatch(canonical_mol) correct_mapped = [canonical_mol.GetAtomWithIdx(i).GetSymbol() == atommap_mol.GetAtomWithIdx(index).GetSymbol() for i,index in enumerate(cano2atommapIdx)] atom_number = len(canonical_mol.GetAtoms()) if np.sum(correct_mapped) < atom_number or len(cano2atommapIdx) < atom_number: cano2atommapIdx = [0] * atom_number atommap2canoIdx = canonical_mol.GetSubstructMatch(atommap_mol) if len(atommap2canoIdx) != atom_number: return None for i, index in enumerate(atommap2canoIdx): cano2atommapIdx[index] = i id2atommap = [atom.GetAtomMapNum() for atom in atommap_mol.GetAtoms()] return [id2atommap[cano2atommapIdx[i]] for i in range(atom_number)] def get_root_id(mol,root_map_number): root = -1 for i, atom in enumerate(mol.GetAtoms()): if atom.GetAtomMapNum() == root_map_number: root = i break return root # root = -1 # for i, atom in enumerate(mol.GetAtoms()): # if atom.GetAtomMapNum() == root_map_number: # return i def get_forward_rsmiles(data): pt = re.compile(r':(\d+)]') product = data['product'] reactant = data['reactant'] augmentation = data['augmentation'] separated = data['separated'] pro_mol = Chem.MolFromSmiles(product) rea_mol = Chem.MolFromSmiles(reactant) """checking data quality""" rids = sorted(re.findall(pt, reactant)) pids = sorted(re.findall(pt, product)) return_status = { "status":0, "src_data":[], "tgt_data":[], "edit_distance":0, } reactant = reactant.split(".") product = product.split(".") rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant] max_times = np.prod([len(map_numbers) for map_numbers in rea_atom_map_numbers]) times = min(augmentation, max_times) reactant_roots = [[-1 for _ in reactant]] j = 0 while j < times: reactant_roots.append([random.sample(rea_atom_map_numbers[k], 1)[0] for k in range(len(reactant))]) if reactant_roots[-1] in reactant_roots[:-1]: reactant_roots.pop() else: j += 1 if j < augmentation: reactant_roots.extend(random.choices(reactant_roots, k=augmentation - times)) times = augmentation reversable = False # no reverse assert times == augmentation if reversable: times = int(times / 2) pro_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", pro))) for pro in product] full_pro_atom_map_numbers = set(map(int, re.findall(r"(?<=:)\d+", ".".join(product)))) for k in range(times): tmp = list(zip(reactant, reactant_roots[k],rea_atom_map_numbers)) random.shuffle(tmp) reactant_k, reactant_roots_k,rea_atom_map_numbers_k = [i[0] for i in tmp], [i[1] for i in tmp], [i[2] for i in tmp] aligned_reactants = [] aligned_products = [] aligned_products_order = [] all_atom_map = [] for i, rea in enumerate(reactant_k): rea_root_atom_map = reactant_roots_k[i] rea_root = get_root_id(Chem.MolFromSmiles(rea), root_map_number=rea_root_atom_map) cano_atom_map = get_cano_map_number(rea, rea_root) if cano_atom_map is None: print(f"Reactant Failed to find Canonical Mol with Atom MapNumber") continue rea_smi = clear_map_canonical_smiles(rea, canonical=True, root=rea_root) aligned_reactants.append(rea_smi) all_atom_map.extend(cano_atom_map) for i, pro_map_number in enumerate(pro_atom_map_numbers): reactant_candidates = [] selected_reactant = [] for j, map_number in enumerate(all_atom_map): if map_number in pro_map_number: for rea_index, rea_atom_map_number in enumerate(rea_atom_map_numbers_k): if map_number in rea_atom_map_number and rea_index not in selected_reactant: selected_reactant.append(rea_index) reactant_candidates.append((map_number, j, len(rea_atom_map_number))) # select maximal reactant reactant_candidates.sort(key=lambda x: x[2], reverse=True) map_number = reactant_candidates[0][0] j = reactant_candidates[0][1] pro_root = get_root_id(Chem.MolFromSmiles(product[i]), root_map_number=map_number) pro_smi = clear_map_canonical_smiles(product[i], canonical=True, root=pro_root) aligned_products.append(pro_smi) aligned_products_order.append(j) sorted_products = sorted(list(zip(aligned_products, aligned_products_order)), key=lambda x: x[1]) aligned_products = [item[0] for item in sorted_products] pro_smi = ".".join(aligned_products) if separated: reactants = [] reagents = [] for i,cano_atom_map in enumerate(rea_atom_map_numbers_k): if len(set(cano_atom_map) & full_pro_atom_map_numbers) > 0: reactants.append(aligned_reactants[i]) else: reagents.append(aligned_reactants[i]) rea_smi = ".".join(reactants) reactant_tokens = smi_tokenizer(rea_smi) if len(reagents) > 0 : reactant_tokens += " " + smi_tokenizer(".".join(reagents)) else: rea_smi = ".".join(aligned_reactants) reactant_tokens = smi_tokenizer(rea_smi) product_tokens = smi_tokenizer(pro_smi) return_status['src_data'].append(reactant_tokens) return_status['tgt_data'].append(product_tokens) if reversable: aligned_reactants.reverse() aligned_products.reverse() pro_smi = ".".join(aligned_products) rea_smi = ".".join(aligned_reactants) product_tokens = smi_tokenizer(pro_smi) reactant_tokens = smi_tokenizer(rea_smi) return_status['src_data'].append(reactant_tokens) return_status['tgt_data'].append(product_tokens) edit_distances = [] for src,tgt in zip(return_status['src_data'],return_status['tgt_data']): edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split())) return_status['edit_distance'] = np.mean(edit_distances) return return_status def get_retro_rsmiles(data): pt = re.compile(r':(\d+)]') product = data['product'] reactant = data['reactant'] augmentation = data['augmentation'] pro_mol = Chem.MolFromSmiles(product) rea_mol = Chem.MolFromSmiles(reactant) """checking data quality""" rids = sorted(re.findall(pt, reactant)) pids = sorted(re.findall(pt, product)) return_status = { "status":0, "src_data":[], "tgt_data":[], "edit_distance":0, } pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product))) reactant = reactant.split(".") reversable = False # no shuffle # augmentation = 100 if augmentation == 999: product_roots = pro_atom_map_numbers times = len(product_roots) else: product_roots = [-1] # reversable = len(reactant) > 1 max_times = len(pro_atom_map_numbers) times = min(augmentation, max_times) if times < augmentation: # times = max_times product_roots.extend(pro_atom_map_numbers) product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots))) else: # times = augmentation while len(product_roots) < times: product_roots.append(random.sample(pro_atom_map_numbers, 1)[0]) # pro_atom_map_numbers.remove(product_roots[-1]) if product_roots[-1] in product_roots[:-1]: product_roots.pop() times = len(product_roots) assert times == augmentation if reversable: times = int(times / 2) # candidates = [] for k in range(times): pro_root_atom_map = product_roots[k] pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map) cano_atom_map = get_cano_map_number(product, root=pro_root) if cano_atom_map is None: return_status["status"] = "error_mapping" return return_status pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root) aligned_reactants = [] aligned_reactants_order = [] rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant] used_indices = [] for i, rea_map_number in enumerate(rea_atom_map_numbers): for j, map_number in enumerate(cano_atom_map): # select mapping reactans if map_number in rea_map_number: rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number) rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root) aligned_reactants.append(rea_smi) aligned_reactants_order.append(j) used_indices.append(i) break sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1]) aligned_reactants = [item[0] for item in sorted_reactants] reactant_smi = ".".join(aligned_reactants) product_tokens = smi_tokenizer(pro_smi) reactant_tokens = smi_tokenizer(reactant_smi) return_status['src_data'].append(product_tokens) return_status['tgt_data'].append(reactant_tokens) if reversable: aligned_reactants.reverse() reactant_smi = ".".join(aligned_reactants) product_tokens = smi_tokenizer(pro_smi) reactant_tokens = smi_tokenizer(reactant_smi) return_status['src_data'].append(product_tokens) return_status['tgt_data'].append(reactant_tokens) assert len(return_status['src_data']) == data['augmentation'] edit_distances = [] for src,tgt in zip(return_status['src_data'],return_status['tgt_data']): edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split())) return_status['edit_distance'] = np.mean(edit_distances) return return_status def multi_process(data): pt = re.compile(r':(\d+)]') product = data['product'] reactant = data['reactant'] augmentation = data['augmentation'] pro_mol = Chem.MolFromSmiles(product) rea_mol = Chem.MolFromSmiles(reactant) """checking data quality""" rids = sorted(re.findall(pt, reactant)) pids = sorted(re.findall(pt, product)) return_status = { "status":0, "src_data":[], "tgt_data":[], "edit_distance":0, } # if ",".join(rids) != ",".join(pids): # mapping is not 1:1 # return_status["status"] = "error_mapping" # if len(set(rids)) != len(rids): # mapping is not 1:1 # return_status["status"] = "error_mapping" # if len(set(pids)) != len(pids): # mapping is not 1:1 # return_status["status"] = "error_mapping" if "" == product: return_status["status"] = "empty_p" if "" == reactant: return_status["status"] = "empty_r" if rea_mol is None: return_status["status"] = "invalid_r" if len(rea_mol.GetAtoms()) < 5: return_status["status"] = "small_r" if pro_mol is None: return_status["status"] = "invalid_p" if len(pro_mol.GetAtoms()) == 1: return_status["status"] = "small_p" if not all([a.HasProp('molAtomMapNumber') for a in pro_mol.GetAtoms()]): return_status["status"] = "error_mapping_p" """finishing checking data quality""" if return_status['status'] == 0: pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product))) reactant = reactant.split(".") if data['root_aligned']: reversable = False # no shuffle # augmentation = 100 if augmentation == 999: product_roots = pro_atom_map_numbers times = len(product_roots) else: product_roots = [-1] # reversable = len(reactant) > 1 max_times = len(pro_atom_map_numbers) times = min(augmentation, max_times) if times < augmentation: # times = max_times product_roots.extend(pro_atom_map_numbers) product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots))) else: # times = augmentation while len(product_roots) < times: product_roots.append(random.sample(pro_atom_map_numbers, 1)[0]) # pro_atom_map_numbers.remove(product_roots[-1]) if product_roots[-1] in product_roots[:-1]: product_roots.pop() times = len(product_roots) assert times == augmentation if reversable: times = int(times / 2) # candidates = [] for k in range(times): pro_root_atom_map = product_roots[k] pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map) cano_atom_map = get_cano_map_number(product, root=pro_root) if cano_atom_map is None: return_status["status"] = "error_mapping" return return_status pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root) aligned_reactants = [] aligned_reactants_order = [] rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant] used_indices = [] for i, rea_map_number in enumerate(rea_atom_map_numbers): for j, map_number in enumerate(cano_atom_map): # select mapping reactans if map_number in rea_map_number: rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number) rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root) aligned_reactants.append(rea_smi) aligned_reactants_order.append(j) used_indices.append(i) break sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1]) aligned_reactants = [item[0] for item in sorted_reactants] reactant_smi = ".".join(aligned_reactants) product_tokens = smi_tokenizer(pro_smi) reactant_tokens = smi_tokenizer(reactant_smi) return_status['src_data'].append(product_tokens) return_status['tgt_data'].append(reactant_tokens) if reversable: aligned_reactants.reverse() reactant_smi = ".".join(aligned_reactants) product_tokens = smi_tokenizer(pro_smi) reactant_tokens = smi_tokenizer(reactant_smi) return_status['src_data'].append(product_tokens) return_status['tgt_data'].append(reactant_tokens) assert len(return_status['src_data']) == data['augmentation'] else: cano_product = clear_map_canonical_smiles(product) cano_reactanct = ".".join([clear_map_canonical_smiles(rea) for rea in reactant if len(set(map(int, re.findall(r"(?<=:)\d+", rea))) & set(pro_atom_map_numbers)) > 0 ]) return_status['src_data'].append(smi_tokenizer(cano_product)) return_status['tgt_data'].append(smi_tokenizer(cano_reactanct)) pro_mol = Chem.MolFromSmiles(cano_product) rea_mols = [Chem.MolFromSmiles(rea) for rea in cano_reactanct.split(".")] for i in range(int(augmentation-1)): pro_smi = Chem.MolToSmiles(pro_mol,doRandom=True) rea_smi = [Chem.MolToSmiles(rea_mol,doRandom=True) for rea_mol in rea_mols] rea_smi = ".".join(rea_smi) return_status['src_data'].append(smi_tokenizer(pro_smi)) return_status['tgt_data'].append(smi_tokenizer(rea_smi)) edit_distances = [] for src,tgt in zip(return_status['src_data'],return_status['tgt_data']): edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split())) return_status['edit_distance'] = np.mean(edit_distances) return return_status if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-rxn',type=str,required=True) parser.add_argument('-mode',type=str,default="retro",) parser.add_argument('-forward_mode',type=str,default="separated",) parser.add_argument("-augmentation",type=int,default=1) parser.add_argument("-seed",type=int,default=33) args = parser.parse_args() print(args) reactant,reagent,product = args.rxn.split(">") pt = re.compile(r':(\d+)]') rids = sorted(re.findall(pt, reactant)) pids = sorted(re.findall(pt, product)) if len(rids) == 0 or len(pids) == 0: print("No atom mapping found!") exit(1) if args.mode == "retro": args.input = product args.output = reactant else: args.input = reactant args.output = product print("Original input:", args.input) print("Original output:",args.output) src_smi = clear_map_canonical_smiles(args.input) tgt_smi = clear_map_canonical_smiles(args.output) if src_smi == "" or tgt_smi == "": print("Invalid SMILES!") exit(1) print("Canonical input:", src_smi) print("Canonical output:",tgt_smi) mapping_check = True if ",".join(rids) != ",".join(pids): # mapping is not 1:1 mapping_check = False if len(set(rids)) != len(rids): # mapping is not 1:1 mapping_check = False if len(set(pids)) != len(pids): # mapping is not 1:1 mapping_check = False if not mapping_check: print("The quality of the atom mapping may not be good enough, which can affect the effect of root alignment.") data = { 'product':product, 'reactant':reactant, 'augmentation':args.augmentation, 'separated':args.forward_mode == "separated" } if args.mode == "retro": res = get_retro_rsmiles(data) else: res = get_forward_rsmiles(data) for index,(src,tgt) in enumerate(zip(res['src_data'], res['tgt_data'])): print(f"ID:{index}") print(f"R-SMILES input:{''.join(src.split())}") print(f"R-SMILES output:{''.join(tgt.split())}") print("Avg. edit distance:", res['edit_distance'])