Spaces:
Sleeping
Sleeping
import evaluate | |
import datasets | |
import pandas as pd | |
from tdc import Evaluator | |
from tdc import Oracle | |
from rdkit.Chem.QED import qed | |
from rdkit.Chem.Crippen import MolLogP | |
import os | |
from collections import Counter | |
from functools import partial | |
import numpy as np | |
import pandas as pd | |
import scipy.sparse | |
import torch | |
from rdkit import Chem | |
from rdkit.Chem import AllChem | |
from rdkit.Chem import MACCSkeys | |
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan | |
from rdkit.Chem.QED import qed | |
from rdkit.Chem.Scaffolds import MurckoScaffold | |
from rdkit.Chem import Descriptors | |
from multiprocessing import Pool | |
from collections import UserList, defaultdict | |
import numpy as np | |
import pandas as pd | |
from rdkit import rdBase | |
from rdkit.Contrib.SA_Score import sascorer | |
import sys | |
from rdkit.Chem import RDConfig | |
import os | |
import pandas as pd | |
from fcd_torch import FCD | |
from syba.syba import SybaClassifier | |
from myscscore.SCScore import SCScorer | |
import warnings | |
def get_mol(smiles_or_mol): | |
""" | |
Converts a SMILES string or RDKit molecule object to an RDKit molecule object. | |
If the input is already an RDKit molecule object, it returns it directly. | |
For a SMILES string, it attempts to create an RDKit molecule object. | |
Parameters: | |
- smiles_or_mol (str or Mol): The SMILES string of the molecule or an RDKit molecule object. | |
Returns: | |
- Mol or None: The RDKit molecule object or None if conversion fails. | |
""" | |
if isinstance(smiles_or_mol, str): | |
if len(smiles_or_mol) == 0: | |
return None | |
mol = Chem.MolFromSmiles(smiles_or_mol) | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return mol | |
return smiles_or_mol | |
def mapper(n_jobs): | |
""" | |
Returns a mapping function suitable for parallel or sequential execution | |
based on the value of n_jobs. | |
Parameters: | |
- n_jobs (int or Pool): Number of jobs for parallel execution or a multiprocessing Pool object. | |
Returns: | |
- Function: A mapping function that can be used for applying a function over a sequence. | |
""" | |
if n_jobs == 1: | |
def _mapper(*args, **kwargs): | |
return list(map(*args, **kwargs)) | |
return _mapper | |
if isinstance(n_jobs, int): | |
pool = Pool(n_jobs) | |
def _mapper(*args, **kwargs): | |
try: | |
result = pool.map(*args, **kwargs) | |
finally: | |
pool.terminate() | |
return result | |
return _mapper | |
return n_jobs.map | |
def fraction_valid(gen, n_jobs=1): | |
""" | |
Calculates the fraction of valid molecules in a list of SMILES strings. | |
Parameters: | |
- gen (list of str): List of SMILES strings. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
Returns: | |
- float: Fraction of valid molecules. | |
""" | |
gen = mapper(n_jobs)(get_mol, gen) | |
return 1 - gen.count(None) / len(gen) | |
def canonic_smiles(smiles_or_mol): | |
""" | |
Converts a molecule into its canonical SMILES representation. | |
Parameters: | |
- smiles_or_mol (str or Mol): SMILES string or RDKit molecule object. | |
Returns: | |
- str or None: Canonical SMILES string, or None if conversion fails. | |
""" | |
mol = get_mol(smiles_or_mol) | |
if mol is None: | |
return None | |
return Chem.MolToSmiles(mol) | |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True): | |
""" | |
Calculates the fraction of unique molecules in a list of SMILES strings. | |
Parameters: | |
- gen (list of str): List of SMILES strings. | |
- k (int, optional): Number of top molecules to consider for uniqueness. If None, considers all. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
- check_validity (bool): If True, checks for the validity of molecules. | |
Returns: | |
- float: Fraction of unique molecules. | |
""" | |
if k is not None: | |
if len(gen) < k: | |
warnings.warn( | |
"Can't compute unique@{}.".format(k) + | |
"gen contains only {} molecules".format(len(gen)) | |
) | |
gen = gen[:k] | |
canonic = set(mapper(n_jobs)(canonic_smiles, gen)) | |
if None in canonic and check_validity: | |
raise ValueError("Invalid molecule passed to unique@k") | |
return len(canonic) / len(gen) | |
def novelty(gen, train, n_jobs=1): | |
""" | |
Computes the novelty of generated molecules compared to a training set. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- train (List[str]): List of SMILES strings from the training set. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
Returns: | |
- float: Novelty score. | |
""" | |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen) | |
gen_smiles_set = set(gen_smiles) - {None} | |
train_set = set(train) | |
return len(gen_smiles_set - train_set) / len(gen_smiles_set) | |
def synthetic_complexity_score(gen): | |
""" | |
Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings. | |
Parameters: | |
- gen (list of str): A list containing the SMILES representations of the molecules. | |
Returns: | |
- float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found. | |
""" | |
model = SCScorer() | |
model.restore() | |
average_score = model.get_avg_score(gen) | |
return average_score | |
def calculate_sa_score(smiles): | |
""" | |
Calculates the SA score for a single SMILES string. | |
Parameters: | |
- smiles (str): SMILES string of the molecule. | |
Returns: | |
- float: SA score of the molecule, or None if the molecule couldn't be created. | |
""" | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: | |
return sascorer.calculateScore(mol) | |
else: | |
return None | |
def average_sascore(gen, n_jobs=1): | |
""" | |
Computes the average synthetic accessibility score for a list of molecules | |
using parallel or sequential execution based on the n_jobs parameter. | |
Parameters: | |
- molecules (List[str]): List of generated SMILES strings. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
Returns: | |
- float: Average SA score, or None if no scores could be computed. | |
""" | |
scores = mapper(n_jobs)(calculate_sa_score, gen) | |
# Filter out None values which indicate failed molecule creation | |
valid_scores = [score for score in scores if score is not None] | |
if valid_scores: | |
return sum(valid_scores) / len(valid_scores) | |
else: | |
return None | |
def average_agg_tanimoto(stock_vecs, gen_vecs, | |
batch_size=5000, agg='max', | |
device='cpu', p=1): | |
""" | |
Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints. | |
Parameters: | |
- stock_vecs (numpy array): Fingerprint vectors for the reference molecule set. | |
- gen_vecs (numpy array): Fingerprint vectors for the generated molecule set. | |
- batch_size (int): The size of batches to process similarities (reduces memory usage). | |
- agg (str): Aggregation method, either 'max' or 'mean'. | |
- device (str): The computation device ('cpu' or 'cuda:0', etc.). | |
- p (float): The power for averaging, used in generalized mean calculation. | |
Returns: | |
- float: Average aggregate Tanimoto similarity score. | |
""" | |
assert agg in ['max', 'mean'], "Can aggregate only max or mean" | |
agg_tanimoto = np.zeros(len(gen_vecs)) | |
total = np.zeros(len(gen_vecs)) | |
for j in range(0, stock_vecs.shape[0], batch_size): | |
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() | |
for i in range(0, gen_vecs.shape[0], batch_size): | |
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() | |
y_gen = y_gen.transpose(0, 1) | |
tp = torch.mm(x_stock, y_gen) | |
jac = (tp / (x_stock.sum(1, keepdim=True) + | |
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() | |
jac[np.isnan(jac)] = 1 | |
if p != 1: | |
jac = jac**p | |
if agg == 'max': | |
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( | |
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) | |
elif agg == 'mean': | |
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) | |
total[i:i + y_gen.shape[1]] += jac.shape[0] | |
if agg == 'mean': | |
agg_tanimoto /= total | |
if p != 1: | |
agg_tanimoto = (agg_tanimoto)**(1/p) | |
return np.mean(agg_tanimoto) | |
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, | |
morgan__n=1024, *args, **kwargs): | |
""" | |
Generates fingerprint for SMILES | |
If smiles is invalid, returns None | |
Returns numpy array of fingerprint bits | |
Parameters: | |
smiles: SMILES string | |
type: type of fingerprint: [MACCS|morgan] | |
dtype: if not None, specifies the dtype of returned array | |
""" | |
fp_type = fp_type.lower() | |
molecule = get_mol(smiles_or_mol, *args, **kwargs) | |
if molecule is None: | |
return None | |
if fp_type == 'maccs': | |
keys = MACCSkeys.GenMACCSKeys(molecule) | |
keys = np.array(keys.GetOnBits()) | |
fingerprint = np.zeros(166, dtype='uint8') | |
if len(keys) != 0: | |
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero | |
elif fp_type == 'morgan': | |
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), | |
dtype='uint8') | |
else: | |
raise ValueError("Unknown fingerprint type {}".format(fp_type)) | |
if dtype is not None: | |
fingerprint = fingerprint.astype(dtype) | |
return fingerprint | |
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, | |
**kwargs): | |
''' | |
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers | |
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) | |
Inserts np.NaN to rows corresponding to incorrect smiles. | |
IMPORTANT: if there is at least one np.NaN, the dtype would be float | |
Parameters: | |
smiles_mols_array: list/array/pd.Series of smiles or already computed | |
RDKit molecules | |
n_jobs: number of parralel workers to execute | |
already_unique: flag for performance reasons, if smiles array is big | |
and already unique. Its value is set to True if smiles_mols_array | |
contain RDKit molecules already. | |
''' | |
if isinstance(smiles_mols_array, pd.Series): | |
smiles_mols_array = smiles_mols_array.values | |
else: | |
smiles_mols_array = np.asarray(smiles_mols_array) | |
if not isinstance(smiles_mols_array[0], str): | |
already_unique = True | |
if not already_unique: | |
smiles_mols_array, inv_index = np.unique(smiles_mols_array, | |
return_inverse=True) | |
fps = mapper(n_jobs)( | |
partial(fingerprint, *args, **kwargs), smiles_mols_array | |
) | |
length = 1 | |
for fp in fps: | |
if fp is not None: | |
length = fp.shape[-1] | |
first_fp = fp | |
break | |
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] | |
for fp in fps] | |
if scipy.sparse.issparse(first_fp): | |
fps = scipy.sparse.vstack(fps).tocsr() | |
else: | |
fps = np.vstack(fps) | |
if not already_unique: | |
return fps[inv_index] | |
return fps | |
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan', | |
gen_fps=None, p=1): | |
""" | |
Computes internal diversity as: | |
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- n_jobs (int): Number of parallel jobs for fingerprint computation. | |
- device (str): Computation device ('cpu' or 'cuda:0', etc.). | |
- fp_type (str): Type of fingerprint to use ('morgan', etc.). | |
- gen_fps (Optional[np.ndarray]): Precomputed fingerprints of generated molecules. If None, will be computed. | |
Returns: | |
- float: Internal diversity score. | |
""" | |
if gen_fps is None: | |
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs) | |
return 1 - (average_agg_tanimoto(gen_fps, gen_fps, | |
agg='mean', device=device, p=p)).mean() | |
def fcd_metric(gen, train, n_jobs = 1, device = None): | |
""" | |
Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- train (List[str]): List of training set SMILES strings. | |
- n_jobs (int): Number of parallel jobs for computation. | |
- device (str): Computation device for the FCD calculation. | |
Returns: | |
- float: FCD score. | |
""" | |
# Determine the device dynamically based on CUDA availability | |
if device is None: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
else: | |
device = torch.device(device if torch.cuda.is_available() and 'cuda' in device else 'cpu') | |
fcd = FCD(device=device, n_jobs= n_jobs) | |
return fcd(gen, train) | |
def SYBAscore(gen): | |
""" | |
Compute the average SYBA score for a list of SMILES strings. | |
Parameters: | |
- smiles_list (list of str): A list of SMILES strings representing molecules. | |
Returns: | |
- float: The average SYBA score for the list of molecules. | |
""" | |
syba = SybaClassifier() | |
syba.fitDefaultScore() | |
scores = [] | |
for smiles in gen: | |
try: | |
score = syba.predict(smi=smiles) | |
scores.append(score) | |
except Exception as e: | |
print(f"Error processing SMILES '{smiles}': {e}") | |
continue | |
if scores: | |
return sum(scores) / len(scores) | |
else: | |
return None # Or handle empty list or all failed predictions as needed | |
def qed_metric(gen): | |
""" | |
Computes RDKit's QED score | |
""" | |
if not gen: | |
return 0.0 # Return 0 or suitable value for empty list | |
# Convert SMILES strings to RDKit molecule objects and calculate QED scores | |
qed_scores = [] | |
for smiles in gen: | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: # Ensure molecule is valid | |
qed_scores.append(qed(mol)) | |
except Exception as e: | |
print(f"Error processing molecule {smiles}: {str(e)}") | |
# Calculate the average QED score | |
if qed_scores: | |
return sum(qed_scores) / len(qed_scores) | |
else: | |
return 0.0 # Return 0 or suitable value if no valid molecules are processed | |
def logP_metric(gen): | |
""" | |
Computes the average RDKit's logP value for a list of SMILES strings. | |
Parameters: | |
- mols (List[str]): List of SMILES strings representing the molecules. | |
Returns: | |
- float: Average logP value for the list of molecules. | |
""" | |
# Check if the input list is empty | |
if not gen: | |
return 0.0 # Return 0 or suitable value for empty list | |
# Convert SMILES strings to RDKit molecule objects and calculate logP values | |
logP_values = [] | |
for smiles in gen: | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: # Ensure molecule is valid | |
logP_values.append(MolLogP(mol)) | |
except Exception as e: | |
print(f"Error processing molecule {smiles}: {str(e)}") | |
# Calculate the average logP value | |
if logP_values: | |
return sum(logP_values) / len(logP_values) | |
else: | |
return 0.0 # Return 0 or suitable value if no valid molecules are processed | |
def oracles(gen, train): | |
""" | |
Computes scores from various oracles for a list of generated molecules. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- train (List[str]): List of training set SMILES strings. | |
Returns: | |
- Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values. | |
""" | |
result = {} | |
# oracle_list = [ | |
# 'QED', 'MPO', 'GSK3B', 'JNK3', | |
# 'DRD2', 'LogP', 'Rediscovery', 'Similarity', | |
# 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop' | |
# ] | |
oracle_list = ['QED', 'LogP', 'SA'] | |
for oracle_name in oracle_list: | |
print(oracle_name) | |
oracle = Oracle(name=oracle_name) | |
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']: | |
score = oracle(gen) | |
if isinstance(score, dict): | |
score = {key: sum(values)/len(values) for key, values in score.items()} | |
else: | |
score = oracle(gen) | |
if isinstance(score, list): | |
score = sum(score) / len(score) | |
result[f"{oracle_name}"] = score | |
return result | |
_DESCRIPTION = """ | |
Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples. | |
train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules. | |
Returns: | |
Dectionary item containing various metrics to evaluate model performance | |
""" | |
_CITATION = """ | |
@article{DBLP:journals/corr/abs-1811-12823, | |
author = {Daniil Polykovskiy and | |
Alexander Zhebrak and | |
Benjam{\'{\i}}n S{\'{a}}nchez{-}Lengeling and | |
Sergey Golovanov and | |
Oktai Tatanov and | |
Stanislav Belyaev and | |
Rauf Kurbanov and | |
Aleksey Artamonov and | |
Vladimir Aladinskiy and | |
Mark Veselov and | |
Artur Kadurin and | |
Sergey I. Nikolenko and | |
Al{\'{a}}n Aspuru{-}Guzik and | |
Alex Zhavoronkov}, | |
title = {Molecular Sets {(MOSES):} {A} Benchmarking Platform for Molecular | |
Generation Models}, | |
journal = {CoRR}, | |
volume = {abs/1811.12823}, | |
year = {2018}, | |
url = {http://arxiv.org/abs/1811.12823}, | |
eprinttype = {arXiv}, | |
eprint = {1811.12823}, | |
timestamp = {Fri, 26 Nov 2021 15:34:30 +0100}, | |
biburl = {https://dblp.org/rec/journals/corr/abs-1811-12823.bib}, | |
bibsource = {dblp computer science bibliography, https://dblp.org} | |
} | |
""" | |
class molgenevalmetric(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"gensmi": datasets.Sequence(datasets.Value("string")), | |
"trainsmi": datasets.Sequence(datasets.Value("string")), | |
} | |
if self.config_name == "multilabel" | |
else { | |
"gensmi": datasets.Value("string"), | |
"trainsmi": datasets.Value("string"), | |
} | |
), | |
reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"], | |
) | |
def _compute(self, gensmi, trainsmi): | |
metrics = {} | |
metrics['Novelty'] = novelty(gen = gensmi, train = trainsmi) | |
metrics['Valid'] = fraction_valid(gen=gensmi) | |
metrics['Unique'] = fraction_unique(gen=gensmi) | |
metrics['IntDiv'] = internal_diversity(gen=gensmi) | |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi) | |
metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi) | |
metrics['QED'] = qed_metric(gen=gensmi) | |
metrics['LogP'] = logP_metric(gen=gensmi) | |
metrics['SA'] = average_sascore(gen=gensmi) | |
metrics['SCS'] = synthetic_complexity_score(gen=gensmi) | |
metrics['SYBA'] = SYBAscore(gen=gensmi) | |
return metrics | |
# generated_smiles = [s for s in generated_smiles if s != ''] | |
# evaluator = Evaluator(name = 'KL_Divergence') | |
# KL_Divergence = evaluator(generated_smiles, train_smiles) | |
# Results.update({ | |
# "KL_Divergence": KL_Divergence, | |
# }) | |
# oracle_list = [ | |
# 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3', | |
# 'DRD2', 'LogP', 'Rediscovery', 'Similarity', | |
# 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop' | |
# ] | |
# for oracle_name in oracle_list: | |
# oracle = Oracle(name=oracle_name) | |
# if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']: | |
# score = oracle(generated_smiles) | |
# if isinstance(score, dict): | |
# score = {key: sum(values)/len(values) for key, values in score.items()} | |
# else: | |
# score = oracle(generated_smiles) | |
# if isinstance(score, list): | |
# score = sum(score) / len(score) | |
# Results.update({f"{oracle_name}": score}) | |
# # keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"] | |
# # for key in keys_to_remove: | |
# # Results.pop(key, None) | |
# return {"results": Results} | |