Spaces:
Sleeping
Sleeping
import evaluate | |
import datasets | |
# import moses | |
# from moses import metrics | |
import pandas as pd | |
# from tdc import Evaluator | |
# from tdc import Oracle | |
# from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles | |
import os | |
from collections import Counter | |
from functools import partial | |
import numpy as np | |
import pandas as pd | |
import scipy.sparse | |
import torch | |
from rdkit import Chem | |
from rdkit.Chem import AllChem | |
from rdkit.Chem import MACCSkeys | |
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan | |
from rdkit.Chem.QED import qed | |
from rdkit.Chem.Scaffolds import MurckoScaffold | |
from rdkit.Chem import Descriptors | |
import random | |
from multiprocessing import Pool | |
from collections import UserList, defaultdict | |
import numpy as np | |
import pandas as pd | |
from rdkit import rdBase | |
import sys | |
from rdkit.Chem import RDConfig | |
import os | |
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score')) | |
import sascorer | |
import pandas as pd | |
from fcd_torch import FCD | |
from syba.syba import SybaClassifier | |
from tdc import Evaluator | |
from tdc import Oracle | |
def get_mol(smiles_or_mol): | |
''' | |
Loads SMILES/molecule into RDKit's object | |
''' | |
if isinstance(smiles_or_mol, str): | |
if len(smiles_or_mol) == 0: | |
return None | |
mol = Chem.MolFromSmiles(smiles_or_mol) | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return mol | |
return smiles_or_mol | |
def mapper(n_jobs): | |
''' | |
Returns function for map call. | |
If n_jobs == 1, will use standard map | |
If n_jobs > 1, will use multiprocessing pool | |
If n_jobs is a pool object, will return its map function | |
''' | |
if n_jobs == 1: | |
def _mapper(*args, **kwargs): | |
return list(map(*args, **kwargs)) | |
return _mapper | |
if isinstance(n_jobs, int): | |
pool = Pool(n_jobs) | |
def _mapper(*args, **kwargs): | |
try: | |
result = pool.map(*args, **kwargs) | |
finally: | |
pool.terminate() | |
return result | |
return _mapper | |
return n_jobs.map | |
def fraction_valid(gen, n_jobs=1): | |
""" | |
Computes a number of valid molecules | |
Parameters: | |
gen: list of SMILES | |
n_jobs: number of threads for calculation | |
""" | |
gen = mapper(n_jobs)(get_mol, gen) | |
return 1 - gen.count(None) / len(gen) | |
def canonic_smiles(smiles_or_mol): | |
mol = get_mol(smiles_or_mol) | |
if mol is None: | |
return None | |
return Chem.MolToSmiles(mol) | |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True): | |
""" | |
Computes a number of unique molecules | |
Parameters: | |
gen: list of SMILES | |
k: compute unique@k | |
n_jobs: number of threads for calculation | |
check_validity: raises ValueError if invalid molecules are present | |
""" | |
if k is not None: | |
if len(gen) < k: | |
warnings.warn( | |
"Can't compute unique@{}.".format(k) + | |
"gen contains only {} molecules".format(len(gen)) | |
) | |
gen = gen[:k] | |
canonic = set(mapper(n_jobs)(canonic_smiles, gen)) | |
if None in canonic and check_validity: | |
raise ValueError("Invalid molecule passed to unique@k") | |
return len(canonic) / len(gen) | |
def novelty(gen, train, n_jobs=1): | |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen) | |
gen_smiles_set = set(gen_smiles) - {None} | |
train_set = set(train) | |
return len(gen_smiles_set - train_set) / len(gen_smiles_set) | |
def SAscore(gen): | |
""" | |
Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings. | |
Parameters: | |
- smiles_list (list of str): A list containing the SMILES representations of the molecules. | |
Returns: | |
- float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found. | |
""" | |
scores = [] | |
for smiles in gen: | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: # Ensures the molecule could be parsed from the SMILES string | |
score = sascorer.calculateScore(mol) | |
scores.append(score) | |
if scores: # Checks if there are any scores calculated | |
return np.mean(scores) | |
else: | |
return None | |
def average_agg_tanimoto(stock_vecs, gen_vecs, | |
batch_size=5000, agg='max', | |
device='cpu', p=1): | |
""" | |
For each molecule in gen_vecs finds closest molecule in stock_vecs. | |
Returns average tanimoto score for between these molecules | |
Parameters: | |
stock_vecs: numpy array <n_vectors x dim> | |
gen_vecs: numpy array <n_vectors' x dim> | |
agg: max or mean | |
p: power for averaging: (mean x^p)^(1/p) | |
""" | |
assert agg in ['max', 'mean'], "Can aggregate only max or mean" | |
agg_tanimoto = np.zeros(len(gen_vecs)) | |
total = np.zeros(len(gen_vecs)) | |
for j in range(0, stock_vecs.shape[0], batch_size): | |
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() | |
for i in range(0, gen_vecs.shape[0], batch_size): | |
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() | |
y_gen = y_gen.transpose(0, 1) | |
tp = torch.mm(x_stock, y_gen) | |
jac = (tp / (x_stock.sum(1, keepdim=True) + | |
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() | |
jac[np.isnan(jac)] = 1 | |
if p != 1: | |
jac = jac**p | |
if agg == 'max': | |
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( | |
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) | |
elif agg == 'mean': | |
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) | |
total[i:i + y_gen.shape[1]] += jac.shape[0] | |
if agg == 'mean': | |
agg_tanimoto /= total | |
if p != 1: | |
agg_tanimoto = (agg_tanimoto)**(1/p) | |
return np.mean(agg_tanimoto) | |
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, | |
morgan__n=1024, *args, **kwargs): | |
""" | |
Generates fingerprint for SMILES | |
If smiles is invalid, returns None | |
Returns numpy array of fingerprint bits | |
Parameters: | |
smiles: SMILES string | |
type: type of fingerprint: [MACCS|morgan] | |
dtype: if not None, specifies the dtype of returned array | |
""" | |
fp_type = fp_type.lower() | |
molecule = get_mol(smiles_or_mol, *args, **kwargs) | |
if molecule is None: | |
return None | |
if fp_type == 'maccs': | |
keys = MACCSkeys.GenMACCSKeys(molecule) | |
keys = np.array(keys.GetOnBits()) | |
fingerprint = np.zeros(166, dtype='uint8') | |
if len(keys) != 0: | |
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero | |
elif fp_type == 'morgan': | |
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), | |
dtype='uint8') | |
else: | |
raise ValueError("Unknown fingerprint type {}".format(fp_type)) | |
if dtype is not None: | |
fingerprint = fingerprint.astype(dtype) | |
return fingerprint | |
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, | |
**kwargs): | |
''' | |
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers | |
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) | |
Inserts np.NaN to rows corresponding to incorrect smiles. | |
IMPORTANT: if there is at least one np.NaN, the dtype would be float | |
Parameters: | |
smiles_mols_array: list/array/pd.Series of smiles or already computed | |
RDKit molecules | |
n_jobs: number of parralel workers to execute | |
already_unique: flag for performance reasons, if smiles array is big | |
and already unique. Its value is set to True if smiles_mols_array | |
contain RDKit molecules already. | |
''' | |
if isinstance(smiles_mols_array, pd.Series): | |
smiles_mols_array = smiles_mols_array.values | |
else: | |
smiles_mols_array = np.asarray(smiles_mols_array) | |
if not isinstance(smiles_mols_array[0], str): | |
already_unique = True | |
if not already_unique: | |
smiles_mols_array, inv_index = np.unique(smiles_mols_array, | |
return_inverse=True) | |
fps = mapper(n_jobs)( | |
partial(fingerprint, *args, **kwargs), smiles_mols_array | |
) | |
length = 1 | |
for fp in fps: | |
if fp is not None: | |
length = fp.shape[-1] | |
first_fp = fp | |
break | |
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] | |
for fp in fps] | |
if scipy.sparse.issparse(first_fp): | |
fps = scipy.sparse.vstack(fps).tocsr() | |
else: | |
fps = np.vstack(fps) | |
if not already_unique: | |
return fps[inv_index] | |
return fps | |
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan', | |
gen_fps=None, p=1): | |
""" | |
Computes internal diversity as: | |
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) | |
""" | |
if gen_fps is None: | |
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs) | |
return 1 - (average_agg_tanimoto(gen_fps, gen_fps, | |
agg='mean', device=device, p=p)).mean() | |
def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'): | |
fcd = FCD(device=device, n_jobs= n_jobs) | |
return fcd(gen, train) | |
def SYBAscore(gen): | |
""" | |
Compute the average SYBA score for a list of SMILES strings. | |
Parameters: | |
- smiles_list (list of str): A list of SMILES strings representing molecules. | |
Returns: | |
- float: The average SYBA score for the list of molecules. | |
""" | |
syba = SybaClassifier() | |
syba.fitDefaultScore() | |
scores = [] | |
for smiles in gen: | |
try: | |
score = syba.predict(smi=smiles) | |
scores.append(score) | |
except Exception as e: | |
print(f"Error processing SMILES '{smiles}': {e}") | |
continue | |
if scores: | |
return sum(scores) / len(scores) | |
else: | |
return None # Or handle empty list or all failed predictions as needed | |
def oracles(gen, train): | |
Result = {} | |
# evaluator = Evaluator(name = 'KL_Divergence') | |
# KL_Divergence = evaluator(gen, train) | |
# Result["KL_Divergence"]: KL_Divergence | |
oracle_list = [ | |
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3', | |
'DRD2', 'LogP', 'Rediscovery', 'Similarity', | |
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop' | |
] | |
for oracle_name in oracle_list: | |
oracle = Oracle(name=oracle_name) | |
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']: | |
score = oracle(gen) | |
if isinstance(score, dict): | |
score = {key: sum(values)/len(values) for key, values in score.items()} | |
else: | |
score = oracle(gen) | |
if isinstance(score, list): | |
score = sum(score) / len(score) | |
Result[f"{oracle_name}"] = score | |
return Result | |
_DESCRIPTION = """ | |
Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples. | |
train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules. | |
Returns: | |
Dectionary item containing various metrics to evaluate model performance | |
""" | |
_CITATION = """ | |
@article{DBLP:journals/corr/abs-1811-12823, | |
author = {Daniil Polykovskiy and | |
Alexander Zhebrak and | |
Benjam{\'{\i}}n S{\'{a}}nchez{-}Lengeling and | |
Sergey Golovanov and | |
Oktai Tatanov and | |
Stanislav Belyaev and | |
Rauf Kurbanov and | |
Aleksey Artamonov and | |
Vladimir Aladinskiy and | |
Mark Veselov and | |
Artur Kadurin and | |
Sergey I. Nikolenko and | |
Al{\'{a}}n Aspuru{-}Guzik and | |
Alex Zhavoronkov}, | |
title = {Molecular Sets {(MOSES):} {A} Benchmarking Platform for Molecular | |
Generation Models}, | |
journal = {CoRR}, | |
volume = {abs/1811.12823}, | |
year = {2018}, | |
url = {http://arxiv.org/abs/1811.12823}, | |
eprinttype = {arXiv}, | |
eprint = {1811.12823}, | |
timestamp = {Fri, 26 Nov 2021 15:34:30 +0100}, | |
biburl = {https://dblp.org/rec/journals/corr/abs-1811-12823.bib}, | |
bibsource = {dblp computer science bibliography, https://dblp.org} | |
} | |
""" | |
class molgenevalmetric(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"generated_smiles": datasets.Sequence(datasets.Value("string")), | |
"train_smiles": datasets.Sequence(datasets.Value("string")), | |
} | |
if self.config_name == "multilabel" | |
else { | |
"generated_smiles": datasets.Value("string"), | |
"train_smiles": datasets.Value("string"), | |
} | |
), | |
reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"], | |
) | |
def _compute(self, gensmi, trainsmi): | |
metrics = {} | |
metrics['novelty'] = novelty(gen = gensmi, train = trainsmi) | |
metrics['valid'] = fraction_valid(gen=gensmi) | |
metrics['unique'] = fraction_unique(gen=gensmi) | |
metrics['IntDiv'] = internal_diversity(gen=gensmi) | |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi) | |
metrics['SA'] = SAscore(gen=gensmi) | |
metrics['SCS'] = SAscore(gen=trainsmi) | |
return metrics | |
# generated_smiles = [s for s in generated_smiles if s != ''] | |
# evaluator = Evaluator(name = 'KL_Divergence') | |
# KL_Divergence = evaluator(generated_smiles, train_smiles) | |
# Results.update({ | |
# "KL_Divergence": KL_Divergence, | |
# }) | |
# oracle_list = [ | |
# 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3', | |
# 'DRD2', 'LogP', 'Rediscovery', 'Similarity', | |
# 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop' | |
# ] | |
# for oracle_name in oracle_list: | |
# oracle = Oracle(name=oracle_name) | |
# if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']: | |
# score = oracle(generated_smiles) | |
# if isinstance(score, dict): | |
# score = {key: sum(values)/len(values) for key, values in score.items()} | |
# else: | |
# score = oracle(generated_smiles) | |
# if isinstance(score, list): | |
# score = sum(score) / len(score) | |
# Results.update({f"{oracle_name}": score}) | |
# # keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"] | |
# # for key in keys_to_remove: | |
# # Results.pop(key, None) | |
# return {"results": Results} | |