|
""" |
|
|
|
Max Perturb Words Constraints |
|
------------------------------- |
|
|
|
|
|
""" |
|
|
|
import math |
|
|
|
from textattack.constraints import Constraint |
|
|
|
|
|
class MaxWordsPerturbed(Constraint): |
|
"""A constraint representing a maximum allowed perturbed words. |
|
|
|
Args: |
|
max_num_words (:obj:`int`, optional): Maximum number of perturbed words allowed. |
|
max_percent (:obj: `float`, optional): Maximum percentage of words allowed to be perturbed. |
|
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. |
|
Otherwise, compare it against the previous `x_adv`. |
|
""" |
|
|
|
def __init__( |
|
self, max_num_words=None, max_percent=None, compare_against_original=True |
|
): |
|
super().__init__(compare_against_original) |
|
if not compare_against_original: |
|
raise ValueError( |
|
"Cannot apply constraint MaxWordsPerturbed with `compare_against_original=False`" |
|
) |
|
|
|
if (max_num_words is None) and (max_percent is None): |
|
raise ValueError("must set either `max_percent` or `max_num_words`") |
|
if max_percent and not (0 <= max_percent <= 1): |
|
raise ValueError("max perc must be between 0 and 1") |
|
self.max_num_words = max_num_words |
|
self.max_percent = max_percent |
|
|
|
def _check_constraint(self, transformed_text, reference_text): |
|
num_words_diff = len(transformed_text.all_words_diff(reference_text)) |
|
if self.max_percent: |
|
min_num_words = min(len(transformed_text.words), len(reference_text.words)) |
|
max_words_perturbed = math.ceil(min_num_words * (self.max_percent)) |
|
max_percent_met = num_words_diff <= max_words_perturbed |
|
else: |
|
max_percent_met = True |
|
if self.max_num_words: |
|
max_num_words_met = num_words_diff <= self.max_num_words |
|
else: |
|
max_num_words_met = True |
|
|
|
return max_percent_met and max_num_words_met |
|
|
|
def extra_repr_keys(self): |
|
metric = [] |
|
if self.max_percent is not None: |
|
metric.append("max_percent") |
|
if self.max_num_words is not None: |
|
metric.append("max_num_words") |
|
return metric + super().extra_repr_keys() |
|
|