lvwerra HF staff commited on
Commit
8357825
1 Parent(s): 8283662

Update Space (evaluate main: 8e481b15)

Browse files
Files changed (1) hide show
  1. bertscore.py +24 -22
bertscore.py CHANGED
@@ -105,12 +105,20 @@ class BERTScore(evaluate.Metric):
105
  citation=_CITATION,
106
  homepage="https://github.com/Tiiiger/bert_score",
107
  inputs_description=_KWARGS_DESCRIPTION,
108
- features=datasets.Features(
109
- {
110
- "predictions": datasets.Value("string", id="sequence"),
111
- "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
112
- }
113
- ),
 
 
 
 
 
 
 
 
114
  codebase_urls=["https://github.com/Tiiiger/bert_score"],
115
  reference_urls=[
116
  "https://github.com/Tiiiger/bert_score",
@@ -135,6 +143,15 @@ class BERTScore(evaluate.Metric):
135
  baseline_path=None,
136
  use_fast_tokenizer=False,
137
  ):
 
 
 
 
 
 
 
 
 
138
  get_hash = bert_score.utils.get_hash
139
  scorer = bert_score.BERTScorer
140
 
@@ -171,6 +188,7 @@ class BERTScore(evaluate.Metric):
171
  nthreads=nthreads,
172
  all_layers=all_layers,
173
  idf=idf,
 
174
  device=device,
175
  lang=lang,
176
  rescale_with_baseline=rescale_with_baseline,
@@ -190,19 +208,3 @@ class BERTScore(evaluate.Metric):
190
  "hashcode": hashcode,
191
  }
192
  return output_dict
193
-
194
- def add_batch(self, predictions=None, references=None, **kwargs):
195
- """Add a batch of predictions and references for the metric's stack."""
196
- # References can be strings or lists of strings
197
- # Let's change strings to lists of strings with one element
198
- if references is not None:
199
- references = [[ref] if isinstance(ref, str) else ref for ref in references]
200
- super().add_batch(predictions=predictions, references=references, **kwargs)
201
-
202
- def add(self, prediction=None, reference=None, **kwargs):
203
- """Add one prediction and reference for the metric's stack."""
204
- # References can be strings or lists of strings
205
- # Let's change strings to lists of strings with one element
206
- if isinstance(reference, str):
207
- reference = [reference]
208
- super().add(prediction=prediction, reference=reference, **kwargs)
 
105
  citation=_CITATION,
106
  homepage="https://github.com/Tiiiger/bert_score",
107
  inputs_description=_KWARGS_DESCRIPTION,
108
+ features=[
109
+ datasets.Features(
110
+ {
111
+ "predictions": datasets.Value("string", id="sequence"),
112
+ "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
113
+ }
114
+ ),
115
+ datasets.Features(
116
+ {
117
+ "predictions": datasets.Value("string", id="sequence"),
118
+ "references": datasets.Value("string", id="sequence"),
119
+ }
120
+ ),
121
+ ],
122
  codebase_urls=["https://github.com/Tiiiger/bert_score"],
123
  reference_urls=[
124
  "https://github.com/Tiiiger/bert_score",
 
143
  baseline_path=None,
144
  use_fast_tokenizer=False,
145
  ):
146
+
147
+ if isinstance(references[0], str):
148
+ references = [[ref] for ref in references]
149
+
150
+ if idf:
151
+ idf_sents = [r for ref in references for r in ref]
152
+ else:
153
+ idf_sents = None
154
+
155
  get_hash = bert_score.utils.get_hash
156
  scorer = bert_score.BERTScorer
157
 
 
188
  nthreads=nthreads,
189
  all_layers=all_layers,
190
  idf=idf,
191
+ idf_sents=idf_sents,
192
  device=device,
193
  lang=lang,
194
  rescale_with_baseline=rescale_with_baseline,
 
208
  "hashcode": hashcode,
209
  }
210
  return output_dict