File size: 8,235 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""
Misc Checkpoints
===================

The ``AttackCheckpoint`` class saves in-progress attacks and loads saved attacks from disk.
"""
import copy
import datetime
import os
import pickle
import time

import textattack
from textattack.attack_results import (
    FailedAttackResult,
    MaximizedAttackResult,
    SkippedAttackResult,
    SuccessfulAttackResult,
)
from textattack.shared import logger, utils

# TODO: Consider still keeping the old `Checkpoint` class and allow older checkpoints to be loaded to new TextAttack


class AttackCheckpoint:
    """An object that stores necessary information for saving and loading
    checkpoints.

    Args:
        attack_args (textattack.AttackArgs): Arguments of the original attack
        attack_log_manager (textattack.loggers.AttackLogManager): Object for storing attack results
        worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset.
        worklist_candidates (int): List of other available examples we can attack. Used to get the next dataset element when `attack_n=True`.
        chkpt_time (float): epoch time representing when checkpoint was made
    """

    def __init__(
        self,
        attack_args,
        attack_log_manager,
        worklist,
        worklist_candidates,
        chkpt_time=None,
    ):
        assert isinstance(
            attack_args, textattack.AttackArgs
        ), "`attack_args` must be of type `textattack.AttackArgs`."
        assert isinstance(
            attack_log_manager, textattack.loggers.AttackLogManager
        ), "`attack_log_manager` must be of type `textattack.loggers.AttackLogManager`."

        self.attack_args = copy.deepcopy(attack_args)
        self.attack_log_manager = attack_log_manager
        self.worklist = worklist
        self.worklist_candidates = worklist_candidates
        if chkpt_time:
            self.time = chkpt_time
        else:
            self.time = time.time()

        self._verify()

    def __repr__(self):
        main_str = "AttackCheckpoint("
        lines = []
        lines.append(utils.add_indent(f"(Time):  {self.datetime}", 2))

        args_lines = []
        recipe_set = (
            True
            if "recipe" in self.attack_args.__dict__
            and self.attack_args.__dict__["recipe"]
            else False
        )
        mutually_exclusive_args = ["search", "transformation", "constraints", "recipe"]
        if recipe_set:
            args_lines.append(
                utils.add_indent(f'(recipe): {self.attack_args.__dict__["recipe"]}', 2)
            )
        else:
            args_lines.append(
                utils.add_indent(f'(search): {self.attack_args.__dict__["search"]}', 2)
            )
            args_lines.append(
                utils.add_indent(
                    f'(transformation): {self.attack_args.__dict__["transformation"]}',
                    2,
                )
            )
            args_lines.append(
                utils.add_indent(
                    f'(constraints): {self.attack_args.__dict__["constraints"]}', 2
                )
            )

        for key in self.attack_args.__dict__:
            if key not in mutually_exclusive_args:
                args_lines.append(
                    utils.add_indent(f"({key}): {self.attack_args.__dict__[key]}", 2)
                )

        args_str = utils.add_indent("\n" + "\n".join(args_lines), 2)
        lines.append(utils.add_indent(f"(attack_args):  {args_str}", 2))

        attack_logger_lines = []
        attack_logger_lines.append(
            utils.add_indent(
                f"(Total number of examples to attack): {self.attack_args.num_examples}",
                2,
            )
        )
        attack_logger_lines.append(
            utils.add_indent(f"(Number of attacks performed): {self.results_count}", 2)
        )
        attack_logger_lines.append(
            utils.add_indent(
                f"(Number of remaining attacks): {self.num_remaining_attacks}", 2
            )
        )
        breakdown_lines = []
        breakdown_lines.append(
            utils.add_indent(
                f"(Number of successful attacks): {self.num_successful_attacks}", 2
            )
        )
        breakdown_lines.append(
            utils.add_indent(
                f"(Number of failed attacks): {self.num_failed_attacks}", 2
            )
        )
        breakdown_lines.append(
            utils.add_indent(
                f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
            )
        )
        breakdown_lines.append(
            utils.add_indent(
                f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
            )
        )
        breakdown_str = utils.add_indent("\n" + "\n".join(breakdown_lines), 2)
        attack_logger_lines.append(
            utils.add_indent(f"(Latest result breakdown): {breakdown_str}", 2)
        )
        attack_logger_str = utils.add_indent("\n" + "\n".join(attack_logger_lines), 2)
        lines.append(
            utils.add_indent(f"(Previous attack summary):  {attack_logger_str}", 2)
        )

        main_str += "\n  " + "\n  ".join(lines) + "\n"
        main_str += ")"
        return main_str

    __str__ = __repr__

    @property
    def results_count(self):
        """Return number of attacks made so far."""
        return len(self.attack_log_manager.results)

    @property
    def num_skipped_attacks(self):
        return sum(
            isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results
        )

    @property
    def num_failed_attacks(self):
        return sum(
            isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results
        )

    @property
    def num_successful_attacks(self):
        return sum(
            isinstance(r, SuccessfulAttackResult)
            for r in self.attack_log_manager.results
        )

    @property
    def num_maximized_attacks(self):
        return sum(
            isinstance(r, MaximizedAttackResult)
            for r in self.attack_log_manager.results
        )

    @property
    def num_remaining_attacks(self):
        if self.attack_args.attack_n:
            non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks
            count = self.attack_args.num_examples - non_skipped_attacks
        else:
            count = self.attack_args.num_examples - self.results_count
        return count

    @property
    def dataset_offset(self):
        """Calculate offset into the dataset to start from."""
        # Original offset + # of results processed so far
        return self.attack_args.num_examples_offset + self.results_count

    @property
    def datetime(self):
        return datetime.datetime.fromtimestamp(self.time).strftime("%Y-%m-%d %H:%M:%S")

    def save(self, quiet=False):
        file_name = "{}.ta.chkpt".format(int(self.time * 1000))
        if not os.path.exists(self.attack_args.checkpoint_dir):
            os.makedirs(self.attack_args.checkpoint_dir)
        path = os.path.join(self.attack_args.checkpoint_dir, file_name)
        if not quiet:
            print("\n\n" + "=" * 125)
            logger.info(
                'Saving checkpoint under "{}" at {} after {} attacks.'.format(
                    path, self.datetime, self.results_count
                )
            )
            print("=" * 125 + "\n")
        with open(path, "wb") as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)

    @classmethod
    def load(cls, path):
        with open(path, "rb") as f:
            checkpoint = pickle.load(f)
        assert isinstance(checkpoint, cls)

        return checkpoint

    def _verify(self):
        """Check that the checkpoint has no duplicates and is consistent."""
        assert self.num_remaining_attacks == len(
            self.worklist
        ), "Recorded number of remaining attacks and size of worklist are different."

        results_set = {
            result.original_text for result in self.attack_log_manager.results
        }
        assert (
            len(results_set) == self.results_count
        ), "Duplicate `AttackResults` found."