lvwerra HF staff commited on
Commit
f980503
·
1 Parent(s): 220636d

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. accuracy.py +21 -3
  2. requirements.txt +1 -1
accuracy.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """Accuracy metric."""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import accuracy_score
18
 
@@ -77,13 +80,26 @@ _CITATION = """
77
  """
78
 
79
 
 
 
 
 
 
 
 
 
 
80
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
81
  class Accuracy(evaluate.Metric):
82
- def _info(self):
 
 
 
83
  return evaluate.MetricInfo(
84
  description=_DESCRIPTION,
85
  citation=_CITATION,
86
  inputs_description=_KWARGS_DESCRIPTION,
 
87
  features=datasets.Features(
88
  {
89
  "predictions": datasets.Sequence(datasets.Value("int32")),
@@ -98,9 +114,11 @@ class Accuracy(evaluate.Metric):
98
  reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
99
  )
100
 
101
- def _compute(self, predictions, references, normalize=True, sample_weight=None):
102
  return {
103
  "accuracy": float(
104
- accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
 
 
105
  )
106
  }
 
13
  # limitations under the License.
14
  """Accuracy metric."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
  import datasets
20
  from sklearn.metrics import accuracy_score
21
 
 
80
  """
81
 
82
 
83
+ @dataclass
84
+ class AccuracyConfig(evaluate.info.Config):
85
+
86
+ name: str = "default"
87
+
88
+ normalize: bool = True
89
+ sample_weight: Optional[List[float]] = None
90
+
91
+
92
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
93
  class Accuracy(evaluate.Metric):
94
+ CONFIG_CLASS = AccuracyConfig
95
+ ALLOWED_CONFIG_NAMES = ["default", "multilabel"]
96
+
97
+ def _info(self, config):
98
  return evaluate.MetricInfo(
99
  description=_DESCRIPTION,
100
  citation=_CITATION,
101
  inputs_description=_KWARGS_DESCRIPTION,
102
+ config=config,
103
  features=datasets.Features(
104
  {
105
  "predictions": datasets.Sequence(datasets.Value("int32")),
 
114
  reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
115
  )
116
 
117
+ def _compute(self, predictions, references):
118
  return {
119
  "accuracy": float(
120
+ accuracy_score(
121
+ references, predictions, normalize=self.config.normalize, sample_weight=self.config.sample_weight
122
+ )
123
  )
124
  }
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn