saicharan2804 commited on
Commit
816dc36
1 Parent(s): eaefc6b

Added oracles

Browse files
Files changed (1) hide show
  1. my_metric.py +46 -5
my_metric.py CHANGED
@@ -8,16 +8,18 @@ from tdc import Evaluator
8
 
9
  _DESCRIPTION = """
10
  Moses and PyTDC metrics
 
 
11
  """
12
 
13
 
14
  _KWARGS_DESCRIPTION = """
15
  Args:
16
- generated_smiles (`list` of `string`): Predicted labels.
17
- train_smiles (`list` of `string`): test.
18
 
19
  Returns:
20
- All moses metrics
21
  """
22
 
23
 
@@ -70,7 +72,7 @@ class my_metric(evaluate.Metric):
70
  "train_smiles": datasets.Value("string"),
71
  }
72
  ),
73
- reference_urls=["https://github.com/molecularsets/moses"],
74
  )
75
 
76
  def _compute(self, generated_smiles, train_smiles):
@@ -91,13 +93,52 @@ class my_metric(evaluate.Metric):
91
 
92
  evaluator = Evaluator(name = 'Validity')
93
  Validity = evaluator(generated_smiles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  Results.update({
96
  "PyTDC_Diversity": Diversity,
97
  "PyTDC_KL_Divergence": KL_Divergence,
98
  "PyTDC_FCD_Distance": FCD_Distance,
99
  "PyTDC_Novelty": Novelty,
100
- "PyTDC_Validity": Validity
 
 
 
 
101
  })
102
 
103
  return {"results": Results}
 
8
 
9
  _DESCRIPTION = """
10
  Moses and PyTDC metrics
11
+ Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives.
12
+
13
  """
14
 
15
 
16
  _KWARGS_DESCRIPTION = """
17
  Args:
18
+ generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples.
19
+ train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules.
20
 
21
  Returns:
22
+ Dectionary item containing various metrics to evaluate model performance
23
  """
24
 
25
 
 
72
  "train_smiles": datasets.Value("string"),
73
  }
74
  ),
75
+ reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
76
  )
77
 
78
  def _compute(self, generated_smiles, train_smiles):
 
93
 
94
  evaluator = Evaluator(name = 'Validity')
95
  Validity = evaluator(generated_smiles)
96
+
97
+ oracle = Oracle(name = 'QED')
98
+ QED = oracle(generated_smiles)
99
+
100
+ oracle = Oracle(name = 'SA')
101
+ SA = oracle(generated_smiles)
102
+
103
+ oracle = Oracle(name = 'MPO')
104
+ MPO = oracle(generated_smiles)
105
+ MPO = {key: sum(values)/len(values) for key, values in MPO.items()}
106
+
107
+ oracle_list = [
108
+ 'QED', 'SA', 'MPO', '3pbl_docking', 'GSK3B', 'JNK3',
109
+ 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
110
+ 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
111
+ ]
112
+
113
+ # Iterate through each oracle and compute its score
114
+ for oracle_name in oracle_list:
115
+ oracle = Oracle(name=oracle_name)
116
+ if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
117
+ # Assuming these oracles return a dictionary where values are lists of scores
118
+ score = oracle(generated_smiles)
119
+ if isinstance(score, dict):
120
+ # Convert lists of scores to average score for these specific metrics
121
+ score = {key: sum(values)/len(values) for key, values in score.items()}
122
+ else:
123
+ # Assuming other oracles return a list of scores
124
+ score = oracle(generated_smiles)
125
+ if isinstance(score, list):
126
+ # Convert list of scores to average score
127
+ score = sum(score) / len(score)
128
+
129
+ Results.update({f"PyTDC_{oracle_name}": score})
130
+
131
 
132
  Results.update({
133
  "PyTDC_Diversity": Diversity,
134
  "PyTDC_KL_Divergence": KL_Divergence,
135
  "PyTDC_FCD_Distance": FCD_Distance,
136
  "PyTDC_Novelty": Novelty,
137
+ "PyTDC_Validity": Validity,
138
+
139
+ "PyTDC_QED": sum(QED)/len(QED),
140
+ "PyTDC_SA": sum(SA)/len(SA),
141
+ "PyTDC_MPO": MPO
142
  })
143
 
144
  return {"results": Results}