|
""" |
|
Misc Checkpoints |
|
=================== |
|
|
|
The ``AttackCheckpoint`` class saves in-progress attacks and loads saved attacks from disk. |
|
""" |
|
import copy |
|
import datetime |
|
import os |
|
import pickle |
|
import time |
|
|
|
import textattack |
|
from textattack.attack_results import ( |
|
FailedAttackResult, |
|
MaximizedAttackResult, |
|
SkippedAttackResult, |
|
SuccessfulAttackResult, |
|
) |
|
from textattack.shared import logger, utils |
|
|
|
|
|
|
|
|
|
class AttackCheckpoint: |
|
"""An object that stores necessary information for saving and loading |
|
checkpoints. |
|
|
|
Args: |
|
attack_args (textattack.AttackArgs): Arguments of the original attack |
|
attack_log_manager (textattack.loggers.AttackLogManager): Object for storing attack results |
|
worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset. |
|
worklist_candidates (int): List of other available examples we can attack. Used to get the next dataset element when `attack_n=True`. |
|
chkpt_time (float): epoch time representing when checkpoint was made |
|
""" |
|
|
|
def __init__( |
|
self, |
|
attack_args, |
|
attack_log_manager, |
|
worklist, |
|
worklist_candidates, |
|
chkpt_time=None, |
|
): |
|
assert isinstance( |
|
attack_args, textattack.AttackArgs |
|
), "`attack_args` must be of type `textattack.AttackArgs`." |
|
assert isinstance( |
|
attack_log_manager, textattack.loggers.AttackLogManager |
|
), "`attack_log_manager` must be of type `textattack.loggers.AttackLogManager`." |
|
|
|
self.attack_args = copy.deepcopy(attack_args) |
|
self.attack_log_manager = attack_log_manager |
|
self.worklist = worklist |
|
self.worklist_candidates = worklist_candidates |
|
if chkpt_time: |
|
self.time = chkpt_time |
|
else: |
|
self.time = time.time() |
|
|
|
self._verify() |
|
|
|
def __repr__(self): |
|
main_str = "AttackCheckpoint(" |
|
lines = [] |
|
lines.append(utils.add_indent(f"(Time): {self.datetime}", 2)) |
|
|
|
args_lines = [] |
|
recipe_set = ( |
|
True |
|
if "recipe" in self.attack_args.__dict__ |
|
and self.attack_args.__dict__["recipe"] |
|
else False |
|
) |
|
mutually_exclusive_args = ["search", "transformation", "constraints", "recipe"] |
|
if recipe_set: |
|
args_lines.append( |
|
utils.add_indent(f'(recipe): {self.attack_args.__dict__["recipe"]}', 2) |
|
) |
|
else: |
|
args_lines.append( |
|
utils.add_indent(f'(search): {self.attack_args.__dict__["search"]}', 2) |
|
) |
|
args_lines.append( |
|
utils.add_indent( |
|
f'(transformation): {self.attack_args.__dict__["transformation"]}', |
|
2, |
|
) |
|
) |
|
args_lines.append( |
|
utils.add_indent( |
|
f'(constraints): {self.attack_args.__dict__["constraints"]}', 2 |
|
) |
|
) |
|
|
|
for key in self.attack_args.__dict__: |
|
if key not in mutually_exclusive_args: |
|
args_lines.append( |
|
utils.add_indent(f"({key}): {self.attack_args.__dict__[key]}", 2) |
|
) |
|
|
|
args_str = utils.add_indent("\n" + "\n".join(args_lines), 2) |
|
lines.append(utils.add_indent(f"(attack_args): {args_str}", 2)) |
|
|
|
attack_logger_lines = [] |
|
attack_logger_lines.append( |
|
utils.add_indent( |
|
f"(Total number of examples to attack): {self.attack_args.num_examples}", |
|
2, |
|
) |
|
) |
|
attack_logger_lines.append( |
|
utils.add_indent(f"(Number of attacks performed): {self.results_count}", 2) |
|
) |
|
attack_logger_lines.append( |
|
utils.add_indent( |
|
f"(Number of remaining attacks): {self.num_remaining_attacks}", 2 |
|
) |
|
) |
|
breakdown_lines = [] |
|
breakdown_lines.append( |
|
utils.add_indent( |
|
f"(Number of successful attacks): {self.num_successful_attacks}", 2 |
|
) |
|
) |
|
breakdown_lines.append( |
|
utils.add_indent( |
|
f"(Number of failed attacks): {self.num_failed_attacks}", 2 |
|
) |
|
) |
|
breakdown_lines.append( |
|
utils.add_indent( |
|
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2 |
|
) |
|
) |
|
breakdown_lines.append( |
|
utils.add_indent( |
|
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 |
|
) |
|
) |
|
breakdown_str = utils.add_indent("\n" + "\n".join(breakdown_lines), 2) |
|
attack_logger_lines.append( |
|
utils.add_indent(f"(Latest result breakdown): {breakdown_str}", 2) |
|
) |
|
attack_logger_str = utils.add_indent("\n" + "\n".join(attack_logger_lines), 2) |
|
lines.append( |
|
utils.add_indent(f"(Previous attack summary): {attack_logger_str}", 2) |
|
) |
|
|
|
main_str += "\n " + "\n ".join(lines) + "\n" |
|
main_str += ")" |
|
return main_str |
|
|
|
__str__ = __repr__ |
|
|
|
@property |
|
def results_count(self): |
|
"""Return number of attacks made so far.""" |
|
return len(self.attack_log_manager.results) |
|
|
|
@property |
|
def num_skipped_attacks(self): |
|
return sum( |
|
isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results |
|
) |
|
|
|
@property |
|
def num_failed_attacks(self): |
|
return sum( |
|
isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results |
|
) |
|
|
|
@property |
|
def num_successful_attacks(self): |
|
return sum( |
|
isinstance(r, SuccessfulAttackResult) |
|
for r in self.attack_log_manager.results |
|
) |
|
|
|
@property |
|
def num_maximized_attacks(self): |
|
return sum( |
|
isinstance(r, MaximizedAttackResult) |
|
for r in self.attack_log_manager.results |
|
) |
|
|
|
@property |
|
def num_remaining_attacks(self): |
|
if self.attack_args.attack_n: |
|
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks |
|
count = self.attack_args.num_examples - non_skipped_attacks |
|
else: |
|
count = self.attack_args.num_examples - self.results_count |
|
return count |
|
|
|
@property |
|
def dataset_offset(self): |
|
"""Calculate offset into the dataset to start from.""" |
|
|
|
return self.attack_args.num_examples_offset + self.results_count |
|
|
|
@property |
|
def datetime(self): |
|
return datetime.datetime.fromtimestamp(self.time).strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
def save(self, quiet=False): |
|
file_name = "{}.ta.chkpt".format(int(self.time * 1000)) |
|
if not os.path.exists(self.attack_args.checkpoint_dir): |
|
os.makedirs(self.attack_args.checkpoint_dir) |
|
path = os.path.join(self.attack_args.checkpoint_dir, file_name) |
|
if not quiet: |
|
print("\n\n" + "=" * 125) |
|
logger.info( |
|
'Saving checkpoint under "{}" at {} after {} attacks.'.format( |
|
path, self.datetime, self.results_count |
|
) |
|
) |
|
print("=" * 125 + "\n") |
|
with open(path, "wb") as f: |
|
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
@classmethod |
|
def load(cls, path): |
|
with open(path, "rb") as f: |
|
checkpoint = pickle.load(f) |
|
assert isinstance(checkpoint, cls) |
|
|
|
return checkpoint |
|
|
|
def _verify(self): |
|
"""Check that the checkpoint has no duplicates and is consistent.""" |
|
assert self.num_remaining_attacks == len( |
|
self.worklist |
|
), "Recorded number of remaining attacks and size of worklist are different." |
|
|
|
results_set = { |
|
result.original_text for result in self.attack_log_manager.results |
|
} |
|
assert ( |
|
len(results_set) == self.results_count |
|
), "Duplicate `AttackResults` found." |
|
|