DiffLinker / src /delinker.py
igashov
DiffLinker code
95ba5bc
raw
history blame
No virus
8.69 kB
import csv
import numpy as np
from rdkit import Chem
from rdkit.Chem import MolStandardize
from src import metrics
from src.delinker_utils import sascorer, calc_SC_RDKit
from tqdm import tqdm
from pdb import set_trace
def get_valid_as_in_delinker(data, progress=False):
valid = []
generator = tqdm(enumerate(data), total=len(data)) if progress else enumerate(data)
for i, m in generator:
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=False)
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=False)
frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=False)
pred_mol_frags = Chem.GetMolFrags(pred_mol, asMols=True, sanitizeFrags=False)
pred_mol_filtered = max(pred_mol_frags, default=pred_mol, key=lambda mol: mol.GetNumAtoms())
try:
Chem.SanitizeMol(pred_mol_filtered)
Chem.SanitizeMol(true_mol)
Chem.SanitizeMol(frag)
except:
continue
if len(pred_mol_filtered.GetSubstructMatch(frag)) > 0:
valid.append({
'pred_mol': m['pred_mol'],
'true_mol': m['true_mol'],
'pred_mol_smi': Chem.MolToSmiles(pred_mol_filtered),
'true_mol_smi': Chem.MolToSmiles(true_mol),
'frag_smi': Chem.MolToSmiles(frag)
})
return valid
def extract_linker_smiles(molecule, fragments):
match = molecule.GetSubstructMatch(fragments)
elinker = Chem.EditableMol(molecule)
for atom_id in sorted(match, reverse=True):
elinker.RemoveAtom(atom_id)
linker = elinker.GetMol()
Chem.RemoveStereochemistry(linker)
try:
linker = MolStandardize.canonicalize_tautomer_smiles(Chem.MolToSmiles(linker))
except:
linker = Chem.MolToSmiles(linker)
return linker
def compute_and_add_linker_smiles(data, progress=False):
data_with_linkers = []
generator = tqdm(data) if progress else data
for m in generator:
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=True)
pred_linker = extract_linker_smiles(pred_mol, frag)
true_linker = extract_linker_smiles(true_mol, frag)
data_with_linkers.append({
**m,
'pred_linker': pred_linker,
'true_linker': true_linker,
})
return data_with_linkers
def compute_uniqueness(data, progress=False):
mol_dictionary = {}
generator = tqdm(data) if progress else data
for m in generator:
frag = m['frag_smi']
pred_mol = m['pred_mol_smi']
true_mol = m['true_mol_smi']
key = f'{true_mol}.{frag}'
mol_dictionary.setdefault(key, []).append(pred_mol)
total_mol = 0
unique_mol = 0
for molecules in mol_dictionary.values():
total_mol += len(molecules)
unique_mol += len(set(molecules))
return unique_mol / total_mol
def compute_novelty(data, progress=False):
novel = 0
true_linkers = set([m['true_linker'] for m in data])
generator = tqdm(data) if progress else data
for m in generator:
pred_linker = m['pred_linker']
if pred_linker in true_linkers:
continue
else:
novel += 1
return novel / len(data)
def compute_recovery_rate(data, progress=False):
total = set()
recovered = set()
generator = tqdm(data) if progress else data
for m in generator:
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
Chem.RemoveStereochemistry(pred_mol)
pred_mol = Chem.MolToSmiles(Chem.RemoveHs(pred_mol))
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
Chem.RemoveStereochemistry(true_mol)
true_mol = Chem.MolToSmiles(Chem.RemoveHs(true_mol))
true_link = m['true_linker']
total.add(f'{true_mol}.{true_link}')
if pred_mol == true_mol:
recovered.add(f'{true_mol}.{true_link}')
return len(recovered) / len(total)
def calc_sa_score_mol(mol):
if mol is None:
return None
return sascorer.calculateScore(mol)
def check_ring_filter(linker):
check = True
# Get linker rings
ssr = Chem.GetSymmSSSR(linker)
# Check rings
for ring in ssr:
for atom_idx in ring:
for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
check = False
return check
def check_pains(mol, pains_smarts):
for pain in pains_smarts:
if mol.HasSubstructMatch(pain):
return False
return True
def calc_2d_filters(toks, pains_smarts):
pred_mol = Chem.MolFromSmiles(toks['pred_mol_smi'])
frag = Chem.MolFromSmiles(toks['frag_smi'])
linker = Chem.MolFromSmiles(toks['pred_linker'])
result = [False, False, False]
if len(pred_mol.GetSubstructMatch(frag)) > 0:
sa_score = False
ra_score = False
pains_score = False
try:
sa_score = calc_sa_score_mol(pred_mol) < calc_sa_score_mol(frag)
except Exception as e:
print(f'Could not compute SA score: {e}')
try:
ra_score = check_ring_filter(linker)
except Exception as e:
print(f'Could not compute RA score: {e}')
try:
pains_score = check_pains(pred_mol, pains_smarts)
except Exception as e:
print(f'Could not compute PAINS score: {e}')
result = [sa_score, ra_score, pains_score]
return result
def calc_filters_2d_dataset(data):
with open('models/wehi_pains.csv', 'r') as f:
pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]
pass_all = pass_SA = pass_RA = pass_PAINS = 0
for m in data:
filters_2d = calc_2d_filters(m, pains_smarts)
pass_all += filters_2d[0] & filters_2d[1] & filters_2d[2]
pass_SA += filters_2d[0]
pass_RA += filters_2d[1]
pass_PAINS += filters_2d[2]
return pass_all / len(data), pass_SA / len(data), pass_RA / len(data), pass_PAINS / len(data)
def calc_sc_rdkit_full_mol(gen_mol, ref_mol):
try:
score = calc_SC_RDKit.calc_SC_RDKit_score(gen_mol, ref_mol)
return score
except:
return -0.5
def sc_rdkit_score(data):
scores = []
for m in data:
score = calc_sc_rdkit_full_mol(m['pred_mol'], m['true_mol'])
scores.append(score)
return np.mean(scores)
def get_delinker_metrics(pred_molecules, true_molecules, true_fragments):
default_values = {
'DeLinker/validity': 0,
'DeLinker/uniqueness': 0,
'DeLinker/novelty': 0,
'DeLinker/recovery': 0,
'DeLinker/2D_filters': 0,
'DeLinker/2D_filters_SA': 0,
'DeLinker/2D_filters_RA': 0,
'DeLinker/2D_filters_PAINS': 0,
'DeLinker/SC_RDKit': 0,
}
if len(pred_molecules) == 0:
return default_values
data = []
for pred_mol, true_mol, true_frag in zip(pred_molecules, true_molecules, true_fragments):
data.append({
'pred_mol': pred_mol,
'true_mol': true_mol,
'pred_mol_smi': Chem.MolToSmiles(pred_mol),
'true_mol_smi': Chem.MolToSmiles(true_mol),
'frag_smi': Chem.MolToSmiles(true_frag)
})
# Validity according to DeLinker paper:
# Passing rdkit.Chem.Sanitize and the biggest fragment contains both fragments
valid_data = get_valid_as_in_delinker(data)
validity_as_in_delinker = len(valid_data) / len(data)
if len(valid_data) == 0:
return default_values
# Compute linkers and add to results
valid_data = compute_and_add_linker_smiles(valid_data)
# Compute uniqueness
uniqueness = compute_uniqueness(valid_data)
# Compute novelty
novelty = compute_novelty(valid_data)
# Compute recovered molecules
recovery_rate = compute_recovery_rate(valid_data)
# 2D filters
pass_all, pass_SA, pass_RA, pass_PAINS = calc_filters_2d_dataset(valid_data)
# 3D Filters
sc_rdkit = sc_rdkit_score(valid_data)
return {
'DeLinker/validity': validity_as_in_delinker,
'DeLinker/uniqueness': uniqueness,
'DeLinker/novelty': novelty,
'DeLinker/recovery': recovery_rate,
'DeLinker/2D_filters': pass_all,
'DeLinker/2D_filters_SA': pass_SA,
'DeLinker/2D_filters_RA': pass_RA,
'DeLinker/2D_filters_PAINS': pass_PAINS,
'DeLinker/SC_RDKit': sc_rdkit,
}