lvwerra HF staff commited on
Commit
4dedfce
1 Parent(s): 67f8166

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. recall.py +26 -11
  2. requirements.txt +1 -1
recall.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """Recall metric."""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import recall_score
18
 
@@ -92,13 +95,30 @@ _CITATION = """
92
  """
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
96
  class Recall(evaluate.Metric):
97
- def _info(self):
 
 
 
 
98
  return evaluate.MetricInfo(
99
  description=_DESCRIPTION,
100
  citation=_CITATION,
101
  inputs_description=_KWARGS_DESCRIPTION,
 
102
  features=datasets.Features(
103
  {
104
  "predictions": datasets.Sequence(datasets.Value("int32")),
@@ -117,19 +137,14 @@ class Recall(evaluate.Metric):
117
  self,
118
  predictions,
119
  references,
120
- labels=None,
121
- pos_label=1,
122
- average="binary",
123
- sample_weight=None,
124
- zero_division="warn",
125
  ):
126
  score = recall_score(
127
  references,
128
  predictions,
129
- labels=labels,
130
- pos_label=pos_label,
131
- average=average,
132
- sample_weight=sample_weight,
133
- zero_division=zero_division,
134
  )
135
  return {"recall": float(score) if score.size == 1 else score}
13
  # limitations under the License.
14
  """Recall metric."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Union
18
+
19
  import datasets
20
  from sklearn.metrics import recall_score
21
 
95
  """
96
 
97
 
98
+ @dataclass
99
+ class RecallConfig(evaluate.info.Config):
100
+
101
+ name: str = "default"
102
+
103
+ pos_label: Union[str, int] = 1
104
+ average: str = "binary"
105
+ labels: Optional[List[str]] = None
106
+ sample_weight: Optional[List[float]] = None
107
+ zero_division: str = "warn"
108
+
109
+
110
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
111
  class Recall(evaluate.Metric):
112
+
113
+ CONFIG_CLASS = RecallConfig
114
+ ALLOWED_CONFIG_NAMES = ["default", "multilabel"]
115
+
116
+ def _info(self, config):
117
  return evaluate.MetricInfo(
118
  description=_DESCRIPTION,
119
  citation=_CITATION,
120
  inputs_description=_KWARGS_DESCRIPTION,
121
+ config=config,
122
  features=datasets.Features(
123
  {
124
  "predictions": datasets.Sequence(datasets.Value("int32")),
137
  self,
138
  predictions,
139
  references,
 
 
 
 
 
140
  ):
141
  score = recall_score(
142
  references,
143
  predictions,
144
+ labels=self.config.labels,
145
+ pos_label=self.config.pos_label,
146
+ average=self.config.average,
147
+ sample_weight=self.config.sample_weight,
148
+ zero_division=self.config.zero_division,
149
  )
150
  return {"recall": float(score) if score.size == 1 else score}
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn