saicharan2804 commited on
Commit
efabb66
1 Parent(s): aae0b32

Improved compute

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +30 -15
molgenevalmetric.py CHANGED
@@ -223,7 +223,7 @@ def average_sascore(gen, n_jobs=1):
223
 
224
  def average_agg_tanimoto(stock_vecs, gen_vecs,
225
  batch_size=5000, agg='max',
226
- device='cpu', p=1):
227
  """
228
  Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints.
229
 
@@ -232,7 +232,7 @@ def average_agg_tanimoto(stock_vecs, gen_vecs,
232
  - gen_vecs (numpy array): Fingerprint vectors for the generated molecule set.
233
  - batch_size (int): The size of batches to process similarities (reduces memory usage).
234
  - agg (str): Aggregation method, either 'max' or 'mean'.
235
- - device (str): The computation device ('cpu' or 'cuda:0', etc.).
236
  - p (float): The power for averaging, used in generalized mean calculation.
237
 
238
  Returns:
@@ -240,6 +240,11 @@ def average_agg_tanimoto(stock_vecs, gen_vecs,
240
  """
241
 
242
  assert agg in ['max', 'mean'], "Can aggregate only max or mean"
 
 
 
 
 
243
  agg_tanimoto = np.zeros(len(gen_vecs))
244
  total = np.zeros(len(gen_vecs))
245
  for j in range(0, stock_vecs.shape[0], batch_size):
@@ -553,24 +558,34 @@ class molgenevalmetric(evaluate.Metric):
553
  )
554
 
555
  def _compute(self, gensmi, trainsmi):
556
-
557
  metrics = {}
558
- metrics['Novelty'] = novelty(gen = gensmi, train = trainsmi)
559
- metrics['Valid'] = fraction_valid(gen=gensmi)
560
- metrics['Unique'] = fraction_unique(gen=gensmi)
561
- metrics['IntDiv'] = internal_diversity(gen=gensmi)
562
- metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
563
- metrics['QED'] = qed_metric(gen=gensmi)
564
- metrics['LogP'] = logP_metric(gen=gensmi)
565
- metrics['Penalized LogP'] = penalized_logp(gen=gensmi)
566
- metrics['SA'] = average_sascore(gen=gensmi)
567
- metrics['SCScore'] = synthetic_complexity_score(gen=gensmi)
568
- metrics['SYBA'] = SYBAscore(gen=gensmi)
569
- # metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
 
 
 
 
 
 
 
 
 
 
570
 
571
  return metrics
572
 
573
 
 
574
  # def get_n_rings(mol):
575
  # """
576
  # Computes the number of rings in a molecule
 
223
 
224
  def average_agg_tanimoto(stock_vecs, gen_vecs,
225
  batch_size=5000, agg='max',
226
+ device=None, p=1):
227
  """
228
  Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints.
229
 
 
232
  - gen_vecs (numpy array): Fingerprint vectors for the generated molecule set.
233
  - batch_size (int): The size of batches to process similarities (reduces memory usage).
234
  - agg (str): Aggregation method, either 'max' or 'mean'.
235
+ - device (str or None): The computation device ('cpu' or 'cuda:0', etc.). If None, automatically detect.
236
  - p (float): The power for averaging, used in generalized mean calculation.
237
 
238
  Returns:
 
240
  """
241
 
242
  assert agg in ['max', 'mean'], "Can aggregate only max or mean"
243
+
244
+ # Automatically detect device if not provided
245
+ if device is None:
246
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
247
+
248
  agg_tanimoto = np.zeros(len(gen_vecs))
249
  total = np.zeros(len(gen_vecs))
250
  for j in range(0, stock_vecs.shape[0], batch_size):
 
558
  )
559
 
560
  def _compute(self, gensmi, trainsmi):
 
561
  metrics = {}
562
+ metric_functions = {
563
+ 'Novelty': lambda: novelty(gen=gensmi, train=trainsmi),
564
+ 'Valid': lambda: fraction_valid(gen=gensmi),
565
+ 'Unique': lambda: fraction_unique(gen=gensmi),
566
+ 'IntDiv': lambda: internal_diversity(gen=gensmi),
567
+ 'FCD': lambda: fcd_metric(gen=gensmi, train=trainsmi),
568
+ 'QED': lambda: qed_metric(gen=gensmi),
569
+ 'LogP': lambda: logP_metric(gen=gensmi),
570
+ 'Penalized LogP': lambda: penalized_logp(gen=gensmi),
571
+ 'SA': lambda: average_sascore(gen=gensmi),
572
+ 'SCScore': lambda: synthetic_complexity_score(gen=gensmi),
573
+ 'SYBA': lambda: SYBAscore(gen=gensmi),
574
+ # 'Oracles': lambda: oracles(gen=gensmi, train=trainsmi)
575
+ }
576
+
577
+ for metric_name, compute_func in metric_functions.items():
578
+ print(f"Computing {metric_name}...")
579
+ try:
580
+ metrics[metric_name] = compute_func()
581
+ except Exception as e:
582
+ print(f"Error computing {metric_name}: {e}")
583
+ metrics[metric_name] = 0
584
 
585
  return metrics
586
 
587
 
588
+
589
  # def get_n_rings(mol):
590
  # """
591
  # Computes the number of rings in a molecule