lvwerra HF staff commited on
Commit
c7ea9bb
1 Parent(s): b394cb6

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. roc_auc.py +11 -27
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
2
  sklearn
roc_auc.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  """Accuracy metric."""
15
 
16
- from dataclasses import dataclass
17
- from typing import List, Optional, Union
18
-
19
  import datasets
20
  from sklearn.metrics import roc_auc_score
21
 
@@ -145,31 +142,13 @@ year={2011}
145
  """
146
 
147
 
148
- @dataclass
149
- class ROCAUCConfig(evaluate.info.Config):
150
-
151
- name: str = "default"
152
-
153
- pos_label: Union[str, int] = 1
154
- average: str = "macro"
155
- labels: Optional[List[str]] = None
156
- sample_weight: Optional[List[float]] = None
157
- max_fpr: Optional[float] = None
158
- multi_class: str = "raise"
159
-
160
-
161
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
162
  class ROCAUC(evaluate.Metric):
163
-
164
- CONFIG_CLASS = ROCAUCConfig
165
- ALLOWED_CONFIG_NAMES = ["default", "multilabel", "multiclass"]
166
-
167
- def _info(self, config):
168
  return evaluate.MetricInfo(
169
  description=_DESCRIPTION,
170
  citation=_CITATION,
171
  inputs_description=_KWARGS_DESCRIPTION,
172
- config=config,
173
  features=datasets.Features(
174
  {
175
  "prediction_scores": datasets.Sequence(datasets.Value("float")),
@@ -193,15 +172,20 @@ class ROCAUC(evaluate.Metric):
193
  self,
194
  references,
195
  prediction_scores,
 
 
 
 
 
196
  ):
197
  return {
198
  "roc_auc": roc_auc_score(
199
  references,
200
  prediction_scores,
201
- average=self.config.average,
202
- sample_weight=self.config.sample_weight,
203
- max_fpr=self.config.max_fpr,
204
- multi_class=self.config.multi_class,
205
- labels=self.config.labels,
206
  )
207
  }
 
13
  # limitations under the License.
14
  """Accuracy metric."""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import roc_auc_score
18
 
 
142
  """
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
146
  class ROCAUC(evaluate.Metric):
147
+ def _info(self):
 
 
 
 
148
  return evaluate.MetricInfo(
149
  description=_DESCRIPTION,
150
  citation=_CITATION,
151
  inputs_description=_KWARGS_DESCRIPTION,
 
152
  features=datasets.Features(
153
  {
154
  "prediction_scores": datasets.Sequence(datasets.Value("float")),
 
172
  self,
173
  references,
174
  prediction_scores,
175
+ average="macro",
176
+ sample_weight=None,
177
+ max_fpr=None,
178
+ multi_class="raise",
179
+ labels=None,
180
  ):
181
  return {
182
  "roc_auc": roc_auc_score(
183
  references,
184
  prediction_scores,
185
+ average=average,
186
+ sample_weight=sample_weight,
187
+ max_fpr=max_fpr,
188
+ multi_class=multi_class,
189
+ labels=labels,
190
  )
191
  }