saicharan2804 commited on
Commit
6d9565d
1 Parent(s): 0a798d8
Files changed (1) hide show
  1. molgenevalmetric.py +8 -2
molgenevalmetric.py CHANGED
@@ -398,7 +398,7 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
398
  agg='mean', device=device, p=p)).mean()
399
 
400
 
401
- def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
402
  """
403
  Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
404
 
@@ -412,6 +412,12 @@ def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
412
  - float: FCD score.
413
  """
414
 
 
 
 
 
 
 
415
  fcd = FCD(device=device, n_jobs= n_jobs)
416
  return fcd(gen, train)
417
 
@@ -562,7 +568,7 @@ class molgenevalmetric(evaluate.Metric):
562
  metrics['valid'] = fraction_valid(gen=gensmi)
563
  metrics['unique'] = fraction_unique(gen=gensmi)
564
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
565
- # metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
566
  # metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
567
 
568
  # print('computing')
 
398
  agg='mean', device=device, p=p)).mean()
399
 
400
 
401
+ def fcd_metric(gen, train, n_jobs = 1, device = None):
402
  """
403
  Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
404
 
 
412
  - float: FCD score.
413
  """
414
 
415
+ # Determine the device dynamically based on CUDA availability
416
+ if device is None:
417
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
418
+ else:
419
+ device = torch.device(device if torch.cuda.is_available() and 'cuda' in device else 'cpu')
420
+
421
  fcd = FCD(device=device, n_jobs= n_jobs)
422
  return fcd(gen, train)
423
 
 
568
  metrics['valid'] = fraction_valid(gen=gensmi)
569
  metrics['unique'] = fraction_unique(gen=gensmi)
570
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
571
+ metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
572
  # metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
573
 
574
  # print('computing')