""" BERT Score --------------------- BERT Score is introduced in this paper (BERTScore: Evaluating Text Generation with BERT) `arxiv link`_. .. _arxiv link: https://arxiv.org/abs/1904.09675 BERT Score measures token similarity between two text using contextual embedding. To decide which two tokens to compare, it greedily chooses the most similar token from one text and matches it to a token in the second text. """ import bert_score from textattack.constraints import Constraint from textattack.shared import utils class BERTScore(Constraint): """A constraint on BERT-Score difference. Args: min_bert_score (float), minimum threshold value for BERT-Score model_name (str), name of model to use for scoring num_layers (int), number of hidden layers in the model score_type (str), Pick one of following three choices -(1) ``precision`` : match words from candidate text to reference text -(2) ``recall`` : match words from reference text to candidate text -(3) ``f1``: harmonic mean of precision and recall (recommended) compare_against_original (bool): If ``True``, compare new ``x_adv`` against the original ``x``. Otherwise, compare it against the previous ``x_adv``. """ SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2} def __init__( self, min_bert_score, model_name="bert-base-uncased", num_layers=None, score_type="f1", compare_against_original=True, ): super().__init__(compare_against_original) if not isinstance(min_bert_score, float): raise TypeError("max_bert_score must be a float") if min_bert_score < 0.0 or min_bert_score > 1.0: raise ValueError("max_bert_score must be a value between 0.0 and 1.0") self.min_bert_score = min_bert_score self.model = model_name self.score_type = score_type # Turn off idf-weighting scheme b/c reference sentence set is small self._bert_scorer = bert_score.BERTScorer( model_type=model_name, idf=False, device=utils.device, num_layers=num_layers ) def _check_constraint(self, transformed_text, reference_text): """Return `True` if BERT Score between `transformed_text` and `reference_text` is lower than minimum BERT Score.""" cand = transformed_text.text ref = reference_text.text result = self._bert_scorer.score([cand], [ref]) score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item() if score >= self.min_bert_score: return True else: return False def extra_repr_keys(self): return ["min_bert_score", "model", "score_type"] + super().extra_repr_keys()