Spaces:
Sleeping
Sleeping
saicharan2804
commited on
Commit
•
efabb66
1
Parent(s):
aae0b32
Improved compute
Browse files- molgenevalmetric.py +30 -15
molgenevalmetric.py
CHANGED
@@ -223,7 +223,7 @@ def average_sascore(gen, n_jobs=1):
|
|
223 |
|
224 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
225 |
batch_size=5000, agg='max',
|
226 |
-
device=
|
227 |
"""
|
228 |
Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints.
|
229 |
|
@@ -232,7 +232,7 @@ def average_agg_tanimoto(stock_vecs, gen_vecs,
|
|
232 |
- gen_vecs (numpy array): Fingerprint vectors for the generated molecule set.
|
233 |
- batch_size (int): The size of batches to process similarities (reduces memory usage).
|
234 |
- agg (str): Aggregation method, either 'max' or 'mean'.
|
235 |
-
- device (str): The computation device ('cpu' or 'cuda:0', etc.).
|
236 |
- p (float): The power for averaging, used in generalized mean calculation.
|
237 |
|
238 |
Returns:
|
@@ -240,6 +240,11 @@ def average_agg_tanimoto(stock_vecs, gen_vecs,
|
|
240 |
"""
|
241 |
|
242 |
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
|
|
|
|
|
|
|
|
|
|
243 |
agg_tanimoto = np.zeros(len(gen_vecs))
|
244 |
total = np.zeros(len(gen_vecs))
|
245 |
for j in range(0, stock_vecs.shape[0], batch_size):
|
@@ -553,24 +558,34 @@ class molgenevalmetric(evaluate.Metric):
|
|
553 |
)
|
554 |
|
555 |
def _compute(self, gensmi, trainsmi):
|
556 |
-
|
557 |
metrics = {}
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
return metrics
|
572 |
|
573 |
|
|
|
574 |
# def get_n_rings(mol):
|
575 |
# """
|
576 |
# Computes the number of rings in a molecule
|
|
|
223 |
|
224 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
225 |
batch_size=5000, agg='max',
|
226 |
+
device=None, p=1):
|
227 |
"""
|
228 |
Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints.
|
229 |
|
|
|
232 |
- gen_vecs (numpy array): Fingerprint vectors for the generated molecule set.
|
233 |
- batch_size (int): The size of batches to process similarities (reduces memory usage).
|
234 |
- agg (str): Aggregation method, either 'max' or 'mean'.
|
235 |
+
- device (str or None): The computation device ('cpu' or 'cuda:0', etc.). If None, automatically detect.
|
236 |
- p (float): The power for averaging, used in generalized mean calculation.
|
237 |
|
238 |
Returns:
|
|
|
240 |
"""
|
241 |
|
242 |
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
243 |
+
|
244 |
+
# Automatically detect device if not provided
|
245 |
+
if device is None:
|
246 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
247 |
+
|
248 |
agg_tanimoto = np.zeros(len(gen_vecs))
|
249 |
total = np.zeros(len(gen_vecs))
|
250 |
for j in range(0, stock_vecs.shape[0], batch_size):
|
|
|
558 |
)
|
559 |
|
560 |
def _compute(self, gensmi, trainsmi):
|
|
|
561 |
metrics = {}
|
562 |
+
metric_functions = {
|
563 |
+
'Novelty': lambda: novelty(gen=gensmi, train=trainsmi),
|
564 |
+
'Valid': lambda: fraction_valid(gen=gensmi),
|
565 |
+
'Unique': lambda: fraction_unique(gen=gensmi),
|
566 |
+
'IntDiv': lambda: internal_diversity(gen=gensmi),
|
567 |
+
'FCD': lambda: fcd_metric(gen=gensmi, train=trainsmi),
|
568 |
+
'QED': lambda: qed_metric(gen=gensmi),
|
569 |
+
'LogP': lambda: logP_metric(gen=gensmi),
|
570 |
+
'Penalized LogP': lambda: penalized_logp(gen=gensmi),
|
571 |
+
'SA': lambda: average_sascore(gen=gensmi),
|
572 |
+
'SCScore': lambda: synthetic_complexity_score(gen=gensmi),
|
573 |
+
'SYBA': lambda: SYBAscore(gen=gensmi),
|
574 |
+
# 'Oracles': lambda: oracles(gen=gensmi, train=trainsmi)
|
575 |
+
}
|
576 |
+
|
577 |
+
for metric_name, compute_func in metric_functions.items():
|
578 |
+
print(f"Computing {metric_name}...")
|
579 |
+
try:
|
580 |
+
metrics[metric_name] = compute_func()
|
581 |
+
except Exception as e:
|
582 |
+
print(f"Error computing {metric_name}: {e}")
|
583 |
+
metrics[metric_name] = 0
|
584 |
|
585 |
return metrics
|
586 |
|
587 |
|
588 |
+
|
589 |
# def get_n_rings(mol):
|
590 |
# """
|
591 |
# Computes the number of rings in a molecule
|