lvwerra HF staff commited on
Commit
1a6c1ef
1 Parent(s): 0df3a9f

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. f1.py +39 -11
  2. requirements.txt +1 -1
f1.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """F1 metric."""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import f1_score
18
 
@@ -52,30 +55,34 @@ Examples:
52
  {'f1': 0.5}
53
 
54
  Example 2-The same simple binary example as in Example 1, but with `pos_label` set to `0`.
55
- >>> f1_metric = evaluate.load("f1")
56
- >>> results = f1_metric.compute(references=[0, 1, 0, 1, 0], predictions=[0, 0, 1, 1, 0], pos_label=0)
57
  >>> print(round(results['f1'], 2))
58
  0.67
59
 
60
  Example 3-The same simple binary example as in Example 1, but with `sample_weight` included.
61
- >>> f1_metric = evaluate.load("f1")
62
- >>> results = f1_metric.compute(references=[0, 1, 0, 1, 0], predictions=[0, 0, 1, 1, 0], sample_weight=[0.9, 0.5, 3.9, 1.2, 0.3])
63
  >>> print(round(results['f1'], 2))
64
  0.35
65
 
66
  Example 4-A multiclass example, with different values for the `average` input.
 
67
  >>> predictions = [0, 2, 1, 0, 0, 1]
68
  >>> references = [0, 1, 2, 0, 1, 2]
69
- >>> results = f1_metric.compute(predictions=predictions, references=references, average="macro")
70
  >>> print(round(results['f1'], 2))
71
  0.27
72
- >>> results = f1_metric.compute(predictions=predictions, references=references, average="micro")
 
73
  >>> print(round(results['f1'], 2))
74
  0.33
75
- >>> results = f1_metric.compute(predictions=predictions, references=references, average="weighted")
 
76
  >>> print(round(results['f1'], 2))
77
  0.27
78
- >>> results = f1_metric.compute(predictions=predictions, references=references, average=None)
 
79
  >>> print(results)
80
  {'f1': array([0.8, 0. , 0. ])}
81
 
@@ -102,13 +109,29 @@ _CITATION = """
102
  """
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
105
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
106
  class F1(evaluate.Metric):
107
- def _info(self):
 
 
 
 
108
  return evaluate.MetricInfo(
109
  description=_DESCRIPTION,
110
  citation=_CITATION,
111
  inputs_description=_KWARGS_DESCRIPTION,
 
112
  features=datasets.Features(
113
  {
114
  "predictions": datasets.Sequence(datasets.Value("int32")),
@@ -123,8 +146,13 @@ class F1(evaluate.Metric):
123
  reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
124
  )
125
 
126
- def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
127
  score = f1_score(
128
- references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight
 
 
 
 
 
129
  )
130
  return {"f1": float(score) if score.size == 1 else score}
 
13
  # limitations under the License.
14
  """F1 metric."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Union
18
+
19
  import datasets
20
  from sklearn.metrics import f1_score
21
 
 
55
  {'f1': 0.5}
56
 
57
  Example 2-The same simple binary example as in Example 1, but with `pos_label` set to `0`.
58
+ >>> f1_metric = evaluate.load("f1", pos_label=0)
59
+ >>> results = f1_metric.compute(references=[0, 1, 0, 1, 0], predictions=[0, 0, 1, 1, 0])
60
  >>> print(round(results['f1'], 2))
61
  0.67
62
 
63
  Example 3-The same simple binary example as in Example 1, but with `sample_weight` included.
64
+ >>> f1_metric = evaluate.load("f1", sample_weight=[0.9, 0.5, 3.9, 1.2, 0.3])
65
+ >>> results = f1_metric.compute(references=[0, 1, 0, 1, 0], predictions=[0, 0, 1, 1, 0])
66
  >>> print(round(results['f1'], 2))
67
  0.35
68
 
69
  Example 4-A multiclass example, with different values for the `average` input.
70
+ >>> f1_metric = evaluate.load("f1", average="macro")
71
  >>> predictions = [0, 2, 1, 0, 0, 1]
72
  >>> references = [0, 1, 2, 0, 1, 2]
73
+ >>> results = f1_metric.compute(predictions=predictions, references=references)
74
  >>> print(round(results['f1'], 2))
75
  0.27
76
+ >>> f1_metric = evaluate.load("f1", average="micro")
77
+ >>> results = f1_metric.compute(predictions=predictions, references=references)
78
  >>> print(round(results['f1'], 2))
79
  0.33
80
+ >>> f1_metric = evaluate.load("f1", average="weighted")
81
+ >>> results = f1_metric.compute(predictions=predictions, references=references)
82
  >>> print(round(results['f1'], 2))
83
  0.27
84
+ >>> f1_metric = evaluate.load("f1", average=None)
85
+ >>> results = f1_metric.compute(predictions=predictions, references=references)
86
  >>> print(results)
87
  {'f1': array([0.8, 0. , 0. ])}
88
 
 
109
  """
110
 
111
 
112
+ @dataclass
113
+ class F1Config(evaluate.info.Config):
114
+
115
+ name: str = "default"
116
+
117
+ pos_label: Union[str, int] = 1
118
+ average: str = "binary"
119
+ labels: Optional[List[str]] = None
120
+ sample_weight: Optional[List[float]] = None
121
+
122
+
123
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
124
  class F1(evaluate.Metric):
125
+
126
+ CONFIG_CLASS = F1Config
127
+ ALLOWED_CONFIG_NAMES = ["default", "multilabel"]
128
+
129
+ def _info(self, config):
130
  return evaluate.MetricInfo(
131
  description=_DESCRIPTION,
132
  citation=_CITATION,
133
  inputs_description=_KWARGS_DESCRIPTION,
134
+ config=config,
135
  features=datasets.Features(
136
  {
137
  "predictions": datasets.Sequence(datasets.Value("int32")),
 
146
  reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
147
  )
148
 
149
+ def _compute(self, predictions, references):
150
  score = f1_score(
151
+ references,
152
+ predictions,
153
+ labels=self.config.labels,
154
+ pos_label=self.config.pos_label,
155
+ average=self.config.average,
156
+ sample_weight=self.config.sample_weight,
157
  )
158
  return {"f1": float(score) if score.size == 1 else score}
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn