Saicharan commited on
Commit
d8a04dc
1 Parent(s): 99bdc4b

SA score added

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +42 -2
molgenevalmetric.py CHANGED
@@ -27,6 +27,7 @@ from collections import UserList, defaultdict
27
  import numpy as np
28
  import pandas as pd
29
  from rdkit import rdBase
 
30
  import sys
31
 
32
  from rdkit.Chem import RDConfig
@@ -213,6 +214,45 @@ def synthetic_complexity_score(gen):
213
  return average_score
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def average_agg_tanimoto(stock_vecs, gen_vecs,
217
  batch_size=5000, agg='max',
218
  device='cpu', p=1):
@@ -522,12 +562,12 @@ class molgenevalmetric(evaluate.Metric):
522
  metrics['valid'] = fraction_valid(gen=gensmi)
523
  metrics['unique'] = fraction_unique(gen=gensmi)
524
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
525
- metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
526
  # metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
527
 
528
  print('computing')
529
 
530
- # metrics['SA'] = SAscore(gen=gensmi)
531
  metrics['SCS'] = synthetic_complexity_score(gen=gensmi)
532
 
533
  return metrics
 
27
  import numpy as np
28
  import pandas as pd
29
  from rdkit import rdBase
30
+ from rdkit.Contrib.SA_Score import sascorer
31
  import sys
32
 
33
  from rdkit.Chem import RDConfig
 
214
  return average_score
215
 
216
 
217
+ def calculate_sa_score(smiles):
218
+ """
219
+ Calculates the SA score for a single SMILES string.
220
+
221
+ Parameters:
222
+ - smiles (str): SMILES string of the molecule.
223
+
224
+ Returns:
225
+ - float: SA score of the molecule, or None if the molecule couldn't be created.
226
+ """
227
+ mol = Chem.MolFromSmiles(smiles)
228
+ if mol:
229
+ return sascorer.calculateScore(mol)
230
+ else:
231
+ return None
232
+
233
+ def average_sascore(gen, n_jobs=1):
234
+ """
235
+ Computes the average synthetic accessibility score for a list of molecules
236
+ using parallel or sequential execution based on the n_jobs parameter.
237
+
238
+ Parameters:
239
+ - molecules (List[str]): List of generated SMILES strings.
240
+ - n_jobs (int): Number of parallel jobs to use for computation.
241
+
242
+ Returns:
243
+ - float: Average SA score, or None if no scores could be computed.
244
+ """
245
+
246
+ scores = mapper(n_jobs)(calculate_sa_score, molecules)
247
+
248
+ # Filter out None values which indicate failed molecule creation
249
+ valid_scores = [score for score in scores if score is not None]
250
+
251
+ if valid_scores:
252
+ return sum(valid_scores) / len(valid_scores)
253
+ else:
254
+ return None
255
+
256
  def average_agg_tanimoto(stock_vecs, gen_vecs,
257
  batch_size=5000, agg='max',
258
  device='cpu', p=1):
 
562
  metrics['valid'] = fraction_valid(gen=gensmi)
563
  metrics['unique'] = fraction_unique(gen=gensmi)
564
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
565
+ # metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
566
  # metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
567
 
568
  print('computing')
569
 
570
+ metrics['SA'] = average_sascore(gen=gensmi)
571
  metrics['SCS'] = synthetic_complexity_score(gen=gensmi)
572
 
573
  return metrics