import evaluate import datasets # import moses # from moses import metrics import pandas as pd # from tdc import Evaluator # from tdc import Oracle # from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles import os from collections import Counter from functools import partial import numpy as np import pandas as pd import scipy.sparse import torch from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import MACCSkeys from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan from rdkit.Chem.QED import qed from rdkit.Chem.Scaffolds import MurckoScaffold from rdkit.Chem import Descriptors import random from multiprocessing import Pool from collections import UserList, defaultdict import numpy as np import pandas as pd from rdkit import rdBase import sys from rdkit.Chem import RDConfig import os sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score')) import sascorer import pandas as pd from fcd_torch import FCD from syba.syba import SybaClassifier from tdc import Evaluator from tdc import Oracle def get_mol(smiles_or_mol): ''' Loads SMILES/molecule into RDKit's object ''' if isinstance(smiles_or_mol, str): if len(smiles_or_mol) == 0: return None mol = Chem.MolFromSmiles(smiles_or_mol) if mol is None: return None try: Chem.SanitizeMol(mol) except ValueError: return None return mol return smiles_or_mol def mapper(n_jobs): ''' Returns function for map call. If n_jobs == 1, will use standard map If n_jobs > 1, will use multiprocessing pool If n_jobs is a pool object, will return its map function ''' if n_jobs == 1: def _mapper(*args, **kwargs): return list(map(*args, **kwargs)) return _mapper if isinstance(n_jobs, int): pool = Pool(n_jobs) def _mapper(*args, **kwargs): try: result = pool.map(*args, **kwargs) finally: pool.terminate() return result return _mapper return n_jobs.map def fraction_valid(gen, n_jobs=1): """ Computes a number of valid molecules Parameters: gen: list of SMILES n_jobs: number of threads for calculation """ gen = mapper(n_jobs)(get_mol, gen) return 1 - gen.count(None) / len(gen) def canonic_smiles(smiles_or_mol): mol = get_mol(smiles_or_mol) if mol is None: return None return Chem.MolToSmiles(mol) def fraction_unique(gen, k=None, n_jobs=1, check_validity=True): """ Computes a number of unique molecules Parameters: gen: list of SMILES k: compute unique@k n_jobs: number of threads for calculation check_validity: raises ValueError if invalid molecules are present """ if k is not None: if len(gen) < k: warnings.warn( "Can't compute unique@{}.".format(k) + "gen contains only {} molecules".format(len(gen)) ) gen = gen[:k] canonic = set(mapper(n_jobs)(canonic_smiles, gen)) if None in canonic and check_validity: raise ValueError("Invalid molecule passed to unique@k") return len(canonic) / len(gen) def novelty(gen, train, n_jobs=1): gen_smiles = mapper(n_jobs)(canonic_smiles, gen) gen_smiles_set = set(gen_smiles) - {None} train_set = set(train) return len(gen_smiles_set - train_set) / len(gen_smiles_set) def SAscore(gen): """ Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings. Parameters: - smiles_list (list of str): A list containing the SMILES representations of the molecules. Returns: - float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found. """ scores = [] for smiles in gen: mol = Chem.MolFromSmiles(smiles) if mol: # Ensures the molecule could be parsed from the SMILES string score = sascorer.calculateScore(mol) scores.append(score) if scores: # Checks if there are any scores calculated return np.mean(scores) else: return None def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1): """ For each molecule in gen_vecs finds closest molecule in stock_vecs. Returns average tanimoto score for between these molecules Parameters: stock_vecs: numpy array gen_vecs: numpy array agg: max or mean p: power for averaging: (mean x^p)^(1/p) """ assert agg in ['max', 'mean'], "Can aggregate only max or mean" agg_tanimoto = np.zeros(len(gen_vecs)) total = np.zeros(len(gen_vecs)) for j in range(0, stock_vecs.shape[0], batch_size): x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() for i in range(0, gen_vecs.shape[0], batch_size): y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() y_gen = y_gen.transpose(0, 1) tp = torch.mm(x_stock, y_gen) jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() jac[np.isnan(jac)] = 1 if p != 1: jac = jac**p if agg == 'max': agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) elif agg == 'mean': agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) total[i:i + y_gen.shape[1]] += jac.shape[0] if agg == 'mean': agg_tanimoto /= total if p != 1: agg_tanimoto = (agg_tanimoto)**(1/p) return np.mean(agg_tanimoto) def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, morgan__n=1024, *args, **kwargs): """ Generates fingerprint for SMILES If smiles is invalid, returns None Returns numpy array of fingerprint bits Parameters: smiles: SMILES string type: type of fingerprint: [MACCS|morgan] dtype: if not None, specifies the dtype of returned array """ fp_type = fp_type.lower() molecule = get_mol(smiles_or_mol, *args, **kwargs) if molecule is None: return None if fp_type == 'maccs': keys = MACCSkeys.GenMACCSKeys(molecule) keys = np.array(keys.GetOnBits()) fingerprint = np.zeros(166, dtype='uint8') if len(keys) != 0: fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero elif fp_type == 'morgan': fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), dtype='uint8') else: raise ValueError("Unknown fingerprint type {}".format(fp_type)) if dtype is not None: fingerprint = fingerprint.astype(dtype) return fingerprint def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, **kwargs): ''' Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) Inserts np.NaN to rows corresponding to incorrect smiles. IMPORTANT: if there is at least one np.NaN, the dtype would be float Parameters: smiles_mols_array: list/array/pd.Series of smiles or already computed RDKit molecules n_jobs: number of parralel workers to execute already_unique: flag for performance reasons, if smiles array is big and already unique. Its value is set to True if smiles_mols_array contain RDKit molecules already. ''' if isinstance(smiles_mols_array, pd.Series): smiles_mols_array = smiles_mols_array.values else: smiles_mols_array = np.asarray(smiles_mols_array) if not isinstance(smiles_mols_array[0], str): already_unique = True if not already_unique: smiles_mols_array, inv_index = np.unique(smiles_mols_array, return_inverse=True) fps = mapper(n_jobs)( partial(fingerprint, *args, **kwargs), smiles_mols_array ) length = 1 for fp in fps: if fp is not None: length = fp.shape[-1] first_fp = fp break fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] for fp in fps] if scipy.sparse.issparse(first_fp): fps = scipy.sparse.vstack(fps).tocsr() else: fps = np.vstack(fps) if not already_unique: return fps[inv_index] return fps def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan', gen_fps=None, p=1): """ Computes internal diversity as: 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) """ if gen_fps is None: gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs) return 1 - (average_agg_tanimoto(gen_fps, gen_fps, agg='mean', device=device, p=p)).mean() def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'): fcd = FCD(device=device, n_jobs= n_jobs) return fcd(gen, train) def SYBAscore(gen): """ Compute the average SYBA score for a list of SMILES strings. Parameters: - smiles_list (list of str): A list of SMILES strings representing molecules. Returns: - float: The average SYBA score for the list of molecules. """ syba = SybaClassifier() syba.fitDefaultScore() scores = [] for smiles in gen: try: score = syba.predict(smi=smiles) scores.append(score) except Exception as e: print(f"Error processing SMILES '{smiles}': {e}") continue if scores: return sum(scores) / len(scores) else: return None # Or handle empty list or all failed predictions as needed def oracles(gen, train): Result = {} # evaluator = Evaluator(name = 'KL_Divergence') # KL_Divergence = evaluator(gen, train) # Result["KL_Divergence"]: KL_Divergence oracle_list = [ 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3', 'DRD2', 'LogP', 'Rediscovery', 'Similarity', 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop' ] for oracle_name in oracle_list: oracle = Oracle(name=oracle_name) if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']: score = oracle(gen) if isinstance(score, dict): score = {key: sum(values)/len(values) for key, values in score.items()} else: score = oracle(gen) if isinstance(score, list): score = sum(score) / len(score) Result[f"{oracle_name}"] = score return Result _DESCRIPTION = """ Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives. """ _KWARGS_DESCRIPTION = """ Args: generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples. train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules. Returns: Dectionary item containing various metrics to evaluate model performance """ _CITATION = """ @article{DBLP:journals/corr/abs-1811-12823, author = {Daniil Polykovskiy and Alexander Zhebrak and Benjam{\'{\i}}n S{\'{a}}nchez{-}Lengeling and Sergey Golovanov and Oktai Tatanov and Stanislav Belyaev and Rauf Kurbanov and Aleksey Artamonov and Vladimir Aladinskiy and Mark Veselov and Artur Kadurin and Sergey I. Nikolenko and Al{\'{a}}n Aspuru{-}Guzik and Alex Zhavoronkov}, title = {Molecular Sets {(MOSES):} {A} Benchmarking Platform for Molecular Generation Models}, journal = {CoRR}, volume = {abs/1811.12823}, year = {2018}, url = {http://arxiv.org/abs/1811.12823}, eprinttype = {arXiv}, eprint = {1811.12823}, timestamp = {Fri, 26 Nov 2021 15:34:30 +0100}, biburl = {https://dblp.org/rec/journals/corr/abs-1811-12823.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class molgenevalmetric(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { "generated_smiles": datasets.Sequence(datasets.Value("string")), "train_smiles": datasets.Sequence(datasets.Value("string")), } if self.config_name == "multilabel" else { "generated_smiles": datasets.Value("string"), "train_smiles": datasets.Value("string"), } ), reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"], ) def _compute(self, gensmi, trainsmi): metrics = {} metrics['novelty'] = novelty(gen = gensmi, train = trainsmi) metrics['valid'] = fraction_valid(gen=gensmi) metrics['unique'] = fraction_unique(gen=gensmi) metrics['IntDiv'] = internal_diversity(gen=gensmi) metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi) metrics['SA'] = SAscore(gen=gensmi) metrics['SCS'] = SAscore(gen=trainsmi) return metrics # generated_smiles = [s for s in generated_smiles if s != ''] # evaluator = Evaluator(name = 'KL_Divergence') # KL_Divergence = evaluator(generated_smiles, train_smiles) # Results.update({ # "KL_Divergence": KL_Divergence, # }) # oracle_list = [ # 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3', # 'DRD2', 'LogP', 'Rediscovery', 'Similarity', # 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop' # ] # for oracle_name in oracle_list: # oracle = Oracle(name=oracle_name) # if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']: # score = oracle(generated_smiles) # if isinstance(score, dict): # score = {key: sum(values)/len(values) for key, values in score.items()} # else: # score = oracle(generated_smiles) # if isinstance(score, list): # score = sum(score) / len(score) # Results.update({f"{oracle_name}": score}) # # keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"] # # for key in keys_to_remove: # # Results.pop(key, None) # return {"results": Results}