File size: 7,950 Bytes
0f81934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f814a86
b0f3cfd
f814a86
0f81934
 
 
 
8357825
 
 
 
 
 
 
 
 
 
 
 
 
 
0f81934
 
 
 
 
 
 
 
 
 
 
b0f3cfd
 
 
 
 
 
 
 
 
 
 
 
0f81934
8357825
 
 
 
b0f3cfd
8357825
 
 
 
0f81934
 
 
 
b0f3cfd
 
 
0f81934
b3dd9f7
 
0f81934
 
 
b0f3cfd
 
b3dd9f7
 
 
 
b0f3cfd
0f81934
b0f3cfd
0f81934
 
 
 
 
b0f3cfd
 
 
0f81934
 
 
 
 
 
 
b0f3cfd
 
 
 
8357825
b0f3cfd
 
 
 
0f81934
 
 
 
 
b0f3cfd
 
0f81934
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright 2020 The HuggingFace Evaluate Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERTScore metric. """

import functools
from contextlib import contextmanager

import bert_score
import datasets
from packaging import version

import evaluate


@contextmanager
def filter_logging_context():
    def filter_log(record):
        return False if "This IS expected if you are initializing" in record.msg else True

    logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
    logger.addFilter(filter_log)
    try:
        yield
    finally:
        logger.removeFilter(filter_log)


_CITATION = """\
@inproceedings{bert-score,
  title={BERTScore: Evaluating Text Generation with BERT},
  author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi},
  booktitle={International Conference on Learning Representations},
  year={2020},
  url={https://openreview.net/forum?id=SkeHuCVFDr}
}
"""

_DESCRIPTION = """\
BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference
sentences by cosine similarity.
It has been shown to correlate with human judgment on sentence-level and system-level evaluation.
Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language
generation tasks.

See the project's README at https://github.com/Tiiiger/bert_score#readme for more information.
"""

_KWARGS_DESCRIPTION = """
BERTScore Metrics with the hashcode from a source against one or more references.

Args:
    predictions (list of str): Prediction/candidate sentences.
    references (list of str or list of list of str): Reference sentences.
    lang (str): Language of the sentences; required (e.g. 'en').
    model_type (str): Bert specification, default using the suggested
        model for the target language; has to specify at least one of
        `model_type` or `lang`.
    num_layers (int): The layer of representation to use,
        default using the number of layers tuned on WMT16 correlation data.
    verbose (bool): Turn on intermediate status update.
    idf (bool or dict): Use idf weighting; can also be a precomputed idf_dict.
    device (str): On which the contextual embedding model will be allocated on.
        If this argument is None, the model lives on cuda:0 if cuda is available.
    nthreads (int): Number of threads.
    batch_size (int): Bert score processing batch size,
        at least one of `model_type` or `lang`. `lang` needs to be
        specified when `rescale_with_baseline` is True.
    rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
    baseline_path (str): Customized baseline file.
    use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10.

Returns:
    precision: Precision.
    recall: Recall.
    f1: F1 score.
    hashcode: Hashcode of the library.

Examples:

    >>> predictions = ["hello there", "general kenobi"]
    >>> references = ["hello there", "general kenobi"]
    >>> bertscore = evaluate.load("bertscore")
    >>> results = bertscore.compute(predictions=predictions, references=references, lang="en")
    >>> print([round(v, 2) for v in results["f1"]])
    [1.0, 1.0]
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class BERTScore(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            homepage="https://github.com/Tiiiger/bert_score",
            inputs_description=_KWARGS_DESCRIPTION,
            features=[
                datasets.Features(
                    {
                        "predictions": datasets.Value("string", id="sequence"),
                        "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
                    }
                ),
                datasets.Features(
                    {
                        "predictions": datasets.Value("string", id="sequence"),
                        "references": datasets.Value("string", id="sequence"),
                    }
                ),
            ],
            codebase_urls=["https://github.com/Tiiiger/bert_score"],
            reference_urls=[
                "https://github.com/Tiiiger/bert_score",
                "https://arxiv.org/abs/1904.09675",
            ],
        )

    def _compute(
        self,
        predictions,
        references,
        lang=None,
        model_type=None,
        num_layers=None,
        verbose=False,
        idf=False,
        device=None,
        batch_size=64,
        nthreads=4,
        all_layers=False,
        rescale_with_baseline=False,
        baseline_path=None,
        use_fast_tokenizer=False,
    ):

        if isinstance(references[0], str):
            references = [[ref] for ref in references]

        if idf:
            idf_sents = [r for ref in references for r in ref]
        else:
            idf_sents = None

        get_hash = bert_score.utils.get_hash
        scorer = bert_score.BERTScorer

        if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
            get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer)
            scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer)
        elif use_fast_tokenizer:
            raise ImportWarning(
                "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
                "`bert-score` doesn't match this condition.\n"
                'You can install it with `pip install "bert-score>=0.3.10"`.'
            )

        if model_type is None:
            if lang is None:
                raise ValueError(
                    "Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
                    " must be specified"
                )
            model_type = bert_score.utils.lang2model[lang.lower()]

        if num_layers is None:
            num_layers = bert_score.utils.model2layers[model_type]

        hashcode = get_hash(
            model=model_type,
            num_layers=num_layers,
            idf=idf,
            rescale_with_baseline=rescale_with_baseline,
            use_custom_baseline=baseline_path is not None,
        )

        with filter_logging_context():
            if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode:
                self.cached_bertscorer = scorer(
                    model_type=model_type,
                    num_layers=num_layers,
                    batch_size=batch_size,
                    nthreads=nthreads,
                    all_layers=all_layers,
                    idf=idf,
                    idf_sents=idf_sents,
                    device=device,
                    lang=lang,
                    rescale_with_baseline=rescale_with_baseline,
                    baseline_path=baseline_path,
                )

        (P, R, F) = self.cached_bertscorer.score(
            cands=predictions,
            refs=references,
            verbose=verbose,
            batch_size=batch_size,
        )
        output_dict = {
            "precision": P.tolist(),
            "recall": R.tolist(),
            "f1": F.tolist(),
            "hashcode": hashcode,
        }
        return output_dict