saicharan2804 commited on
Commit
4e6d2e7
1 Parent(s): 14684b8

Changes to code

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +134 -0
molgenevalmetric.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import datasets
3
+ import moses
4
+ from moses import metrics
5
+ import pandas as pd
6
+ from tdc import Evaluator
7
+ from tdc import Oracle
8
+
9
+
10
+ _DESCRIPTION = """
11
+
12
+ Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives.
13
+
14
+ """
15
+
16
+
17
+ _KWARGS_DESCRIPTION = """
18
+ Args:
19
+ generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples.
20
+ train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules.
21
+
22
+ Returns:
23
+ Dectionary item containing various metrics to evaluate model performance
24
+ """
25
+
26
+
27
+ _CITATION = """
28
+ @article{DBLP:journals/corr/abs-1811-12823,
29
+ author = {Daniil Polykovskiy and
30
+ Alexander Zhebrak and
31
+ Benjam{\'{\i}}n S{\'{a}}nchez{-}Lengeling and
32
+ Sergey Golovanov and
33
+ Oktai Tatanov and
34
+ Stanislav Belyaev and
35
+ Rauf Kurbanov and
36
+ Aleksey Artamonov and
37
+ Vladimir Aladinskiy and
38
+ Mark Veselov and
39
+ Artur Kadurin and
40
+ Sergey I. Nikolenko and
41
+ Al{\'{a}}n Aspuru{-}Guzik and
42
+ Alex Zhavoronkov},
43
+ title = {Molecular Sets {(MOSES):} {A} Benchmarking Platform for Molecular
44
+ Generation Models},
45
+ journal = {CoRR},
46
+ volume = {abs/1811.12823},
47
+ year = {2018},
48
+ url = {http://arxiv.org/abs/1811.12823},
49
+ eprinttype = {arXiv},
50
+ eprint = {1811.12823},
51
+ timestamp = {Fri, 26 Nov 2021 15:34:30 +0100},
52
+ biburl = {https://dblp.org/rec/journals/corr/abs-1811-12823.bib},
53
+ bibsource = {dblp computer science bibliography, https://dblp.org}
54
+ }
55
+ """
56
+
57
+
58
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
59
+ class molgenevalmetric(evaluate.Metric):
60
+ def _info(self):
61
+ return evaluate.MetricInfo(
62
+ description=_DESCRIPTION,
63
+ citation=_CITATION,
64
+ inputs_description=_KWARGS_DESCRIPTION,
65
+ features=datasets.Features(
66
+ {
67
+ "generated_smiles": datasets.Sequence(datasets.Value("string")),
68
+ "train_smiles": datasets.Sequence(datasets.Value("string")),
69
+ }
70
+ if self.config_name == "multilabel"
71
+ else {
72
+ "generated_smiles": datasets.Value("string"),
73
+ "train_smiles": datasets.Value("string"),
74
+ }
75
+ ),
76
+
77
+ reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
78
+ )
79
+
80
+ def _compute(self, generated_smiles, train_smiles = None):
81
+
82
+ Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
83
+
84
+ # evaluator = Evaluator(name = 'Diversity')
85
+ # Diversity = evaluator(generated_smiles)
86
+
87
+ evaluator = Evaluator(name = 'KL_Divergence')
88
+ KL_Divergence = evaluator(generated_smiles, train_smiles)
89
+
90
+ # evaluator = Evaluator(name = 'FCD_Distance')
91
+ # FCD_Distance = evaluator(generated_smiles, train_smiles)
92
+
93
+ # evaluator = Evaluator(name = 'Novelty')
94
+ # Novelty = evaluator(generated_smiles, train_smiles)
95
+
96
+ # evaluator = Evaluator(name = 'Validity')
97
+ # Validity = evaluator(generated_smiles)
98
+
99
+
100
+ Results.update({
101
+ # "PyTDC_Diversity": Diversity,
102
+ "KL_Divergence": KL_Divergence,
103
+ # "PyTDC_Validity": Validity,FCD_Distance": FCD_Distance,
104
+ # "PyTDC_Novelty": Novelty,
105
+ # "PyTDC_
106
+ })
107
+
108
+
109
+ oracle_list = [
110
+ 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
111
+ 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
112
+ 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
113
+ ]
114
+
115
+ # Iterate through each oracle and compute its score
116
+ for oracle_name in oracle_list:
117
+ oracle = Oracle(name=oracle_name)
118
+ if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
119
+ # Assuming these oracles return a dictionary where values are lists of scores
120
+ score = oracle(generated_smiles)
121
+ if isinstance(score, dict):
122
+ # Convert lists of scores to average score for these specific metrics
123
+ score = {key: sum(values)/len(values) for key, values in score.items()}
124
+ else:
125
+ # Assuming other oracles return a list of scores
126
+ score = oracle(generated_smiles)
127
+ if isinstance(score, list):
128
+ # Convert list of scores to average score
129
+ score = sum(score) / len(score)
130
+
131
+ Results.update({f"PyTDC_{oracle_name}": score})
132
+
133
+
134
+ return {"results": Results}