DiffLinker / src /metrics.py
igashov
DiffLinker code
95ba5bc
raw
history blame
No virus
5.17 kB
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from src import const
from src.molecule_builder import get_bond_order
from scipy.stats import wasserstein_distance
from pdb import set_trace
def is_valid(mol):
try:
Chem.SanitizeMol(mol)
except ValueError:
return False
return True
def is_connected(mol):
try:
mol_frags = Chem.GetMolFrags(mol, asMols=True)
except Chem.rdchem.AtomValenceException:
return False
if len(mol_frags) != 1:
return False
return True
def get_valid_molecules(molecules):
valid = []
for mol in molecules:
if is_valid(mol):
valid.append(mol)
return valid
def get_connected_molecules(molecules):
connected = []
for mol in molecules:
if is_connected(mol):
connected.append(mol)
return connected
def get_unique_smiles(valid_molecules):
unique = set()
for mol in valid_molecules:
unique.add(Chem.MolToSmiles(mol))
return list(unique)
def get_novel_smiles(unique_true_smiles, unique_pred_smiles):
return list(set(unique_pred_smiles).difference(set(unique_true_smiles)))
def compute_energy(mol):
mp = AllChem.MMFFGetMoleculeProperties(mol)
energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy()
return energy
def wasserstein_distance_between_energies(true_molecules, pred_molecules):
true_energy_dist = []
for mol in true_molecules:
try:
energy = compute_energy(mol)
true_energy_dist.append(energy)
except:
continue
pred_energy_dist = []
for mol in pred_molecules:
try:
energy = compute_energy(mol)
pred_energy_dist.append(energy)
except:
continue
if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0:
return wasserstein_distance(true_energy_dist, pred_energy_dist)
else:
return 0
def compute_metrics(pred_molecules, true_molecules):
if len(pred_molecules) == 0:
return {
'validity': 0,
'validity_and_connectivity': 0,
'validity_as_in_delinker': 0,
'uniqueness': 0,
'novelty': 0,
'energies': 0,
}
# Passing rdkit.Chem.Sanitize filter
true_valid = get_valid_molecules(true_molecules)
pred_valid = get_valid_molecules(pred_molecules)
validity = len(pred_valid) / len(pred_molecules)
# Checking if molecule consists of a single connected part
true_valid_and_connected = get_connected_molecules(true_valid)
pred_valid_and_connected = get_connected_molecules(pred_valid)
validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules)
# Unique molecules
true_unique = get_unique_smiles(true_valid_and_connected)
pred_unique = get_unique_smiles(pred_valid_and_connected)
uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0
# Novel molecules
pred_novel = get_novel_smiles(true_unique, pred_unique)
novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0
# Difference between Energy distributions
energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected)
return {
'validity': validity,
'validity_and_connectivity': validity_and_connectivity,
'uniqueness': uniqueness,
'novelty': novelty,
'energies': energies,
}
# def check_stability(positions, atom_types):
# assert len(positions.shape) == 2
# assert positions.shape[1] == 3
# x = positions[:, 0]
# y = positions[:, 1]
# z = positions[:, 2]
#
# nr_bonds = np.zeros(len(x), dtype='int')
# for i in range(len(x)):
# for j in range(i + 1, len(x)):
# p1 = np.array([x[i], y[i], z[i]])
# p2 = np.array([x[j], y[j], z[j]])
# dist = np.sqrt(np.sum((p1 - p2) ** 2))
# atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()]
# order = get_bond_order(atom1, atom2, dist)
# nr_bonds[i] += order
# nr_bonds[j] += order
# nr_stable_bonds = 0
# for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds):
# possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]]
# if type(possible_bonds) == int:
# is_stable = possible_bonds == nr_bonds_i
# else:
# is_stable = nr_bonds_i in possible_bonds
# nr_stable_bonds += int(is_stable)
#
# molecule_stable = nr_stable_bonds == len(x)
# return molecule_stable, nr_stable_bonds, len(x)
#
#
# def count_stable_molecules(one_hot, x, node_mask):
# stable_molecules = 0
# for i in range(len(one_hot)):
# mol_size = node_mask[i].sum()
# atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu()
# positions = x[i][:mol_size, :].detach().cpu()
# stable, _, _ = check_stability(positions, atom_types)
# stable_molecules += int(stable)
#
# return stable_molecules