DiffLinker / src /metrics.py
igashov
DiffLinker code
95ba5bc
raw
history blame contribute delete
No virus
5.17 kB
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from src import const
from src.molecule_builder import get_bond_order
from scipy.stats import wasserstein_distance
from pdb import set_trace
def is_valid(mol):
try:
Chem.SanitizeMol(mol)
except ValueError:
return False
return True
def is_connected(mol):
try:
mol_frags = Chem.GetMolFrags(mol, asMols=True)
except Chem.rdchem.AtomValenceException:
return False
if len(mol_frags) != 1:
return False
return True
def get_valid_molecules(molecules):
valid = []
for mol in molecules:
if is_valid(mol):
valid.append(mol)
return valid
def get_connected_molecules(molecules):
connected = []
for mol in molecules:
if is_connected(mol):
connected.append(mol)
return connected
def get_unique_smiles(valid_molecules):
unique = set()
for mol in valid_molecules:
unique.add(Chem.MolToSmiles(mol))
return list(unique)
def get_novel_smiles(unique_true_smiles, unique_pred_smiles):
return list(set(unique_pred_smiles).difference(set(unique_true_smiles)))
def compute_energy(mol):
mp = AllChem.MMFFGetMoleculeProperties(mol)
energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy()
return energy
def wasserstein_distance_between_energies(true_molecules, pred_molecules):
true_energy_dist = []
for mol in true_molecules:
try:
energy = compute_energy(mol)
true_energy_dist.append(energy)
except:
continue
pred_energy_dist = []
for mol in pred_molecules:
try:
energy = compute_energy(mol)
pred_energy_dist.append(energy)
except:
continue
if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0:
return wasserstein_distance(true_energy_dist, pred_energy_dist)
else:
return 0
def compute_metrics(pred_molecules, true_molecules):
if len(pred_molecules) == 0:
return {
'validity': 0,
'validity_and_connectivity': 0,
'validity_as_in_delinker': 0,
'uniqueness': 0,
'novelty': 0,
'energies': 0,
}
# Passing rdkit.Chem.Sanitize filter
true_valid = get_valid_molecules(true_molecules)
pred_valid = get_valid_molecules(pred_molecules)
validity = len(pred_valid) / len(pred_molecules)
# Checking if molecule consists of a single connected part
true_valid_and_connected = get_connected_molecules(true_valid)
pred_valid_and_connected = get_connected_molecules(pred_valid)
validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules)
# Unique molecules
true_unique = get_unique_smiles(true_valid_and_connected)
pred_unique = get_unique_smiles(pred_valid_and_connected)
uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0
# Novel molecules
pred_novel = get_novel_smiles(true_unique, pred_unique)
novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0
# Difference between Energy distributions
energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected)
return {
'validity': validity,
'validity_and_connectivity': validity_and_connectivity,
'uniqueness': uniqueness,
'novelty': novelty,
'energies': energies,
}
# def check_stability(positions, atom_types):
# assert len(positions.shape) == 2
# assert positions.shape[1] == 3
# x = positions[:, 0]
# y = positions[:, 1]
# z = positions[:, 2]
#
# nr_bonds = np.zeros(len(x), dtype='int')
# for i in range(len(x)):
# for j in range(i + 1, len(x)):
# p1 = np.array([x[i], y[i], z[i]])
# p2 = np.array([x[j], y[j], z[j]])
# dist = np.sqrt(np.sum((p1 - p2) ** 2))
# atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()]
# order = get_bond_order(atom1, atom2, dist)
# nr_bonds[i] += order
# nr_bonds[j] += order
# nr_stable_bonds = 0
# for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds):
# possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]]
# if type(possible_bonds) == int:
# is_stable = possible_bonds == nr_bonds_i
# else:
# is_stable = nr_bonds_i in possible_bonds
# nr_stable_bonds += int(is_stable)
#
# molecule_stable = nr_stable_bonds == len(x)
# return molecule_stable, nr_stable_bonds, len(x)
#
#
# def count_stable_molecules(one_hot, x, node_mask):
# stable_molecules = 0
# for i in range(len(one_hot)):
# mol_size = node_mask[i].sum()
# atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu()
# positions = x[i][:mol_size, :].detach().cpu()
# stable, _, _ = check_stability(positions, atom_types)
# stable_molecules += int(stable)
#
# return stable_molecules