File size: 5,218 Bytes
38ae458
a10815e
09e358a
 
 
4848944
b8761e0
45e7e53
 
 
6bdd4e7
816dc36
 
45e7e53
 
 
 
 
816dc36
 
45e7e53
 
816dc36
45e7e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe40ff
45e7e53
bbe40ff
45e7e53
 
 
bbe40ff
 
 
 
 
 
 
 
 
 
 
45b8348
816dc36
45e7e53
 
b4c02aa
4848944
b4c02aa
4848944
6bdd4e7
 
4848944
 
8325032
4848944
6bdd4e7
 
4848944
6bdd4e7
 
8325032
6bdd4e7
 
816dc36
5e2e670
 
6bdd4e7
 
 
 
 
5e2e670
b4c02aa
816dc36
 
79122b2
816dc36
 
d2d6755
816dc36
d2d6755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816dc36
45e7e53
4848944
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import evaluate
import datasets
import moses
from moses import metrics
import pandas as pd
from tdc import Evaluator
from tdc import Oracle


_DESCRIPTION = """

Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives.

"""


_KWARGS_DESCRIPTION = """
Args:
    generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples.
    train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules.

Returns:
    Dectionary item containing various metrics to evaluate model performance
"""


_CITATION = """
@article{DBLP:journals/corr/abs-1811-12823,
  author       = {Daniil Polykovskiy and
                  Alexander Zhebrak and
                  Benjam{\'{\i}}n S{\'{a}}nchez{-}Lengeling and
                  Sergey Golovanov and
                  Oktai Tatanov and
                  Stanislav Belyaev and
                  Rauf Kurbanov and
                  Aleksey Artamonov and
                  Vladimir Aladinskiy and
                  Mark Veselov and
                  Artur Kadurin and
                  Sergey I. Nikolenko and
                  Al{\'{a}}n Aspuru{-}Guzik and
                  Alex Zhavoronkov},
  title        = {Molecular Sets {(MOSES):} {A} Benchmarking Platform for Molecular
                  Generation Models},
  journal      = {CoRR},
  volume       = {abs/1811.12823},
  year         = {2018},
  url          = {http://arxiv.org/abs/1811.12823},
  eprinttype    = {arXiv},
  eprint       = {1811.12823},
  timestamp    = {Fri, 26 Nov 2021 15:34:30 +0100},
  biburl       = {https://dblp.org/rec/journals/corr/abs-1811-12823.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class my_metric(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "generated_smiles": datasets.Sequence(datasets.Value("string")),
                    "train_smiles": datasets.Sequence(datasets.Value("string")),
                }
                if self.config_name == "multilabel"
                else {
                    "generated_smiles": datasets.Value("string"),
                    "train_smiles": datasets.Value("string"),
                }
            ),
                
            reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
        )

    def _compute(self, generated_smiles, train_smiles = None):
        
        Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
        
        # evaluator = Evaluator(name = 'Diversity')
        # Diversity = evaluator(generated_smiles)
        
        evaluator = Evaluator(name = 'KL_Divergence')
        KL_Divergence = evaluator(generated_smiles, train_smiles)
        
        # evaluator = Evaluator(name = 'FCD_Distance')
        # FCD_Distance = evaluator(generated_smiles, train_smiles)
        
        # evaluator = Evaluator(name = 'Novelty')
        # Novelty = evaluator(generated_smiles, train_smiles)
        
        # evaluator = Evaluator(name = 'Validity')
        # Validity = evaluator(generated_smiles)

                
        Results.update({
            # "PyTDC_Diversity": Diversity,
            "KL_Divergence": KL_Divergence,
            # "PyTDC_Validity": Validity,FCD_Distance": FCD_Distance,
            # "PyTDC_Novelty": Novelty,
            # "PyTDC_
        })


        oracle_list = [
        'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
        'DRD2', 'LogP', 'Rediscovery', 'Similarity',
        'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
        ]
    
        # Iterate through each oracle and compute its score
        for oracle_name in oracle_list:
            oracle = Oracle(name=oracle_name)
            if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
                # Assuming these oracles return a dictionary where values are lists of scores
                score = oracle(generated_smiles)
                if isinstance(score, dict):
                    # Convert lists of scores to average score for these specific metrics
                    score = {key: sum(values)/len(values) for key, values in score.items()}
            else:
                # Assuming other oracles return a list of scores
                score = oracle(generated_smiles)
                if isinstance(score, list):
                    # Convert list of scores to average score
                    score = sum(score) / len(score)
            
            Results.update({f"PyTDC_{oracle_name}": score})
    

        return {"results": Results}