Spaces:
Sleeping
Sleeping
Saicharan
commited on
Commit
•
d8a04dc
1
Parent(s):
99bdc4b
SA score added
Browse files- molgenevalmetric.py +42 -2
molgenevalmetric.py
CHANGED
@@ -27,6 +27,7 @@ from collections import UserList, defaultdict
|
|
27 |
import numpy as np
|
28 |
import pandas as pd
|
29 |
from rdkit import rdBase
|
|
|
30 |
import sys
|
31 |
|
32 |
from rdkit.Chem import RDConfig
|
@@ -213,6 +214,45 @@ def synthetic_complexity_score(gen):
|
|
213 |
return average_score
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
217 |
batch_size=5000, agg='max',
|
218 |
device='cpu', p=1):
|
@@ -522,12 +562,12 @@ class molgenevalmetric(evaluate.Metric):
|
|
522 |
metrics['valid'] = fraction_valid(gen=gensmi)
|
523 |
metrics['unique'] = fraction_unique(gen=gensmi)
|
524 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
525 |
-
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
526 |
# metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
527 |
|
528 |
print('computing')
|
529 |
|
530 |
-
|
531 |
metrics['SCS'] = synthetic_complexity_score(gen=gensmi)
|
532 |
|
533 |
return metrics
|
|
|
27 |
import numpy as np
|
28 |
import pandas as pd
|
29 |
from rdkit import rdBase
|
30 |
+
from rdkit.Contrib.SA_Score import sascorer
|
31 |
import sys
|
32 |
|
33 |
from rdkit.Chem import RDConfig
|
|
|
214 |
return average_score
|
215 |
|
216 |
|
217 |
+
def calculate_sa_score(smiles):
|
218 |
+
"""
|
219 |
+
Calculates the SA score for a single SMILES string.
|
220 |
+
|
221 |
+
Parameters:
|
222 |
+
- smiles (str): SMILES string of the molecule.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
- float: SA score of the molecule, or None if the molecule couldn't be created.
|
226 |
+
"""
|
227 |
+
mol = Chem.MolFromSmiles(smiles)
|
228 |
+
if mol:
|
229 |
+
return sascorer.calculateScore(mol)
|
230 |
+
else:
|
231 |
+
return None
|
232 |
+
|
233 |
+
def average_sascore(gen, n_jobs=1):
|
234 |
+
"""
|
235 |
+
Computes the average synthetic accessibility score for a list of molecules
|
236 |
+
using parallel or sequential execution based on the n_jobs parameter.
|
237 |
+
|
238 |
+
Parameters:
|
239 |
+
- molecules (List[str]): List of generated SMILES strings.
|
240 |
+
- n_jobs (int): Number of parallel jobs to use for computation.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
- float: Average SA score, or None if no scores could be computed.
|
244 |
+
"""
|
245 |
+
|
246 |
+
scores = mapper(n_jobs)(calculate_sa_score, molecules)
|
247 |
+
|
248 |
+
# Filter out None values which indicate failed molecule creation
|
249 |
+
valid_scores = [score for score in scores if score is not None]
|
250 |
+
|
251 |
+
if valid_scores:
|
252 |
+
return sum(valid_scores) / len(valid_scores)
|
253 |
+
else:
|
254 |
+
return None
|
255 |
+
|
256 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
257 |
batch_size=5000, agg='max',
|
258 |
device='cpu', p=1):
|
|
|
562 |
metrics['valid'] = fraction_valid(gen=gensmi)
|
563 |
metrics['unique'] = fraction_unique(gen=gensmi)
|
564 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
565 |
+
# metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
566 |
# metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
567 |
|
568 |
print('computing')
|
569 |
|
570 |
+
metrics['SA'] = average_sascore(gen=gensmi)
|
571 |
metrics['SCS'] = synthetic_complexity_score(gen=gensmi)
|
572 |
|
573 |
return metrics
|