lvwerra HF staff commited on
Commit
eadb728
1 Parent(s): d2e253a

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. bertscore.py +54 -34
  2. requirements.txt +1 -1
bertscore.py CHANGED
@@ -15,6 +15,8 @@
15
 
16
  import functools
17
  from contextlib import contextmanager
 
 
18
 
19
  import bert_score
20
  import datasets
@@ -97,14 +99,42 @@ Examples:
97
  """
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
101
  class BERTScore(evaluate.Metric):
102
- def _info(self):
 
 
 
103
  return evaluate.MetricInfo(
104
  description=_DESCRIPTION,
105
  citation=_CITATION,
106
  homepage="https://github.com/Tiiiger/bert_score",
107
  inputs_description=_KWARGS_DESCRIPTION,
 
108
  features=[
109
  datasets.Features(
110
  {
@@ -130,24 +160,12 @@ class BERTScore(evaluate.Metric):
130
  self,
131
  predictions,
132
  references,
133
- lang=None,
134
- model_type=None,
135
- num_layers=None,
136
- verbose=False,
137
- idf=False,
138
- device=None,
139
- batch_size=64,
140
- nthreads=4,
141
- all_layers=False,
142
- rescale_with_baseline=False,
143
- baseline_path=None,
144
- use_fast_tokenizer=False,
145
  ):
146
 
147
  if isinstance(references[0], str):
148
  references = [[ref] for ref in references]
149
 
150
- if idf:
151
  idf_sents = [r for ref in references for r in ref]
152
  else:
153
  idf_sents = None
@@ -156,32 +174,34 @@ class BERTScore(evaluate.Metric):
156
  scorer = bert_score.BERTScorer
157
 
158
  if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
159
- get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer)
160
- scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer)
161
- elif use_fast_tokenizer:
162
  raise ImportWarning(
163
  "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
164
  "`bert-score` doesn't match this condition.\n"
165
  'You can install it with `pip install "bert-score>=0.3.10"`.'
166
  )
167
 
168
- if model_type is None:
169
- if lang is None:
170
  raise ValueError(
171
  "Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
172
  " must be specified"
173
  )
174
- model_type = bert_score.utils.lang2model[lang.lower()]
 
 
175
 
176
- if num_layers is None:
177
  num_layers = bert_score.utils.model2layers[model_type]
178
 
179
  hashcode = get_hash(
180
  model=model_type,
181
  num_layers=num_layers,
182
- idf=idf,
183
- rescale_with_baseline=rescale_with_baseline,
184
- use_custom_baseline=baseline_path is not None,
185
  )
186
 
187
  with filter_logging_context():
@@ -189,22 +209,22 @@ class BERTScore(evaluate.Metric):
189
  self.cached_bertscorer = scorer(
190
  model_type=model_type,
191
  num_layers=num_layers,
192
- batch_size=batch_size,
193
- nthreads=nthreads,
194
- all_layers=all_layers,
195
- idf=idf,
196
  idf_sents=idf_sents,
197
- device=device,
198
- lang=lang,
199
- rescale_with_baseline=rescale_with_baseline,
200
- baseline_path=baseline_path,
201
  )
202
 
203
  (P, R, F) = self.cached_bertscorer.score(
204
  cands=predictions,
205
  refs=references,
206
- verbose=verbose,
207
- batch_size=batch_size,
208
  )
209
  output_dict = {
210
  "precision": P.tolist(),
 
15
 
16
  import functools
17
  from contextlib import contextmanager
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Union
20
 
21
  import bert_score
22
  import datasets
 
99
  """
100
 
101
 
102
+ @dataclass
103
+ class BERTScoreConfig(evaluate.info.Config):
104
+
105
+ name: str = "default"
106
+
107
+ pos_label: Union[str, int] = 1
108
+ average: str = "binary"
109
+ lang: Optional[str] = None
110
+ sample_weight: Optional[List[float]] = None
111
+
112
+ lang: Optional[str] = None
113
+ model_type: Optional[str] = None
114
+ num_layers: Optional[int] = None
115
+ verbose: bool = False
116
+ idf = bool = False
117
+ device: Optional[str] = None
118
+ batch_size: int = 64
119
+ nthreads: int = 4
120
+ all_layers: bool = False
121
+ rescale_with_baseline: bool = False
122
+ baseline_path: Optional[str] = None
123
+ use_fast_tokenizer: bool = False
124
+
125
+
126
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
127
  class BERTScore(evaluate.Metric):
128
+ CONFIG_CLASS = BERTScoreConfig
129
+ ALLOWED_CONFIG_NAMES = ["default"]
130
+
131
+ def _info(self, config):
132
  return evaluate.MetricInfo(
133
  description=_DESCRIPTION,
134
  citation=_CITATION,
135
  homepage="https://github.com/Tiiiger/bert_score",
136
  inputs_description=_KWARGS_DESCRIPTION,
137
+ config=config,
138
  features=[
139
  datasets.Features(
140
  {
 
160
  self,
161
  predictions,
162
  references,
 
 
 
 
 
 
 
 
 
 
 
 
163
  ):
164
 
165
  if isinstance(references[0], str):
166
  references = [[ref] for ref in references]
167
 
168
+ if self.config.idf:
169
  idf_sents = [r for ref in references for r in ref]
170
  else:
171
  idf_sents = None
 
174
  scorer = bert_score.BERTScorer
175
 
176
  if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
177
+ get_hash = functools.partial(get_hash, use_fast_tokenizer=self.config.use_fast_tokenizer)
178
+ scorer = functools.partial(scorer, use_fast_tokenizer=self.config.use_fast_tokenizer)
179
+ elif self.config.use_fast_tokenizer:
180
  raise ImportWarning(
181
  "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
182
  "`bert-score` doesn't match this condition.\n"
183
  'You can install it with `pip install "bert-score>=0.3.10"`.'
184
  )
185
 
186
+ if self.config.model_type is None:
187
+ if self.config.lang is None:
188
  raise ValueError(
189
  "Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
190
  " must be specified"
191
  )
192
+ model_type = bert_score.utils.lang2model[self.config.lang.lower()]
193
+ else:
194
+ model_type = self.config.model_type
195
 
196
+ if self.config.num_layers is None:
197
  num_layers = bert_score.utils.model2layers[model_type]
198
 
199
  hashcode = get_hash(
200
  model=model_type,
201
  num_layers=num_layers,
202
+ idf=self.config.idf,
203
+ rescale_with_baseline=self.config.rescale_with_baseline,
204
+ use_custom_baseline=self.config.baseline_path is not None,
205
  )
206
 
207
  with filter_logging_context():
 
209
  self.cached_bertscorer = scorer(
210
  model_type=model_type,
211
  num_layers=num_layers,
212
+ batch_size=self.config.batch_size,
213
+ nthreads=self.config.nthreads,
214
+ all_layers=self.config.all_layers,
215
+ idf=self.config.idf,
216
  idf_sents=idf_sents,
217
+ device=self.config.device,
218
+ lang=self.config.lang,
219
+ rescale_with_baseline=self.config.rescale_with_baseline,
220
+ baseline_path=self.config.baseline_path,
221
  )
222
 
223
  (P, R, F) = self.cached_bertscorer.score(
224
  cands=predictions,
225
  refs=references,
226
+ verbose=self.config.verbose,
227
+ batch_size=self.config.batch_size,
228
  )
229
  output_dict = {
230
  "precision": P.tolist(),
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  bert_score
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  bert_score