Spaces:
Sleeping
Sleeping
saicharan2804
commited on
Commit
•
6d9565d
1
Parent(s):
0a798d8
FCD added
Browse files- molgenevalmetric.py +8 -2
molgenevalmetric.py
CHANGED
@@ -398,7 +398,7 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
|
|
398 |
agg='mean', device=device, p=p)).mean()
|
399 |
|
400 |
|
401 |
-
def fcd_metric(gen, train, n_jobs =
|
402 |
"""
|
403 |
Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
|
404 |
|
@@ -412,6 +412,12 @@ def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
|
|
412 |
- float: FCD score.
|
413 |
"""
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
fcd = FCD(device=device, n_jobs= n_jobs)
|
416 |
return fcd(gen, train)
|
417 |
|
@@ -562,7 +568,7 @@ class molgenevalmetric(evaluate.Metric):
|
|
562 |
metrics['valid'] = fraction_valid(gen=gensmi)
|
563 |
metrics['unique'] = fraction_unique(gen=gensmi)
|
564 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
565 |
-
|
566 |
# metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
567 |
|
568 |
# print('computing')
|
|
|
398 |
agg='mean', device=device, p=p)).mean()
|
399 |
|
400 |
|
401 |
+
def fcd_metric(gen, train, n_jobs = 1, device = None):
|
402 |
"""
|
403 |
Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
|
404 |
|
|
|
412 |
- float: FCD score.
|
413 |
"""
|
414 |
|
415 |
+
# Determine the device dynamically based on CUDA availability
|
416 |
+
if device is None:
|
417 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
418 |
+
else:
|
419 |
+
device = torch.device(device if torch.cuda.is_available() and 'cuda' in device else 'cpu')
|
420 |
+
|
421 |
fcd = FCD(device=device, n_jobs= n_jobs)
|
422 |
return fcd(gen, train)
|
423 |
|
|
|
568 |
metrics['valid'] = fraction_valid(gen=gensmi)
|
569 |
metrics['unique'] = fraction_unique(gen=gensmi)
|
570 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
571 |
+
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
572 |
# metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
573 |
|
574 |
# print('computing')
|