|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from rdkit import Chem, DataStructs |
|
|
from rdkit.Chem import Descriptors, Crippen, Lipinski, QED |
|
|
from analysis.SA_Score.sascorer import calculateScore |
|
|
|
|
|
from analysis.molecule_builder import build_molecule |
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
class CategoricalDistribution: |
|
|
EPS = 1e-10 |
|
|
|
|
|
def __init__(self, histogram_dict, mapping): |
|
|
histogram = np.zeros(len(mapping)) |
|
|
for k, v in histogram_dict.items(): |
|
|
histogram[mapping[k]] = v |
|
|
|
|
|
|
|
|
self.p = histogram / histogram.sum() |
|
|
self.mapping = deepcopy(mapping) |
|
|
|
|
|
def kl_divergence(self, other_sample): |
|
|
sample_histogram = np.zeros(len(self.mapping)) |
|
|
for x in other_sample: |
|
|
|
|
|
sample_histogram[x] += 1 |
|
|
|
|
|
|
|
|
q = sample_histogram / sample_histogram.sum() |
|
|
|
|
|
return -np.sum(self.p * np.log(q / self.p + self.EPS)) |
|
|
|
|
|
|
|
|
def rdmol_to_smiles(rdmol): |
|
|
mol = Chem.Mol(rdmol) |
|
|
Chem.RemoveStereochemistry(mol) |
|
|
mol = Chem.RemoveHs(mol) |
|
|
return Chem.MolToSmiles(mol) |
|
|
|
|
|
|
|
|
class BasicMolecularMetrics(object): |
|
|
def __init__(self, dataset_info, dataset_smiles_list=None, |
|
|
connectivity_thresh=1.0): |
|
|
self.atom_decoder = dataset_info['atom_decoder'] |
|
|
if dataset_smiles_list is not None: |
|
|
dataset_smiles_list = set(dataset_smiles_list) |
|
|
self.dataset_smiles_list = dataset_smiles_list |
|
|
self.dataset_info = dataset_info |
|
|
self.connectivity_thresh = connectivity_thresh |
|
|
|
|
|
def compute_validity(self, generated): |
|
|
""" generated: list of couples (positions, atom_types)""" |
|
|
if len(generated) < 1: |
|
|
return [], 0.0 |
|
|
|
|
|
valid = [] |
|
|
for mol in generated: |
|
|
try: |
|
|
Chem.SanitizeMol(mol) |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
valid.append(mol) |
|
|
|
|
|
return valid, len(valid) / len(generated) |
|
|
|
|
|
def compute_connectivity(self, valid): |
|
|
""" Consider molecule connected if its largest fragment contains at |
|
|
least x% of all atoms, where x is determined by |
|
|
self.connectivity_thresh (defaults to 100%). """ |
|
|
if len(valid) < 1: |
|
|
return [], 0.0 |
|
|
|
|
|
connected = [] |
|
|
connected_smiles = [] |
|
|
for mol in valid: |
|
|
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) |
|
|
largest_mol = \ |
|
|
max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
|
|
if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: |
|
|
smiles = rdmol_to_smiles(largest_mol) |
|
|
if smiles is not None: |
|
|
connected_smiles.append(smiles) |
|
|
connected.append(largest_mol) |
|
|
|
|
|
return connected, len(connected_smiles) / len(valid), connected_smiles |
|
|
|
|
|
def compute_uniqueness(self, connected): |
|
|
""" valid: list of SMILES strings.""" |
|
|
if len(connected) < 1 or self.dataset_smiles_list is None: |
|
|
return [], 0.0 |
|
|
|
|
|
return list(set(connected)), len(set(connected)) / len(connected) |
|
|
|
|
|
def compute_novelty(self, unique): |
|
|
if len(unique) < 1: |
|
|
return [], 0.0 |
|
|
|
|
|
num_novel = 0 |
|
|
novel = [] |
|
|
for smiles in unique: |
|
|
if smiles not in self.dataset_smiles_list: |
|
|
novel.append(smiles) |
|
|
num_novel += 1 |
|
|
return novel, num_novel / len(unique) |
|
|
|
|
|
def evaluate_rdmols(self, rdmols): |
|
|
valid, validity = self.compute_validity(rdmols) |
|
|
print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%") |
|
|
|
|
|
connected, connectivity, connected_smiles = \ |
|
|
self.compute_connectivity(valid) |
|
|
print(f"Connectivity over {len(valid)} valid molecules: " |
|
|
f"{connectivity * 100 :.2f}%") |
|
|
|
|
|
unique, uniqueness = self.compute_uniqueness(connected_smiles) |
|
|
print(f"Uniqueness over {len(connected)} connected molecules: " |
|
|
f"{uniqueness * 100 :.2f}%") |
|
|
|
|
|
_, novelty = self.compute_novelty(unique) |
|
|
print(f"Novelty over {len(unique)} unique connected molecules: " |
|
|
f"{novelty * 100 :.2f}%") |
|
|
|
|
|
return [validity, connectivity, uniqueness, novelty], [valid, connected] |
|
|
|
|
|
def evaluate(self, generated): |
|
|
""" generated: list of pairs (positions: n x 3, atom_types: n [int]) |
|
|
the positions and atom types should already be masked. """ |
|
|
|
|
|
rdmols = [build_molecule(*graph, self.dataset_info) |
|
|
for graph in generated] |
|
|
return self.evaluate_rdmols(rdmols) |
|
|
|
|
|
|
|
|
class MoleculeProperties: |
|
|
|
|
|
@staticmethod |
|
|
def calculate_qed(rdmol): |
|
|
return QED.qed(rdmol) |
|
|
|
|
|
@staticmethod |
|
|
def calculate_sa(rdmol): |
|
|
sa = calculateScore(rdmol) |
|
|
return round((10 - sa) / 9, 2) |
|
|
|
|
|
@staticmethod |
|
|
def calculate_logp(rdmol): |
|
|
return Crippen.MolLogP(rdmol) |
|
|
|
|
|
@staticmethod |
|
|
def calculate_lipinski(rdmol): |
|
|
rule_1 = Descriptors.ExactMolWt(rdmol) < 500 |
|
|
rule_2 = Lipinski.NumHDonors(rdmol) <= 5 |
|
|
rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 |
|
|
rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) |
|
|
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 |
|
|
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
|
|
|
|
|
@classmethod |
|
|
def calculate_diversity(cls, pocket_mols): |
|
|
if len(pocket_mols) < 2: |
|
|
return 0.0 |
|
|
|
|
|
div = 0 |
|
|
total = 0 |
|
|
for i in range(len(pocket_mols)): |
|
|
for j in range(i + 1, len(pocket_mols)): |
|
|
div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) |
|
|
total += 1 |
|
|
return div / total |
|
|
|
|
|
@staticmethod |
|
|
def similarity(mol_a, mol_b): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fp1 = Chem.RDKFingerprint(mol_a) |
|
|
fp2 = Chem.RDKFingerprint(mol_b) |
|
|
return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
|
|
|
def evaluate(self, pocket_rdmols): |
|
|
""" |
|
|
Run full evaluation |
|
|
Args: |
|
|
pocket_rdmols: list of lists, the inner list contains all RDKit |
|
|
molecules generated for a pocket |
|
|
Returns: |
|
|
QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) |
|
|
""" |
|
|
|
|
|
for pocket in pocket_rdmols: |
|
|
for mol in pocket: |
|
|
Chem.SanitizeMol(mol) |
|
|
assert mol is not None, "only evaluate valid molecules" |
|
|
|
|
|
all_qed = [] |
|
|
all_sa = [] |
|
|
all_logp = [] |
|
|
all_lipinski = [] |
|
|
per_pocket_diversity = [] |
|
|
for pocket in tqdm(pocket_rdmols): |
|
|
all_qed.append([self.calculate_qed(mol) for mol in pocket]) |
|
|
all_sa.append([self.calculate_sa(mol) for mol in pocket]) |
|
|
all_logp.append([self.calculate_logp(mol) for mol in pocket]) |
|
|
all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) |
|
|
per_pocket_diversity.append(self.calculate_diversity(pocket)) |
|
|
|
|
|
print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " |
|
|
f"{len(pocket_rdmols)} pockets evaluated.") |
|
|
|
|
|
qed_flattened = [x for px in all_qed for x in px] |
|
|
print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") |
|
|
|
|
|
sa_flattened = [x for px in all_sa for x in px] |
|
|
print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") |
|
|
|
|
|
logp_flattened = [x for px in all_logp for x in px] |
|
|
print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") |
|
|
|
|
|
lipinski_flattened = [x for px in all_lipinski for x in px] |
|
|
print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") |
|
|
|
|
|
print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") |
|
|
|
|
|
return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity |
|
|
|
|
|
def evaluate_mean(self, rdmols): |
|
|
""" |
|
|
Run full evaluation and return mean of each property |
|
|
Args: |
|
|
rdmols: list of RDKit molecules |
|
|
Returns: |
|
|
QED, SA, LogP, Lipinski, and Diversity |
|
|
""" |
|
|
|
|
|
if len(rdmols) < 1: |
|
|
return 0.0, 0.0, 0.0, 0.0, 0.0 |
|
|
|
|
|
for mol in rdmols: |
|
|
Chem.SanitizeMol(mol) |
|
|
assert mol is not None, "only evaluate valid molecules" |
|
|
|
|
|
qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) |
|
|
sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) |
|
|
logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) |
|
|
lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) |
|
|
diversity = self.calculate_diversity(rdmols) |
|
|
|
|
|
return qed, sa, logp, lipinski, diversity |
|
|
|