lvwerra HF staff commited on
Commit
55b6ff7
1 Parent(s): 8027463

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. wer.py +18 -3
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  jiwer
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  jiwer
wer.py CHANGED
@@ -13,6 +13,8 @@
13
  # limitations under the License.
14
  """ Word Error Ratio (WER) metric. """
15
 
 
 
16
  import datasets
17
  from jiwer import compute_measures
18
 
@@ -74,13 +76,26 @@ Examples:
74
  """
75
 
76
 
 
 
 
 
 
 
 
 
77
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
78
  class WER(evaluate.Metric):
79
- def _info(self):
 
 
 
 
80
  return evaluate.MetricInfo(
81
  description=_DESCRIPTION,
82
  citation=_CITATION,
83
  inputs_description=_KWARGS_DESCRIPTION,
 
84
  features=datasets.Features(
85
  {
86
  "predictions": datasets.Value("string", id="sequence"),
@@ -93,8 +108,8 @@ class WER(evaluate.Metric):
93
  ],
94
  )
95
 
96
- def _compute(self, predictions=None, references=None, concatenate_texts=False):
97
- if concatenate_texts:
98
  return compute_measures(references, predictions)["wer"]
99
  else:
100
  incorrect = 0
 
13
  # limitations under the License.
14
  """ Word Error Ratio (WER) metric. """
15
 
16
+ from dataclasses import dataclass
17
+
18
  import datasets
19
  from jiwer import compute_measures
20
 
 
76
  """
77
 
78
 
79
+ @dataclass
80
+ class WERConfig(evaluate.info.Config):
81
+
82
+ name: str = "default"
83
+
84
+ concatenate_texts: bool = False
85
+
86
+
87
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
88
  class WER(evaluate.Metric):
89
+
90
+ CONFIG_CLASS = WERConfig
91
+ ALLOWED_CONFIG_NAMES = ["default", "multilabel"]
92
+
93
+ def _info(self, config):
94
  return evaluate.MetricInfo(
95
  description=_DESCRIPTION,
96
  citation=_CITATION,
97
  inputs_description=_KWARGS_DESCRIPTION,
98
+ config=config,
99
  features=datasets.Features(
100
  {
101
  "predictions": datasets.Value("string", id="sequence"),
 
108
  ],
109
  )
110
 
111
+ def _compute(self, predictions=None, references=None):
112
+ if self.config.concatenate_texts:
113
  return compute_measures(references, predictions)["wer"]
114
  else:
115
  incorrect = 0