saicharan2804 commited on
Commit
ddc012e
1 Parent(s): 6d9565d

Many additions

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +76 -42
molgenevalmetric.py CHANGED
@@ -4,10 +4,11 @@ import datasets
4
  # import moses
5
  # from moses import metrics
6
  import pandas as pd
7
- # from tdc import Evaluator
8
- # from tdc import Oracle
9
  # from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
10
-
 
11
  import os
12
  from collections import Counter
13
  from functools import partial
@@ -41,6 +42,8 @@ from fcd_torch import FCD
41
  # from SCScore import SCScorer
42
 
43
  from myscscore.SCScore import SCScorer
 
 
44
 
45
  def get_mol(smiles_or_mol):
46
  """
@@ -174,29 +177,6 @@ def novelty(gen, train, n_jobs=1):
174
  return len(gen_smiles_set - train_set) / len(gen_smiles_set)
175
 
176
 
177
- # def SAscore(gen):
178
- # """
179
- # Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
180
-
181
- # Parameters:
182
- # - smiles_list (list of str): A list containing the SMILES representations of the molecules.
183
-
184
- # Returns:
185
- # - float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
186
- # """
187
- # scores = []
188
- # for smiles in gen:
189
- # mol = Chem.MolFromSmiles(smiles)
190
- # if mol: # Ensures the molecule could be parsed from the SMILES string
191
- # score = sascorer.calculateScore(mol)
192
- # scores.append(score)
193
-
194
- # if scores: # Checks if there are any scores calculated
195
- # return np.mean(scores)
196
- # else:
197
- # return None
198
-
199
-
200
  def synthetic_complexity_score(gen):
201
  """
202
  Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings.
@@ -448,6 +428,60 @@ def fcd_metric(gen, train, n_jobs = 1, device = None):
448
  # else:
449
  # return None # Or handle empty list or all failed predictions as needed
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  def oracles(gen, train):
452
 
453
  """
@@ -461,20 +495,18 @@ def oracles(gen, train):
461
  - Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values.
462
  """
463
 
464
- Result = {}
465
- evaluator = Evaluator(name = 'KL_Divergence')
466
- KL_Divergence = evaluator(gen, train)
467
-
468
- Result["KL_Divergence"] = KL_Divergence
469
 
 
 
 
 
 
470
 
471
- oracle_list = [
472
- 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
473
- 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
474
- 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
475
- ]
476
 
477
  for oracle_name in oracle_list:
 
478
  oracle = Oracle(name=oracle_name)
479
  if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
480
  score = oracle(gen)
@@ -485,9 +517,9 @@ def oracles(gen, train):
485
  if isinstance(score, list):
486
  score = sum(score) / len(score)
487
 
488
- Result[f"{oracle_name}"] = score
489
 
490
- return Result
491
 
492
 
493
 
@@ -564,12 +596,14 @@ class molgenevalmetric(evaluate.Metric):
564
  def _compute(self, gensmi, trainsmi):
565
 
566
  metrics = {}
567
- metrics['novelty'] = novelty(gen = gensmi, train = trainsmi)
568
- metrics['valid'] = fraction_valid(gen=gensmi)
569
- metrics['unique'] = fraction_unique(gen=gensmi)
570
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
571
  metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
572
- # metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
 
 
573
 
574
  # print('computing')
575
 
 
4
  # import moses
5
  # from moses import metrics
6
  import pandas as pd
7
+ from tdc import Evaluator
8
+ from tdc import Oracle
9
  # from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
10
+ from rdkit.Chem.QED import qed
11
+ from rdkit.Chem.Crippen import MolLogP
12
  import os
13
  from collections import Counter
14
  from functools import partial
 
42
  # from SCScore import SCScorer
43
 
44
  from myscscore.SCScore import SCScorer
45
+ import warnings
46
+
47
 
48
  def get_mol(smiles_or_mol):
49
  """
 
177
  return len(gen_smiles_set - train_set) / len(gen_smiles_set)
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def synthetic_complexity_score(gen):
181
  """
182
  Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings.
 
428
  # else:
429
  # return None # Or handle empty list or all failed predictions as needed
430
 
431
+ def qed_metric(gen):
432
+ """
433
+ Computes RDKit's QED score
434
+ """
435
+ if not gen:
436
+ return 0.0 # Return 0 or suitable value for empty list
437
+
438
+ # Convert SMILES strings to RDKit molecule objects and calculate QED scores
439
+ qed_scores = []
440
+ for smiles in gen:
441
+ try:
442
+ mol = Chem.MolFromSmiles(smiles)
443
+ if mol: # Ensure molecule is valid
444
+ qed_scores.append(qed(mol))
445
+ except Exception as e:
446
+ print(f"Error processing molecule {smiles}: {str(e)}")
447
+
448
+ # Calculate the average QED score
449
+ if qed_scores:
450
+ return sum(qed_scores) / len(qed_scores)
451
+ else:
452
+ return 0.0 # Return 0 or suitable value if no valid molecules are processed
453
+
454
+ def logP_metric(gen):
455
+ """
456
+ Computes the average RDKit's logP value for a list of SMILES strings.
457
+
458
+ Parameters:
459
+ - mols (List[str]): List of SMILES strings representing the molecules.
460
+
461
+ Returns:
462
+ - float: Average logP value for the list of molecules.
463
+ """
464
+ # Check if the input list is empty
465
+ if not gen:
466
+ return 0.0 # Return 0 or suitable value for empty list
467
+
468
+ # Convert SMILES strings to RDKit molecule objects and calculate logP values
469
+ logP_values = []
470
+ for smiles in gen:
471
+ try:
472
+ mol = Chem.MolFromSmiles(smiles)
473
+ if mol: # Ensure molecule is valid
474
+ logP_values.append(MolLogP(mol))
475
+ except Exception as e:
476
+ print(f"Error processing molecule {smiles}: {str(e)}")
477
+
478
+ # Calculate the average logP value
479
+ if logP_values:
480
+ return sum(logP_values) / len(logP_values)
481
+ else:
482
+ return 0.0 # Return 0 or suitable value if no valid molecules are processed
483
+
484
+
485
  def oracles(gen, train):
486
 
487
  """
 
495
  - Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values.
496
  """
497
 
498
+ result = {}
 
 
 
 
499
 
500
+ # oracle_list = [
501
+ # 'QED', 'MPO', 'GSK3B', 'JNK3',
502
+ # 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
503
+ # 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
504
+ # ]
505
 
506
+ oracle_list = ['QED', 'LogP', 'SA']
 
 
 
 
507
 
508
  for oracle_name in oracle_list:
509
+ print(oracle_name)
510
  oracle = Oracle(name=oracle_name)
511
  if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
512
  score = oracle(gen)
 
517
  if isinstance(score, list):
518
  score = sum(score) / len(score)
519
 
520
+ result[f"{oracle_name}"] = score
521
 
522
+ return result
523
 
524
 
525
 
 
596
  def _compute(self, gensmi, trainsmi):
597
 
598
  metrics = {}
599
+ metrics['Novelty'] = novelty(gen = gensmi, train = trainsmi)
600
+ metrics['Valid'] = fraction_valid(gen=gensmi)
601
+ metrics['Unique'] = fraction_unique(gen=gensmi)
602
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
603
  metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
604
+ metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
605
+ metrics['QED'] = qed_metric(gen=gensmi)
606
+ metrics['LogP'] = logP_metric(gen=gensmi)
607
 
608
  # print('computing')
609