File size: 2,235 Bytes
4943752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""

Max Perturb Words Constraints
-------------------------------


"""

import math

from textattack.constraints import Constraint


class MaxWordsPerturbed(Constraint):
    """A constraint representing a maximum allowed perturbed words.

    Args:
        max_num_words (:obj:`int`, optional): Maximum number of perturbed words allowed.
        max_percent (:obj: `float`, optional): Maximum percentage of words allowed to be perturbed.
        compare_against_original (bool):  If `True`, compare new `x_adv` against the original `x`.
            Otherwise, compare it against the previous `x_adv`.
    """

    def __init__(
        self, max_num_words=None, max_percent=None, compare_against_original=True
    ):
        super().__init__(compare_against_original)
        if not compare_against_original:
            raise ValueError(
                "Cannot apply constraint MaxWordsPerturbed with `compare_against_original=False`"
            )

        if (max_num_words is None) and (max_percent is None):
            raise ValueError("must set either `max_percent` or `max_num_words`")
        if max_percent and not (0 <= max_percent <= 1):
            raise ValueError("max perc must be between 0 and 1")
        self.max_num_words = max_num_words
        self.max_percent = max_percent

    def _check_constraint(self, transformed_text, reference_text):
        num_words_diff = len(transformed_text.all_words_diff(reference_text))
        if self.max_percent:
            min_num_words = min(len(transformed_text.words), len(reference_text.words))
            max_words_perturbed = math.ceil(min_num_words * (self.max_percent))
            max_percent_met = num_words_diff <= max_words_perturbed
        else:
            max_percent_met = True
        if self.max_num_words:
            max_num_words_met = num_words_diff <= self.max_num_words
        else:
            max_num_words_met = True

        return max_percent_met and max_num_words_met

    def extra_repr_keys(self):
        metric = []
        if self.max_percent is not None:
            metric.append("max_percent")
        if self.max_num_words is not None:
            metric.append("max_num_words")
        return metric + super().extra_repr_keys()