Spaces:
Sleeping
Sleeping
saicharan2804
commited on
Commit
•
ddc012e
1
Parent(s):
6d9565d
Many additions
Browse files- molgenevalmetric.py +76 -42
molgenevalmetric.py
CHANGED
@@ -4,10 +4,11 @@ import datasets
|
|
4 |
# import moses
|
5 |
# from moses import metrics
|
6 |
import pandas as pd
|
7 |
-
|
8 |
-
|
9 |
# from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
|
10 |
-
|
|
|
11 |
import os
|
12 |
from collections import Counter
|
13 |
from functools import partial
|
@@ -41,6 +42,8 @@ from fcd_torch import FCD
|
|
41 |
# from SCScore import SCScorer
|
42 |
|
43 |
from myscscore.SCScore import SCScorer
|
|
|
|
|
44 |
|
45 |
def get_mol(smiles_or_mol):
|
46 |
"""
|
@@ -174,29 +177,6 @@ def novelty(gen, train, n_jobs=1):
|
|
174 |
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
175 |
|
176 |
|
177 |
-
# def SAscore(gen):
|
178 |
-
# """
|
179 |
-
# Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
|
180 |
-
|
181 |
-
# Parameters:
|
182 |
-
# - smiles_list (list of str): A list containing the SMILES representations of the molecules.
|
183 |
-
|
184 |
-
# Returns:
|
185 |
-
# - float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
|
186 |
-
# """
|
187 |
-
# scores = []
|
188 |
-
# for smiles in gen:
|
189 |
-
# mol = Chem.MolFromSmiles(smiles)
|
190 |
-
# if mol: # Ensures the molecule could be parsed from the SMILES string
|
191 |
-
# score = sascorer.calculateScore(mol)
|
192 |
-
# scores.append(score)
|
193 |
-
|
194 |
-
# if scores: # Checks if there are any scores calculated
|
195 |
-
# return np.mean(scores)
|
196 |
-
# else:
|
197 |
-
# return None
|
198 |
-
|
199 |
-
|
200 |
def synthetic_complexity_score(gen):
|
201 |
"""
|
202 |
Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings.
|
@@ -448,6 +428,60 @@ def fcd_metric(gen, train, n_jobs = 1, device = None):
|
|
448 |
# else:
|
449 |
# return None # Or handle empty list or all failed predictions as needed
|
450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
def oracles(gen, train):
|
452 |
|
453 |
"""
|
@@ -461,20 +495,18 @@ def oracles(gen, train):
|
|
461 |
- Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values.
|
462 |
"""
|
463 |
|
464 |
-
|
465 |
-
evaluator = Evaluator(name = 'KL_Divergence')
|
466 |
-
KL_Divergence = evaluator(gen, train)
|
467 |
-
|
468 |
-
Result["KL_Divergence"] = KL_Divergence
|
469 |
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
-
oracle_list = [
|
472 |
-
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
|
473 |
-
'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
474 |
-
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
475 |
-
]
|
476 |
|
477 |
for oracle_name in oracle_list:
|
|
|
478 |
oracle = Oracle(name=oracle_name)
|
479 |
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
|
480 |
score = oracle(gen)
|
@@ -485,9 +517,9 @@ def oracles(gen, train):
|
|
485 |
if isinstance(score, list):
|
486 |
score = sum(score) / len(score)
|
487 |
|
488 |
-
|
489 |
|
490 |
-
return
|
491 |
|
492 |
|
493 |
|
@@ -564,12 +596,14 @@ class molgenevalmetric(evaluate.Metric):
|
|
564 |
def _compute(self, gensmi, trainsmi):
|
565 |
|
566 |
metrics = {}
|
567 |
-
metrics['
|
568 |
-
metrics['
|
569 |
-
metrics['
|
570 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
571 |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
572 |
-
|
|
|
|
|
573 |
|
574 |
# print('computing')
|
575 |
|
|
|
4 |
# import moses
|
5 |
# from moses import metrics
|
6 |
import pandas as pd
|
7 |
+
from tdc import Evaluator
|
8 |
+
from tdc import Oracle
|
9 |
# from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
|
10 |
+
from rdkit.Chem.QED import qed
|
11 |
+
from rdkit.Chem.Crippen import MolLogP
|
12 |
import os
|
13 |
from collections import Counter
|
14 |
from functools import partial
|
|
|
42 |
# from SCScore import SCScorer
|
43 |
|
44 |
from myscscore.SCScore import SCScorer
|
45 |
+
import warnings
|
46 |
+
|
47 |
|
48 |
def get_mol(smiles_or_mol):
|
49 |
"""
|
|
|
177 |
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
178 |
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def synthetic_complexity_score(gen):
|
181 |
"""
|
182 |
Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings.
|
|
|
428 |
# else:
|
429 |
# return None # Or handle empty list or all failed predictions as needed
|
430 |
|
431 |
+
def qed_metric(gen):
|
432 |
+
"""
|
433 |
+
Computes RDKit's QED score
|
434 |
+
"""
|
435 |
+
if not gen:
|
436 |
+
return 0.0 # Return 0 or suitable value for empty list
|
437 |
+
|
438 |
+
# Convert SMILES strings to RDKit molecule objects and calculate QED scores
|
439 |
+
qed_scores = []
|
440 |
+
for smiles in gen:
|
441 |
+
try:
|
442 |
+
mol = Chem.MolFromSmiles(smiles)
|
443 |
+
if mol: # Ensure molecule is valid
|
444 |
+
qed_scores.append(qed(mol))
|
445 |
+
except Exception as e:
|
446 |
+
print(f"Error processing molecule {smiles}: {str(e)}")
|
447 |
+
|
448 |
+
# Calculate the average QED score
|
449 |
+
if qed_scores:
|
450 |
+
return sum(qed_scores) / len(qed_scores)
|
451 |
+
else:
|
452 |
+
return 0.0 # Return 0 or suitable value if no valid molecules are processed
|
453 |
+
|
454 |
+
def logP_metric(gen):
|
455 |
+
"""
|
456 |
+
Computes the average RDKit's logP value for a list of SMILES strings.
|
457 |
+
|
458 |
+
Parameters:
|
459 |
+
- mols (List[str]): List of SMILES strings representing the molecules.
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
- float: Average logP value for the list of molecules.
|
463 |
+
"""
|
464 |
+
# Check if the input list is empty
|
465 |
+
if not gen:
|
466 |
+
return 0.0 # Return 0 or suitable value for empty list
|
467 |
+
|
468 |
+
# Convert SMILES strings to RDKit molecule objects and calculate logP values
|
469 |
+
logP_values = []
|
470 |
+
for smiles in gen:
|
471 |
+
try:
|
472 |
+
mol = Chem.MolFromSmiles(smiles)
|
473 |
+
if mol: # Ensure molecule is valid
|
474 |
+
logP_values.append(MolLogP(mol))
|
475 |
+
except Exception as e:
|
476 |
+
print(f"Error processing molecule {smiles}: {str(e)}")
|
477 |
+
|
478 |
+
# Calculate the average logP value
|
479 |
+
if logP_values:
|
480 |
+
return sum(logP_values) / len(logP_values)
|
481 |
+
else:
|
482 |
+
return 0.0 # Return 0 or suitable value if no valid molecules are processed
|
483 |
+
|
484 |
+
|
485 |
def oracles(gen, train):
|
486 |
|
487 |
"""
|
|
|
495 |
- Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values.
|
496 |
"""
|
497 |
|
498 |
+
result = {}
|
|
|
|
|
|
|
|
|
499 |
|
500 |
+
# oracle_list = [
|
501 |
+
# 'QED', 'MPO', 'GSK3B', 'JNK3',
|
502 |
+
# 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
503 |
+
# 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
504 |
+
# ]
|
505 |
|
506 |
+
oracle_list = ['QED', 'LogP', 'SA']
|
|
|
|
|
|
|
|
|
507 |
|
508 |
for oracle_name in oracle_list:
|
509 |
+
print(oracle_name)
|
510 |
oracle = Oracle(name=oracle_name)
|
511 |
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
|
512 |
score = oracle(gen)
|
|
|
517 |
if isinstance(score, list):
|
518 |
score = sum(score) / len(score)
|
519 |
|
520 |
+
result[f"{oracle_name}"] = score
|
521 |
|
522 |
+
return result
|
523 |
|
524 |
|
525 |
|
|
|
596 |
def _compute(self, gensmi, trainsmi):
|
597 |
|
598 |
metrics = {}
|
599 |
+
metrics['Novelty'] = novelty(gen = gensmi, train = trainsmi)
|
600 |
+
metrics['Valid'] = fraction_valid(gen=gensmi)
|
601 |
+
metrics['Unique'] = fraction_unique(gen=gensmi)
|
602 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
603 |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
604 |
+
metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
605 |
+
metrics['QED'] = qed_metric(gen=gensmi)
|
606 |
+
metrics['LogP'] = logP_metric(gen=gensmi)
|
607 |
|
608 |
# print('computing')
|
609 |
|