anonymous8/RPD-Demo
initial commit
4943752
raw
history blame contribute delete
No virus
2.82 kB
"""
BERT Score
---------------------
BERT Score is introduced in this paper (BERTScore: Evaluating Text Generation with BERT) `arxiv link`_.
.. _arxiv link: https://arxiv.org/abs/1904.09675
BERT Score measures token similarity between two text using contextual embedding.
To decide which two tokens to compare, it greedily chooses the most similar token from one text and matches it to a token in the second text.
"""
import bert_score
from textattack.constraints import Constraint
from textattack.shared import utils
class BERTScore(Constraint):
"""A constraint on BERT-Score difference.
Args:
min_bert_score (float), minimum threshold value for BERT-Score
model_name (str), name of model to use for scoring
num_layers (int), number of hidden layers in the model
score_type (str), Pick one of following three choices
-(1) ``precision`` : match words from candidate text to reference text
-(2) ``recall`` : match words from reference text to candidate text
-(3) ``f1``: harmonic mean of precision and recall (recommended)
compare_against_original (bool):
If ``True``, compare new ``x_adv`` against the original ``x``.
Otherwise, compare it against the previous ``x_adv``.
"""
SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2}
def __init__(
self,
min_bert_score,
model_name="bert-base-uncased",
num_layers=None,
score_type="f1",
compare_against_original=True,
):
super().__init__(compare_against_original)
if not isinstance(min_bert_score, float):
raise TypeError("max_bert_score must be a float")
if min_bert_score < 0.0 or min_bert_score > 1.0:
raise ValueError("max_bert_score must be a value between 0.0 and 1.0")
self.min_bert_score = min_bert_score
self.model = model_name
self.score_type = score_type
# Turn off idf-weighting scheme b/c reference sentence set is small
self._bert_scorer = bert_score.BERTScorer(
model_type=model_name, idf=False, device=utils.device, num_layers=num_layers
)
def _check_constraint(self, transformed_text, reference_text):
"""Return `True` if BERT Score between `transformed_text` and
`reference_text` is lower than minimum BERT Score."""
cand = transformed_text.text
ref = reference_text.text
result = self._bert_scorer.score([cand], [ref])
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item()
if score >= self.min_bert_score:
return True
else:
return False
def extra_repr_keys(self):
return ["min_bert_score", "model", "score_type"] + super().extra_repr_keys()