anonymous8/RPD-Demo
initial commit
4943752
raw
history blame
8.24 kB
"""
Misc Checkpoints
===================
The ``AttackCheckpoint`` class saves in-progress attacks and loads saved attacks from disk.
"""
import copy
import datetime
import os
import pickle
import time
import textattack
from textattack.attack_results import (
FailedAttackResult,
MaximizedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
)
from textattack.shared import logger, utils
# TODO: Consider still keeping the old `Checkpoint` class and allow older checkpoints to be loaded to new TextAttack
class AttackCheckpoint:
"""An object that stores necessary information for saving and loading
checkpoints.
Args:
attack_args (textattack.AttackArgs): Arguments of the original attack
attack_log_manager (textattack.loggers.AttackLogManager): Object for storing attack results
worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset.
worklist_candidates (int): List of other available examples we can attack. Used to get the next dataset element when `attack_n=True`.
chkpt_time (float): epoch time representing when checkpoint was made
"""
def __init__(
self,
attack_args,
attack_log_manager,
worklist,
worklist_candidates,
chkpt_time=None,
):
assert isinstance(
attack_args, textattack.AttackArgs
), "`attack_args` must be of type `textattack.AttackArgs`."
assert isinstance(
attack_log_manager, textattack.loggers.AttackLogManager
), "`attack_log_manager` must be of type `textattack.loggers.AttackLogManager`."
self.attack_args = copy.deepcopy(attack_args)
self.attack_log_manager = attack_log_manager
self.worklist = worklist
self.worklist_candidates = worklist_candidates
if chkpt_time:
self.time = chkpt_time
else:
self.time = time.time()
self._verify()
def __repr__(self):
main_str = "AttackCheckpoint("
lines = []
lines.append(utils.add_indent(f"(Time): {self.datetime}", 2))
args_lines = []
recipe_set = (
True
if "recipe" in self.attack_args.__dict__
and self.attack_args.__dict__["recipe"]
else False
)
mutually_exclusive_args = ["search", "transformation", "constraints", "recipe"]
if recipe_set:
args_lines.append(
utils.add_indent(f'(recipe): {self.attack_args.__dict__["recipe"]}', 2)
)
else:
args_lines.append(
utils.add_indent(f'(search): {self.attack_args.__dict__["search"]}', 2)
)
args_lines.append(
utils.add_indent(
f'(transformation): {self.attack_args.__dict__["transformation"]}',
2,
)
)
args_lines.append(
utils.add_indent(
f'(constraints): {self.attack_args.__dict__["constraints"]}', 2
)
)
for key in self.attack_args.__dict__:
if key not in mutually_exclusive_args:
args_lines.append(
utils.add_indent(f"({key}): {self.attack_args.__dict__[key]}", 2)
)
args_str = utils.add_indent("\n" + "\n".join(args_lines), 2)
lines.append(utils.add_indent(f"(attack_args): {args_str}", 2))
attack_logger_lines = []
attack_logger_lines.append(
utils.add_indent(
f"(Total number of examples to attack): {self.attack_args.num_examples}",
2,
)
)
attack_logger_lines.append(
utils.add_indent(f"(Number of attacks performed): {self.results_count}", 2)
)
attack_logger_lines.append(
utils.add_indent(
f"(Number of remaining attacks): {self.num_remaining_attacks}", 2
)
)
breakdown_lines = []
breakdown_lines.append(
utils.add_indent(
f"(Number of successful attacks): {self.num_successful_attacks}", 2
)
)
breakdown_lines.append(
utils.add_indent(
f"(Number of failed attacks): {self.num_failed_attacks}", 2
)
)
breakdown_lines.append(
utils.add_indent(
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
)
)
breakdown_lines.append(
utils.add_indent(
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
)
)
breakdown_str = utils.add_indent("\n" + "\n".join(breakdown_lines), 2)
attack_logger_lines.append(
utils.add_indent(f"(Latest result breakdown): {breakdown_str}", 2)
)
attack_logger_str = utils.add_indent("\n" + "\n".join(attack_logger_lines), 2)
lines.append(
utils.add_indent(f"(Previous attack summary): {attack_logger_str}", 2)
)
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
__str__ = __repr__
@property
def results_count(self):
"""Return number of attacks made so far."""
return len(self.attack_log_manager.results)
@property
def num_skipped_attacks(self):
return sum(
isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results
)
@property
def num_failed_attacks(self):
return sum(
isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results
)
@property
def num_successful_attacks(self):
return sum(
isinstance(r, SuccessfulAttackResult)
for r in self.attack_log_manager.results
)
@property
def num_maximized_attacks(self):
return sum(
isinstance(r, MaximizedAttackResult)
for r in self.attack_log_manager.results
)
@property
def num_remaining_attacks(self):
if self.attack_args.attack_n:
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks
count = self.attack_args.num_examples - non_skipped_attacks
else:
count = self.attack_args.num_examples - self.results_count
return count
@property
def dataset_offset(self):
"""Calculate offset into the dataset to start from."""
# Original offset + # of results processed so far
return self.attack_args.num_examples_offset + self.results_count
@property
def datetime(self):
return datetime.datetime.fromtimestamp(self.time).strftime("%Y-%m-%d %H:%M:%S")
def save(self, quiet=False):
file_name = "{}.ta.chkpt".format(int(self.time * 1000))
if not os.path.exists(self.attack_args.checkpoint_dir):
os.makedirs(self.attack_args.checkpoint_dir)
path = os.path.join(self.attack_args.checkpoint_dir, file_name)
if not quiet:
print("\n\n" + "=" * 125)
logger.info(
'Saving checkpoint under "{}" at {} after {} attacks.'.format(
path, self.datetime, self.results_count
)
)
print("=" * 125 + "\n")
with open(path, "wb") as f:
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
@classmethod
def load(cls, path):
with open(path, "rb") as f:
checkpoint = pickle.load(f)
assert isinstance(checkpoint, cls)
return checkpoint
def _verify(self):
"""Check that the checkpoint has no duplicates and is consistent."""
assert self.num_remaining_attacks == len(
self.worklist
), "Recorded number of remaining attacks and size of worklist are different."
results_set = {
result.original_text for result in self.attack_log_manager.results
}
assert (
len(results_set) == self.results_count
), "Duplicate `AttackResults` found."