lvwerra HF staff commited on
Commit
b394cb6
1 Parent(s): cf0c51e

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. roc_auc.py +27 -11
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn
roc_auc.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """Accuracy metric."""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import roc_auc_score
18
 
@@ -142,13 +145,31 @@ year={2011}
142
  """
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
146
  class ROCAUC(evaluate.Metric):
147
- def _info(self):
 
 
 
 
148
  return evaluate.MetricInfo(
149
  description=_DESCRIPTION,
150
  citation=_CITATION,
151
  inputs_description=_KWARGS_DESCRIPTION,
 
152
  features=datasets.Features(
153
  {
154
  "prediction_scores": datasets.Sequence(datasets.Value("float")),
@@ -172,20 +193,15 @@ class ROCAUC(evaluate.Metric):
172
  self,
173
  references,
174
  prediction_scores,
175
- average="macro",
176
- sample_weight=None,
177
- max_fpr=None,
178
- multi_class="raise",
179
- labels=None,
180
  ):
181
  return {
182
  "roc_auc": roc_auc_score(
183
  references,
184
  prediction_scores,
185
- average=average,
186
- sample_weight=sample_weight,
187
- max_fpr=max_fpr,
188
- multi_class=multi_class,
189
- labels=labels,
190
  )
191
  }
 
13
  # limitations under the License.
14
  """Accuracy metric."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Union
18
+
19
  import datasets
20
  from sklearn.metrics import roc_auc_score
21
 
 
145
  """
146
 
147
 
148
+ @dataclass
149
+ class ROCAUCConfig(evaluate.info.Config):
150
+
151
+ name: str = "default"
152
+
153
+ pos_label: Union[str, int] = 1
154
+ average: str = "macro"
155
+ labels: Optional[List[str]] = None
156
+ sample_weight: Optional[List[float]] = None
157
+ max_fpr: Optional[float] = None
158
+ multi_class: str = "raise"
159
+
160
+
161
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
162
  class ROCAUC(evaluate.Metric):
163
+
164
+ CONFIG_CLASS = ROCAUCConfig
165
+ ALLOWED_CONFIG_NAMES = ["default", "multilabel", "multiclass"]
166
+
167
+ def _info(self, config):
168
  return evaluate.MetricInfo(
169
  description=_DESCRIPTION,
170
  citation=_CITATION,
171
  inputs_description=_KWARGS_DESCRIPTION,
172
+ config=config,
173
  features=datasets.Features(
174
  {
175
  "prediction_scores": datasets.Sequence(datasets.Value("float")),
 
193
  self,
194
  references,
195
  prediction_scores,
 
 
 
 
 
196
  ):
197
  return {
198
  "roc_auc": roc_auc_score(
199
  references,
200
  prediction_scores,
201
+ average=self.config.average,
202
+ sample_weight=self.config.sample_weight,
203
+ max_fpr=self.config.max_fpr,
204
+ multi_class=self.config.multi_class,
205
+ labels=self.config.labels,
206
  )
207
  }