Spaces:
Running
Running
Update Space (evaluate main: 828c6327)
Browse files- README.md +190 -5
- app.py +6 -0
- requirements.txt +4 -0
- roc_auc.py +191 -0
README.md
CHANGED
@@ -1,12 +1,197 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.0.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: ROC AUC
|
3 |
+
emoji: 🤗
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.0.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
tags:
|
11 |
+
- evaluate
|
12 |
+
- metric
|
13 |
---
|
14 |
|
15 |
+
# Metric Card for ROC AUC
|
16 |
+
|
17 |
+
|
18 |
+
## Metric Description
|
19 |
+
This metric computes the area under the curve (AUC) for the Receiver Operating Characteristic Curve (ROC). The return values represent how well the model used is predicting the correct classes, based on the input data. A score of `0.5` means that the model is predicting exactly at chance, i.e. the model's predictions are correct at the same rate as if the predictions were being decided by the flip of a fair coin or the roll of a fair die. A score above `0.5` indicates that the model is doing better than chance, while a score below `0.5` indicates that the model is doing worse than chance.
|
20 |
+
|
21 |
+
This metric has three separate use cases:
|
22 |
+
- **binary**: The case in which there are only two different label classes, and each example gets only one label. This is the default implementation.
|
23 |
+
- **multiclass**: The case in which there can be more than two different label classes, but each example still gets only one label.
|
24 |
+
- **multilabel**: The case in which there can be more than two different label classes, and each example can have more than one label.
|
25 |
+
|
26 |
+
|
27 |
+
## How to Use
|
28 |
+
At minimum, this metric requires references and prediction scores:
|
29 |
+
```python
|
30 |
+
>>> roc_auc_score = evaluate.load("roc_auc")
|
31 |
+
>>> refs = [1, 0, 1, 1, 0, 0]
|
32 |
+
>>> pred_scores = [0.5, 0.2, 0.99, 0.3, 0.1, 0.7]
|
33 |
+
>>> results = roc_auc_score.compute(references=refs, prediction_scores=pred_scores)
|
34 |
+
>>> print(round(results['roc_auc'], 2))
|
35 |
+
0.78
|
36 |
+
```
|
37 |
+
|
38 |
+
The default implementation of this metric is the **binary** implementation. If employing the **multiclass** or **multilabel** use cases, the keyword `"multiclass"` or `"multilabel"` must be specified when loading the metric:
|
39 |
+
- In the **multiclass** case, the metric is loaded with:
|
40 |
+
```python
|
41 |
+
>>> roc_auc_score = evaluate.load("roc_auc", "multiclass")
|
42 |
+
```
|
43 |
+
- In the **multilabel** case, the metric is loaded with:
|
44 |
+
```python
|
45 |
+
>>> roc_auc_score = evaluate.load("roc_auc", "multilabel")
|
46 |
+
```
|
47 |
+
|
48 |
+
See the [Examples Section Below](#examples_section) for more extensive examples.
|
49 |
+
|
50 |
+
|
51 |
+
### Inputs
|
52 |
+
- **`references`** (array-like of shape (n_samples,) or (n_samples, n_classes)): Ground truth labels. Expects different inputs based on use case:
|
53 |
+
- binary: expects an array-like of shape (n_samples,)
|
54 |
+
- multiclass: expects an array-like of shape (n_samples,)
|
55 |
+
- multilabel: expects an array-like of shape (n_samples, n_classes)
|
56 |
+
- **`prediction_scores`** (array-like of shape (n_samples,) or (n_samples, n_classes)): Model predictions. Expects different inputs based on use case:
|
57 |
+
- binary: expects an array-like of shape (n_samples,)
|
58 |
+
- multiclass: expects an array-like of shape (n_samples, n_classes). The probability estimates must sum to 1 across the possible classes.
|
59 |
+
- multilabel: expects an array-like of shape (n_samples, n_classes)
|
60 |
+
- **`average`** (`str`): Type of average, and is ignored in the binary use case. Defaults to `'macro'`. Options are:
|
61 |
+
- `'micro'`: Calculates metrics globally by considering each element of the label indicator matrix as a label. Only works with the multilabel use case.
|
62 |
+
- `'macro'`: Calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account.
|
63 |
+
- `'weighted'`: Calculate metrics for each label, and find their average, weighted by support (i.e. the number of true instances for each label).
|
64 |
+
- `'samples'`: Calculate metrics for each instance, and find their average. Only works with the multilabel use case.
|
65 |
+
- `None`: No average is calculated, and scores for each class are returned. Only works with the multilabels use case.
|
66 |
+
- **`sample_weight`** (array-like of shape (n_samples,)): Sample weights. Defaults to None.
|
67 |
+
- **`max_fpr`** (`float`): If not None, the standardized partial AUC over the range [0, `max_fpr`] is returned. Must be greater than `0` and less than or equal to `1`. Defaults to `None`. Note: For the multiclass use case, `max_fpr` should be either `None` or `1.0` as ROC AUC partial computation is not currently supported for `multiclass`.
|
68 |
+
- **`multi_class`** (`str`): Only used for multiclass targets, in which case it is required. Determines the type of configuration to use. Options are:
|
69 |
+
- `'ovr'`: Stands for One-vs-rest. Computes the AUC of each class against the rest. This treats the multiclass case in the same way as the multilabel case. Sensitive to class imbalance even when `average == 'macro'`, because class imbalance affects the composition of each of the 'rest' groupings.
|
70 |
+
- `'ovo'`: Stands for One-vs-one. Computes the average AUC of all possible pairwise combinations of classes. Insensitive to class imbalance when `average == 'macro'`.
|
71 |
+
- **`labels`** (array-like of shape (n_classes,)): Only used for multiclass targets. List of labels that index the classes in `prediction_scores`. If `None`, the numerical or lexicographical order of the labels in `prediction_scores` is used. Defaults to `None`.
|
72 |
+
|
73 |
+
### Output Values
|
74 |
+
This metric returns a dict containing the `roc_auc` score. The score is a `float`, unless it is the multilabel case with `average=None`, in which case the score is a numpy `array` with entries of type `float`.
|
75 |
+
|
76 |
+
The output therefore generally takes the following format:
|
77 |
+
```python
|
78 |
+
{'roc_auc': 0.778}
|
79 |
+
```
|
80 |
+
|
81 |
+
In contrast, though, the output takes the following format in the multilabel case when `average=None`:
|
82 |
+
```python
|
83 |
+
{'roc_auc': array([0.83333333, 0.375, 0.94444444])}
|
84 |
+
```
|
85 |
+
|
86 |
+
ROC AUC scores can take on any value between `0` and `1`, inclusive.
|
87 |
+
|
88 |
+
#### Values from Popular Papers
|
89 |
+
|
90 |
+
|
91 |
+
### <a name="examples_section"></a>Examples
|
92 |
+
Example 1, the **binary** use case:
|
93 |
+
```python
|
94 |
+
>>> roc_auc_score = evaluate.load("roc_auc")
|
95 |
+
>>> refs = [1, 0, 1, 1, 0, 0]
|
96 |
+
>>> pred_scores = [0.5, 0.2, 0.99, 0.3, 0.1, 0.7]
|
97 |
+
>>> results = roc_auc_score.compute(references=refs, prediction_scores=pred_scores)
|
98 |
+
>>> print(round(results['roc_auc'], 2))
|
99 |
+
0.78
|
100 |
+
```
|
101 |
+
|
102 |
+
Example 2, the **multiclass** use case:
|
103 |
+
```python
|
104 |
+
>>> roc_auc_score = evaluate.load("roc_auc", "multiclass")
|
105 |
+
>>> refs = [1, 0, 1, 2, 2, 0]
|
106 |
+
>>> pred_scores = [[0.3, 0.5, 0.2],
|
107 |
+
... [0.7, 0.2, 0.1],
|
108 |
+
... [0.005, 0.99, 0.005],
|
109 |
+
... [0.2, 0.3, 0.5],
|
110 |
+
... [0.1, 0.1, 0.8],
|
111 |
+
... [0.1, 0.7, 0.2]]
|
112 |
+
>>> results = roc_auc_score.compute(references=refs,
|
113 |
+
... prediction_scores=pred_scores,
|
114 |
+
... multi_class='ovr')
|
115 |
+
>>> print(round(results['roc_auc'], 2))
|
116 |
+
0.85
|
117 |
+
```
|
118 |
+
|
119 |
+
Example 3, the **multilabel** use case:
|
120 |
+
```python
|
121 |
+
>>> roc_auc_score = evaluate.load("roc_auc", "multilabel")
|
122 |
+
>>> refs = [[1, 1, 0],
|
123 |
+
... [1, 1, 0],
|
124 |
+
... [0, 1, 0],
|
125 |
+
... [0, 0, 1],
|
126 |
+
... [0, 1, 1],
|
127 |
+
... [1, 0, 1]]
|
128 |
+
>>> pred_scores = [[0.3, 0.5, 0.2],
|
129 |
+
... [0.7, 0.2, 0.1],
|
130 |
+
... [0.005, 0.99, 0.005],
|
131 |
+
... [0.2, 0.3, 0.5],
|
132 |
+
... [0.1, 0.1, 0.8],
|
133 |
+
... [0.1, 0.7, 0.2]]
|
134 |
+
>>> results = roc_auc_score.compute(references=refs,
|
135 |
+
... prediction_scores=pred_scores,
|
136 |
+
... average=None)
|
137 |
+
>>> print([round(res, 2) for res in results['roc_auc'])
|
138 |
+
[0.83, 0.38, 0.94]
|
139 |
+
```
|
140 |
+
|
141 |
+
|
142 |
+
## Limitations and Bias
|
143 |
+
|
144 |
+
|
145 |
+
## Citation
|
146 |
+
```bibtex
|
147 |
+
@article{doi:10.1177/0272989X8900900307,
|
148 |
+
author = {Donna Katzman McClish},
|
149 |
+
title ={Analyzing a Portion of the ROC Curve},
|
150 |
+
journal = {Medical Decision Making},
|
151 |
+
volume = {9},
|
152 |
+
number = {3},
|
153 |
+
pages = {190-195},
|
154 |
+
year = {1989},
|
155 |
+
doi = {10.1177/0272989X8900900307},
|
156 |
+
note ={PMID: 2668680},
|
157 |
+
URL = {https://doi.org/10.1177/0272989X8900900307},
|
158 |
+
eprint = {https://doi.org/10.1177/0272989X8900900307}
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
```bibtex
|
163 |
+
@article{10.1023/A:1010920819831,
|
164 |
+
author = {Hand, David J. and Till, Robert J.},
|
165 |
+
title = {A Simple Generalisation of the Area Under the ROC Curve for Multiple Class Classification Problems},
|
166 |
+
year = {2001},
|
167 |
+
issue_date = {November 2001},
|
168 |
+
publisher = {Kluwer Academic Publishers},
|
169 |
+
address = {USA},
|
170 |
+
volume = {45},
|
171 |
+
number = {2},
|
172 |
+
issn = {0885-6125},
|
173 |
+
url = {https://doi.org/10.1023/A:1010920819831},
|
174 |
+
doi = {10.1023/A:1010920819831},
|
175 |
+
journal = {Mach. Learn.},
|
176 |
+
month = {oct},
|
177 |
+
pages = {171–186},
|
178 |
+
numpages = {16},
|
179 |
+
keywords = {Gini index, AUC, error rate, ROC curve, receiver operating characteristic}
|
180 |
+
}
|
181 |
+
```
|
182 |
+
|
183 |
+
```bibtex
|
184 |
+
@article{scikit-learn,
|
185 |
+
title={Scikit-learn: Machine Learning in {P}ython},
|
186 |
+
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
187 |
+
journal={Journal of Machine Learning Research},
|
188 |
+
volume={12},
|
189 |
+
pages={2825--2830},
|
190 |
+
year={2011}
|
191 |
+
}
|
192 |
+
```
|
193 |
+
|
194 |
+
## Further References
|
195 |
+
This implementation is a wrapper around the [Scikit-learn implementation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html). Much of the documentation here was adapted from their existing documentation, as well.
|
196 |
+
|
197 |
+
The [Guide to ROC and AUC](https://youtu.be/iCZJfO-7C5Q) video from the channel Data Science Bits is also very informative.
|
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from evaluate.utils import launch_gradio_widget
|
3 |
+
|
4 |
+
|
5 |
+
module = evaluate.load("roc_auc")
|
6 |
+
launch_gradio_widget(module)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: fix github to release
|
2 |
+
git+https://github.com/huggingface/evaluate.git@b6e6ed7f3e6844b297bff1b43a1b4be0709b9671
|
3 |
+
datasets~=2.0
|
4 |
+
sklearn
|
roc_auc.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Accuracy metric."""
|
15 |
+
|
16 |
+
import datasets
|
17 |
+
from sklearn.metrics import roc_auc_score
|
18 |
+
|
19 |
+
import evaluate
|
20 |
+
|
21 |
+
|
22 |
+
_DESCRIPTION = """
|
23 |
+
This metric computes the area under the curve (AUC) for the Receiver Operating Characteristic Curve (ROC). The return values represent how well the model used is predicting the correct classes, based on the input data. A score of `0.5` means that the model is predicting exactly at chance, i.e. the model's predictions are correct at the same rate as if the predictions were being decided by the flip of a fair coin or the roll of a fair die. A score above `0.5` indicates that the model is doing better than chance, while a score below `0.5` indicates that the model is doing worse than chance.
|
24 |
+
|
25 |
+
This metric has three separate use cases:
|
26 |
+
- binary: The case in which there are only two different label classes, and each example gets only one label. This is the default implementation.
|
27 |
+
- multiclass: The case in which there can be more than two different label classes, but each example still gets only one label.
|
28 |
+
- multilabel: The case in which there can be more than two different label classes, and each example can have more than one label.
|
29 |
+
"""
|
30 |
+
|
31 |
+
_KWARGS_DESCRIPTION = """
|
32 |
+
Args:
|
33 |
+
- references (array-like of shape (n_samples,) or (n_samples, n_classes)): Ground truth labels. Expects different input based on use case:
|
34 |
+
- binary: expects an array-like of shape (n_samples,)
|
35 |
+
- multiclass: expects an array-like of shape (n_samples,)
|
36 |
+
- multilabel: expects an array-like of shape (n_samples, n_classes)
|
37 |
+
- prediction_scores (array-like of shape (n_samples,) or (n_samples, n_classes)): Model predictions. Expects different inputs based on use case:
|
38 |
+
- binary: expects an array-like of shape (n_samples,)
|
39 |
+
- multiclass: expects an array-like of shape (n_samples, n_classes)
|
40 |
+
- multilabel: expects an array-like of shape (n_samples, n_classes)
|
41 |
+
- average (`str`): Type of average, and is ignored in the binary use case. Defaults to 'macro'. Options are:
|
42 |
+
- `'micro'`: Calculates metrics globally by considering each element of the label indicator matrix as a label. Only works with the multilabel use case.
|
43 |
+
- `'macro'`: Calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account.
|
44 |
+
- `'weighted'`: Calculate metrics for each label, and find their average, weighted by support (i.e. the number of true instances for each label).
|
45 |
+
- `'samples'`: Calculate metrics for each instance, and find their average. Only works with the multilabel use case.
|
46 |
+
- `None`: No average is calculated, and scores for each class are returned. Only works with the multilabels use case.
|
47 |
+
- sample_weight (array-like of shape (n_samples,)): Sample weights. Defaults to None.
|
48 |
+
- max_fpr (`float`): If not None, the standardized partial AUC over the range [0, `max_fpr`] is returned. Must be greater than `0` and less than or equal to `1`. Defaults to `None`. Note: For the multiclass use case, `max_fpr` should be either `None` or `1.0` as ROC AUC partial computation is not currently supported for `multiclass`.
|
49 |
+
- multi_class (`str`): Only used for multiclass targets, where it is required. Determines the type of configuration to use. Options are:
|
50 |
+
- `'ovr'`: Stands for One-vs-rest. Computes the AUC of each class against the rest. This treats the multiclass case in the same way as the multilabel case. Sensitive to class imbalance even when `average == 'macro'`, because class imbalance affects the composition of each of the 'rest' groupings.
|
51 |
+
- `'ovo'`: Stands for One-vs-one. Computes the average AUC of all possible pairwise combinations of classes. Insensitive to class imbalance when `average == 'macro'`.
|
52 |
+
- labels (array-like of shape (n_classes,)): Only used for multiclass targets. List of labels that index the classes in
|
53 |
+
`prediction_scores`. If `None`, the numerical or lexicographical order of the labels in
|
54 |
+
`prediction_scores` is used. Defaults to `None`.
|
55 |
+
Returns:
|
56 |
+
roc_auc (`float` or array-like of shape (n_classes,)): Returns array if in multilabel use case and `average='None'`. Otherwise, returns `float`.
|
57 |
+
Examples:
|
58 |
+
Example 1:
|
59 |
+
>>> roc_auc_score = evaluate.load("roc_auc")
|
60 |
+
>>> refs = [1, 0, 1, 1, 0, 0]
|
61 |
+
>>> pred_scores = [0.5, 0.2, 0.99, 0.3, 0.1, 0.7]
|
62 |
+
>>> results = roc_auc_score.compute(references=refs, prediction_scores=pred_scores)
|
63 |
+
>>> print(round(results['roc_auc'], 2))
|
64 |
+
0.78
|
65 |
+
|
66 |
+
Example 2:
|
67 |
+
>>> roc_auc_score = evaluate.load("roc_auc", "multiclass")
|
68 |
+
>>> refs = [1, 0, 1, 2, 2, 0]
|
69 |
+
>>> pred_scores = [[0.3, 0.5, 0.2],
|
70 |
+
... [0.7, 0.2, 0.1],
|
71 |
+
... [0.005, 0.99, 0.005],
|
72 |
+
... [0.2, 0.3, 0.5],
|
73 |
+
... [0.1, 0.1, 0.8],
|
74 |
+
... [0.1, 0.7, 0.2]]
|
75 |
+
>>> results = roc_auc_score.compute(references=refs, prediction_scores=pred_scores, multi_class='ovr')
|
76 |
+
>>> print(round(results['roc_auc'], 2))
|
77 |
+
0.85
|
78 |
+
|
79 |
+
Example 3:
|
80 |
+
>>> roc_auc_score = evaluate.load("roc_auc", "multilabel")
|
81 |
+
>>> refs = [[1, 1, 0],
|
82 |
+
... [1, 1, 0],
|
83 |
+
... [0, 1, 0],
|
84 |
+
... [0, 0, 1],
|
85 |
+
... [0, 1, 1],
|
86 |
+
... [1, 0, 1]]
|
87 |
+
>>> pred_scores = [[0.3, 0.5, 0.2],
|
88 |
+
... [0.7, 0.2, 0.1],
|
89 |
+
... [0.005, 0.99, 0.005],
|
90 |
+
... [0.2, 0.3, 0.5],
|
91 |
+
... [0.1, 0.1, 0.8],
|
92 |
+
... [0.1, 0.7, 0.2]]
|
93 |
+
>>> results = roc_auc_score.compute(references=refs, prediction_scores=pred_scores, average=None)
|
94 |
+
>>> print([round(res, 2) for res in results['roc_auc']])
|
95 |
+
[0.83, 0.38, 0.94]
|
96 |
+
"""
|
97 |
+
|
98 |
+
_CITATION = """\
|
99 |
+
@article{doi:10.1177/0272989X8900900307,
|
100 |
+
author = {Donna Katzman McClish},
|
101 |
+
title ={Analyzing a Portion of the ROC Curve},
|
102 |
+
journal = {Medical Decision Making},
|
103 |
+
volume = {9},
|
104 |
+
number = {3},
|
105 |
+
pages = {190-195},
|
106 |
+
year = {1989},
|
107 |
+
doi = {10.1177/0272989X8900900307},
|
108 |
+
note ={PMID: 2668680},
|
109 |
+
URL = {https://doi.org/10.1177/0272989X8900900307},
|
110 |
+
eprint = {https://doi.org/10.1177/0272989X8900900307}
|
111 |
+
}
|
112 |
+
|
113 |
+
|
114 |
+
@article{10.1023/A:1010920819831,
|
115 |
+
author = {Hand, David J. and Till, Robert J.},
|
116 |
+
title = {A Simple Generalisation of the Area Under the ROC Curve for Multiple Class Classification Problems},
|
117 |
+
year = {2001},
|
118 |
+
issue_date = {November 2001},
|
119 |
+
publisher = {Kluwer Academic Publishers},
|
120 |
+
address = {USA},
|
121 |
+
volume = {45},
|
122 |
+
number = {2},
|
123 |
+
issn = {0885-6125},
|
124 |
+
url = {https://doi.org/10.1023/A:1010920819831},
|
125 |
+
doi = {10.1023/A:1010920819831},
|
126 |
+
journal = {Mach. Learn.},
|
127 |
+
month = {oct},
|
128 |
+
pages = {171–186},
|
129 |
+
numpages = {16},
|
130 |
+
keywords = {Gini index, AUC, error rate, ROC curve, receiver operating characteristic}
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
@article{scikit-learn,
|
135 |
+
title={Scikit-learn: Machine Learning in {P}ython},
|
136 |
+
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
137 |
+
journal={Journal of Machine Learning Research},
|
138 |
+
volume={12},
|
139 |
+
pages={2825--2830},
|
140 |
+
year={2011}
|
141 |
+
}
|
142 |
+
"""
|
143 |
+
|
144 |
+
|
145 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
146 |
+
class ROCAUC(evaluate.EvaluationModule):
|
147 |
+
def _info(self):
|
148 |
+
return evaluate.EvaluationModuleInfo(
|
149 |
+
description=_DESCRIPTION,
|
150 |
+
citation=_CITATION,
|
151 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
152 |
+
features=datasets.Features(
|
153 |
+
{
|
154 |
+
"prediction_scores": datasets.Sequence(datasets.Value("float")),
|
155 |
+
"references": datasets.Value("int32"),
|
156 |
+
}
|
157 |
+
if self.config_name == "multiclass"
|
158 |
+
else {
|
159 |
+
"references": datasets.Sequence(datasets.Value("int32")),
|
160 |
+
"prediction_scores": datasets.Sequence(datasets.Value("float")),
|
161 |
+
}
|
162 |
+
if self.config_name == "multilabel"
|
163 |
+
else {
|
164 |
+
"references": datasets.Value("int32"),
|
165 |
+
"prediction_scores": datasets.Value("float"),
|
166 |
+
}
|
167 |
+
),
|
168 |
+
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"],
|
169 |
+
)
|
170 |
+
|
171 |
+
def _compute(
|
172 |
+
self,
|
173 |
+
references,
|
174 |
+
prediction_scores,
|
175 |
+
average="macro",
|
176 |
+
sample_weight=None,
|
177 |
+
max_fpr=None,
|
178 |
+
multi_class="raise",
|
179 |
+
labels=None,
|
180 |
+
):
|
181 |
+
return {
|
182 |
+
"roc_auc": roc_auc_score(
|
183 |
+
references,
|
184 |
+
prediction_scores,
|
185 |
+
average=average,
|
186 |
+
sample_weight=sample_weight,
|
187 |
+
max_fpr=max_fpr,
|
188 |
+
multi_class=multi_class,
|
189 |
+
labels=labels,
|
190 |
+
)
|
191 |
+
}
|