import csv import numpy as np from rdkit import Chem from rdkit.Chem import MolStandardize from src import metrics from src.delinker_utils import sascorer, calc_SC_RDKit from tqdm import tqdm from pdb import set_trace def get_valid_as_in_delinker(data, progress=False): valid = [] generator = tqdm(enumerate(data), total=len(data)) if progress else enumerate(data) for i, m in generator: pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=False) true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=False) frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=False) pred_mol_frags = Chem.GetMolFrags(pred_mol, asMols=True, sanitizeFrags=False) pred_mol_filtered = max(pred_mol_frags, default=pred_mol, key=lambda mol: mol.GetNumAtoms()) try: Chem.SanitizeMol(pred_mol_filtered) Chem.SanitizeMol(true_mol) Chem.SanitizeMol(frag) except: continue if len(pred_mol_filtered.GetSubstructMatch(frag)) > 0: valid.append({ 'pred_mol': m['pred_mol'], 'true_mol': m['true_mol'], 'pred_mol_smi': Chem.MolToSmiles(pred_mol_filtered), 'true_mol_smi': Chem.MolToSmiles(true_mol), 'frag_smi': Chem.MolToSmiles(frag) }) return valid def extract_linker_smiles(molecule, fragments): match = molecule.GetSubstructMatch(fragments) elinker = Chem.EditableMol(molecule) for atom_id in sorted(match, reverse=True): elinker.RemoveAtom(atom_id) linker = elinker.GetMol() Chem.RemoveStereochemistry(linker) try: linker = MolStandardize.canonicalize_tautomer_smiles(Chem.MolToSmiles(linker)) except: linker = Chem.MolToSmiles(linker) return linker def compute_and_add_linker_smiles(data, progress=False): data_with_linkers = [] generator = tqdm(data) if progress else data for m in generator: pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True) true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True) frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=True) pred_linker = extract_linker_smiles(pred_mol, frag) true_linker = extract_linker_smiles(true_mol, frag) data_with_linkers.append({ **m, 'pred_linker': pred_linker, 'true_linker': true_linker, }) return data_with_linkers def compute_uniqueness(data, progress=False): mol_dictionary = {} generator = tqdm(data) if progress else data for m in generator: frag = m['frag_smi'] pred_mol = m['pred_mol_smi'] true_mol = m['true_mol_smi'] key = f'{true_mol}.{frag}' mol_dictionary.setdefault(key, []).append(pred_mol) total_mol = 0 unique_mol = 0 for molecules in mol_dictionary.values(): total_mol += len(molecules) unique_mol += len(set(molecules)) return unique_mol / total_mol def compute_novelty(data, progress=False): novel = 0 true_linkers = set([m['true_linker'] for m in data]) generator = tqdm(data) if progress else data for m in generator: pred_linker = m['pred_linker'] if pred_linker in true_linkers: continue else: novel += 1 return novel / len(data) def compute_recovery_rate(data, progress=False): total = set() recovered = set() generator = tqdm(data) if progress else data for m in generator: pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True) Chem.RemoveStereochemistry(pred_mol) pred_mol = Chem.MolToSmiles(Chem.RemoveHs(pred_mol)) true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True) Chem.RemoveStereochemistry(true_mol) true_mol = Chem.MolToSmiles(Chem.RemoveHs(true_mol)) true_link = m['true_linker'] total.add(f'{true_mol}.{true_link}') if pred_mol == true_mol: recovered.add(f'{true_mol}.{true_link}') return len(recovered) / len(total) def calc_sa_score_mol(mol): if mol is None: return None return sascorer.calculateScore(mol) def check_ring_filter(linker): check = True # Get linker rings ssr = Chem.GetSymmSSSR(linker) # Check rings for ring in ssr: for atom_idx in ring: for bond in linker.GetAtomWithIdx(atom_idx).GetBonds(): if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring: check = False return check def check_pains(mol, pains_smarts): for pain in pains_smarts: if mol.HasSubstructMatch(pain): return False return True def calc_2d_filters(toks, pains_smarts): pred_mol = Chem.MolFromSmiles(toks['pred_mol_smi']) frag = Chem.MolFromSmiles(toks['frag_smi']) linker = Chem.MolFromSmiles(toks['pred_linker']) result = [False, False, False] if len(pred_mol.GetSubstructMatch(frag)) > 0: sa_score = False ra_score = False pains_score = False try: sa_score = calc_sa_score_mol(pred_mol) < calc_sa_score_mol(frag) except Exception as e: print(f'Could not compute SA score: {e}') try: ra_score = check_ring_filter(linker) except Exception as e: print(f'Could not compute RA score: {e}') try: pains_score = check_pains(pred_mol, pains_smarts) except Exception as e: print(f'Could not compute PAINS score: {e}') result = [sa_score, ra_score, pains_score] return result def calc_filters_2d_dataset(data): with open('models/wehi_pains.csv', 'r') as f: pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)] pass_all = pass_SA = pass_RA = pass_PAINS = 0 for m in data: filters_2d = calc_2d_filters(m, pains_smarts) pass_all += filters_2d[0] & filters_2d[1] & filters_2d[2] pass_SA += filters_2d[0] pass_RA += filters_2d[1] pass_PAINS += filters_2d[2] return pass_all / len(data), pass_SA / len(data), pass_RA / len(data), pass_PAINS / len(data) def calc_sc_rdkit_full_mol(gen_mol, ref_mol): try: score = calc_SC_RDKit.calc_SC_RDKit_score(gen_mol, ref_mol) return score except: return -0.5 def sc_rdkit_score(data): scores = [] for m in data: score = calc_sc_rdkit_full_mol(m['pred_mol'], m['true_mol']) scores.append(score) return np.mean(scores) def get_delinker_metrics(pred_molecules, true_molecules, true_fragments): default_values = { 'DeLinker/validity': 0, 'DeLinker/uniqueness': 0, 'DeLinker/novelty': 0, 'DeLinker/recovery': 0, 'DeLinker/2D_filters': 0, 'DeLinker/2D_filters_SA': 0, 'DeLinker/2D_filters_RA': 0, 'DeLinker/2D_filters_PAINS': 0, 'DeLinker/SC_RDKit': 0, } if len(pred_molecules) == 0: return default_values data = [] for pred_mol, true_mol, true_frag in zip(pred_molecules, true_molecules, true_fragments): data.append({ 'pred_mol': pred_mol, 'true_mol': true_mol, 'pred_mol_smi': Chem.MolToSmiles(pred_mol), 'true_mol_smi': Chem.MolToSmiles(true_mol), 'frag_smi': Chem.MolToSmiles(true_frag) }) # Validity according to DeLinker paper: # Passing rdkit.Chem.Sanitize and the biggest fragment contains both fragments valid_data = get_valid_as_in_delinker(data) validity_as_in_delinker = len(valid_data) / len(data) if len(valid_data) == 0: return default_values # Compute linkers and add to results valid_data = compute_and_add_linker_smiles(valid_data) # Compute uniqueness uniqueness = compute_uniqueness(valid_data) # Compute novelty novelty = compute_novelty(valid_data) # Compute recovered molecules recovery_rate = compute_recovery_rate(valid_data) # 2D filters pass_all, pass_SA, pass_RA, pass_PAINS = calc_filters_2d_dataset(valid_data) # 3D Filters sc_rdkit = sc_rdkit_score(valid_data) return { 'DeLinker/validity': validity_as_in_delinker, 'DeLinker/uniqueness': uniqueness, 'DeLinker/novelty': novelty, 'DeLinker/recovery': recovery_rate, 'DeLinker/2D_filters': pass_all, 'DeLinker/2D_filters_SA': pass_SA, 'DeLinker/2D_filters_RA': pass_RA, 'DeLinker/2D_filters_PAINS': pass_PAINS, 'DeLinker/SC_RDKit': sc_rdkit, }