|
""" |
|
Attack Class |
|
============ |
|
""" |
|
|
|
from collections import OrderedDict |
|
from typing import List, Union |
|
|
|
import lru |
|
import torch |
|
|
|
import textattack |
|
from textattack.attack_results import ( |
|
FailedAttackResult, |
|
MaximizedAttackResult, |
|
SkippedAttackResult, |
|
SuccessfulAttackResult, |
|
) |
|
from textattack.constraints import Constraint, PreTransformationConstraint |
|
from textattack.goal_function_results import GoalFunctionResultStatus |
|
from textattack.goal_functions import GoalFunction |
|
from textattack.models.wrappers import ModelWrapper |
|
from textattack.search_methods import SearchMethod |
|
from textattack.shared import AttackedText, utils |
|
from textattack.transformations import CompositeTransformation, Transformation |
|
|
|
|
|
class Attack: |
|
"""An attack generates adversarial examples on text. |
|
|
|
An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time. |
|
|
|
Args: |
|
goal_function (:class:`~textattack.goal_functions.GoalFunction`): |
|
A function for determining how well a perturbation is doing at achieving the attack's goal. |
|
constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`): |
|
A list of constraints to add to the attack, defining which perturbations are valid. |
|
transformation (:class:`~textattack.transformations.Transformation`): |
|
The transformation applied at each step of the attack. |
|
search_method (:class:`~textattack.search_methods.SearchMethod`): |
|
The method for exploring the search space of possible perturbations |
|
transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`): |
|
The number of items to keep in the transformations cache |
|
constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`): |
|
The number of items to keep in the constraints cache |
|
|
|
Example:: |
|
|
|
>>> import textattack |
|
>>> import transformers |
|
|
|
>>> # Load model, tokenizer, and model_wrapper |
|
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") |
|
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") |
|
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) |
|
|
|
>>> # Construct our four components for `Attack` |
|
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification |
|
>>> from textattack.constraints.semantics import WordEmbeddingDistance |
|
|
|
>>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) |
|
>>> constraints = [ |
|
... RepeatModification(), |
|
... StopwordModification() |
|
... WordEmbeddingDistance(min_cos_sim=0.9) |
|
... ] |
|
>>> transformation = WordSwapEmbedding(max_candidates=50) |
|
>>> search_method = GreedyWordSwapWIR(wir_method="delete") |
|
|
|
>>> # Construct the actual attack |
|
>>> attack = Attack(goal_function, constraints, transformation, search_method) |
|
|
|
>>> input_text = "I really enjoyed the new movie that came out last month." |
|
>>> label = 1 #Positive |
|
>>> attack_result = attack.attack(input_text, label) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
goal_function: GoalFunction, |
|
constraints: List[Union[Constraint, PreTransformationConstraint]], |
|
transformation: Transformation, |
|
search_method: SearchMethod, |
|
transformation_cache_size=2**15, |
|
constraint_cache_size=2**15, |
|
): |
|
"""Initialize an attack object. |
|
|
|
Attacks can be run multiple times. |
|
""" |
|
assert isinstance( |
|
goal_function, GoalFunction |
|
), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`." |
|
assert isinstance( |
|
constraints, list |
|
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`." |
|
for c in constraints: |
|
assert isinstance( |
|
c, (Constraint, PreTransformationConstraint) |
|
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`." |
|
assert isinstance( |
|
transformation, Transformation |
|
), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`." |
|
assert isinstance( |
|
search_method, SearchMethod |
|
), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`." |
|
|
|
self.goal_function = goal_function |
|
self.search_method = search_method |
|
self.transformation = transformation |
|
self.is_black_box = ( |
|
getattr(transformation, "is_black_box", True) and search_method.is_black_box |
|
) |
|
|
|
if not self.search_method.check_transformation_compatibility( |
|
self.transformation |
|
): |
|
raise ValueError( |
|
f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}" |
|
) |
|
|
|
self.constraints = [] |
|
self.pre_transformation_constraints = [] |
|
for constraint in constraints: |
|
if isinstance( |
|
constraint, |
|
textattack.constraints.PreTransformationConstraint, |
|
): |
|
self.pre_transformation_constraints.append(constraint) |
|
else: |
|
self.constraints.append(constraint) |
|
|
|
|
|
if not self.transformation.deterministic: |
|
self.use_transformation_cache = False |
|
elif isinstance(self.transformation, CompositeTransformation): |
|
self.use_transformation_cache = True |
|
for t in self.transformation.transformations: |
|
if not t.deterministic: |
|
self.use_transformation_cache = False |
|
break |
|
else: |
|
self.use_transformation_cache = True |
|
self.transformation_cache_size = transformation_cache_size |
|
self.transformation_cache = lru.LRU(transformation_cache_size) |
|
|
|
self.constraint_cache_size = constraint_cache_size |
|
self.constraints_cache = lru.LRU(constraint_cache_size) |
|
|
|
|
|
self.search_method.get_transformations = self.get_transformations |
|
|
|
self.search_method.goal_function = self.goal_function |
|
|
|
|
|
self.search_method.get_goal_results = self.goal_function.get_results |
|
|
|
|
|
self.search_method.get_indices_to_order = self.get_indices_to_order |
|
|
|
self.search_method.filter_transformations = self.filter_transformations |
|
|
|
def clear_cache(self, recursive=True): |
|
self.constraints_cache.clear() |
|
if self.use_transformation_cache: |
|
self.transformation_cache.clear() |
|
if recursive: |
|
self.goal_function.clear_cache() |
|
for constraint in self.constraints: |
|
if hasattr(constraint, "clear_cache"): |
|
constraint.clear_cache() |
|
|
|
def cpu_(self): |
|
"""Move any `torch.nn.Module` models that are part of Attack to CPU.""" |
|
visited = set() |
|
|
|
def to_cpu(obj): |
|
visited.add(id(obj)) |
|
if isinstance(obj, torch.nn.Module): |
|
obj.cpu() |
|
elif isinstance( |
|
obj, |
|
( |
|
Attack, |
|
GoalFunction, |
|
Transformation, |
|
SearchMethod, |
|
Constraint, |
|
PreTransformationConstraint, |
|
ModelWrapper, |
|
), |
|
): |
|
for key in obj.__dict__: |
|
s_obj = obj.__dict__[key] |
|
if id(s_obj) not in visited: |
|
to_cpu(s_obj) |
|
elif isinstance(obj, (list, tuple)): |
|
for item in obj: |
|
if id(item) not in visited and isinstance( |
|
item, (Transformation, Constraint, PreTransformationConstraint) |
|
): |
|
to_cpu(item) |
|
|
|
to_cpu(self) |
|
|
|
def cuda_(self): |
|
"""Move any `torch.nn.Module` models that are part of Attack to GPU.""" |
|
visited = set() |
|
|
|
def to_cuda(obj): |
|
visited.add(id(obj)) |
|
if isinstance(obj, torch.nn.Module): |
|
obj.to(textattack.shared.utils.device) |
|
elif isinstance( |
|
obj, |
|
( |
|
Attack, |
|
GoalFunction, |
|
Transformation, |
|
SearchMethod, |
|
Constraint, |
|
PreTransformationConstraint, |
|
ModelWrapper, |
|
), |
|
): |
|
for key in obj.__dict__: |
|
s_obj = obj.__dict__[key] |
|
if id(s_obj) not in visited: |
|
to_cuda(s_obj) |
|
elif isinstance(obj, (list, tuple)): |
|
for item in obj: |
|
if id(item) not in visited and isinstance( |
|
item, (Transformation, Constraint, PreTransformationConstraint) |
|
): |
|
to_cuda(item) |
|
|
|
to_cuda(self) |
|
|
|
def get_indices_to_order(self, current_text, **kwargs): |
|
"""Applies ``pre_transformation_constraints`` to ``text`` to get all |
|
the indices that can be used to search and order. |
|
|
|
Args: |
|
current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered. |
|
Returns: |
|
The length and the filtered list of indices which search methods can use to search/order. |
|
""" |
|
|
|
indices_to_order = self.transformation( |
|
current_text, |
|
pre_transformation_constraints=self.pre_transformation_constraints, |
|
return_indices=True, |
|
**kwargs, |
|
) |
|
|
|
len_text = len(indices_to_order) |
|
|
|
|
|
return len_text, list(indices_to_order) |
|
|
|
def _get_transformations_uncached(self, current_text, original_text=None, **kwargs): |
|
"""Applies ``self.transformation`` to ``text``, then filters the list |
|
of possible transformations through the applicable constraints. |
|
|
|
Args: |
|
current_text: The current ``AttackedText`` on which to perform the transformations. |
|
original_text: The original ``AttackedText`` from which the attack started. |
|
Returns: |
|
A filtered list of transformations where each transformation matches the constraints |
|
""" |
|
transformed_texts = self.transformation( |
|
current_text, |
|
pre_transformation_constraints=self.pre_transformation_constraints, |
|
**kwargs, |
|
) |
|
|
|
return transformed_texts |
|
|
|
def get_transformations(self, current_text, original_text=None, **kwargs): |
|
"""Applies ``self.transformation`` to ``text``, then filters the list |
|
of possible transformations through the applicable constraints. |
|
|
|
Args: |
|
current_text: The current ``AttackedText`` on which to perform the transformations. |
|
original_text: The original ``AttackedText`` from which the attack started. |
|
Returns: |
|
A filtered list of transformations where each transformation matches the constraints |
|
""" |
|
if not self.transformation: |
|
raise RuntimeError( |
|
"Cannot call `get_transformations` without a transformation." |
|
) |
|
|
|
if self.use_transformation_cache: |
|
cache_key = tuple([current_text] + sorted(kwargs.items())) |
|
if utils.hashable(cache_key) and cache_key in self.transformation_cache: |
|
|
|
self.transformation_cache[cache_key] = self.transformation_cache[ |
|
cache_key |
|
] |
|
transformed_texts = list(self.transformation_cache[cache_key]) |
|
else: |
|
transformed_texts = self._get_transformations_uncached( |
|
current_text, original_text, **kwargs |
|
) |
|
if utils.hashable(cache_key): |
|
self.transformation_cache[cache_key] = tuple(transformed_texts) |
|
else: |
|
transformed_texts = self._get_transformations_uncached( |
|
current_text, original_text, **kwargs |
|
) |
|
|
|
return self.filter_transformations( |
|
transformed_texts, current_text, original_text |
|
) |
|
|
|
def _filter_transformations_uncached( |
|
self, transformed_texts, current_text, original_text=None |
|
): |
|
"""Filters a list of potential transformed texts based on |
|
``self.constraints`` |
|
|
|
Args: |
|
transformed_texts: A list of candidate transformed ``AttackedText`` to filter. |
|
current_text: The current ``AttackedText`` on which the transformation was applied. |
|
original_text: The original ``AttackedText`` from which the attack started. |
|
""" |
|
filtered_texts = transformed_texts[:] |
|
for C in self.constraints: |
|
if len(filtered_texts) == 0: |
|
break |
|
if C.compare_against_original: |
|
if not original_text: |
|
raise ValueError( |
|
f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`" |
|
) |
|
|
|
filtered_texts = C.call_many(filtered_texts, original_text) |
|
else: |
|
filtered_texts = C.call_many(filtered_texts, current_text) |
|
|
|
for original_transformed_text in transformed_texts: |
|
self.constraints_cache[(current_text, original_transformed_text)] = False |
|
|
|
for filtered_text in filtered_texts: |
|
self.constraints_cache[(current_text, filtered_text)] = True |
|
return filtered_texts |
|
|
|
def filter_transformations( |
|
self, transformed_texts, current_text, original_text=None |
|
): |
|
"""Filters a list of potential transformed texts based on |
|
``self.constraints`` Utilizes an LRU cache to attempt to avoid |
|
recomputing common transformations. |
|
|
|
Args: |
|
transformed_texts: A list of candidate transformed ``AttackedText`` to filter. |
|
current_text: The current ``AttackedText`` on which the transformation was applied. |
|
original_text: The original ``AttackedText`` from which the attack started. |
|
""" |
|
|
|
transformed_texts = [ |
|
t for t in transformed_texts if t.text != current_text.text |
|
] |
|
|
|
uncached_texts = [] |
|
filtered_texts = [] |
|
for transformed_text in transformed_texts: |
|
if (current_text, transformed_text) not in self.constraints_cache: |
|
uncached_texts.append(transformed_text) |
|
else: |
|
|
|
self.constraints_cache[ |
|
(current_text, transformed_text) |
|
] = self.constraints_cache[(current_text, transformed_text)] |
|
if self.constraints_cache[(current_text, transformed_text)]: |
|
filtered_texts.append(transformed_text) |
|
filtered_texts += self._filter_transformations_uncached( |
|
uncached_texts, current_text, original_text=original_text |
|
) |
|
|
|
filtered_texts.sort(key=lambda t: t.text) |
|
return filtered_texts |
|
|
|
def _attack(self, initial_result): |
|
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in |
|
``initial_result``. |
|
|
|
Args: |
|
initial_result: The initial ``GoalFunctionResult`` from which to perturb. |
|
|
|
Returns: |
|
A ``SuccessfulAttackResult``, ``FailedAttackResult``, |
|
or ``MaximizedAttackResult``. |
|
""" |
|
final_result = self.search_method(initial_result) |
|
self.clear_cache() |
|
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: |
|
result = SuccessfulAttackResult( |
|
initial_result, |
|
final_result, |
|
) |
|
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING: |
|
result = FailedAttackResult( |
|
initial_result, |
|
final_result, |
|
) |
|
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING: |
|
result = MaximizedAttackResult( |
|
initial_result, |
|
final_result, |
|
) |
|
else: |
|
raise ValueError(f"Unrecognized goal status {final_result.goal_status}") |
|
return result |
|
|
|
def attack(self, example, ground_truth_output): |
|
"""Attack a single example. |
|
|
|
Args: |
|
example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`): |
|
Example to attack. It can be a single string or an `OrderedDict` where |
|
keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx. |
|
Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input. |
|
ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`): |
|
Ground truth output of `example`. |
|
For classification tasks, it should be an integer representing the ground truth label. |
|
For regression tasks (e.g. STS), it should be the target value. |
|
For seq2seq tasks (e.g. translation), it should be the target string. |
|
Returns: |
|
:class:`~textattack.attack_results.AttackResult` that represents the result of the attack. |
|
""" |
|
assert isinstance( |
|
example, (str, OrderedDict, AttackedText) |
|
), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`." |
|
if isinstance(example, (str, OrderedDict)): |
|
example = AttackedText(example) |
|
|
|
assert isinstance( |
|
ground_truth_output, (int, str) |
|
), "`ground_truth_output` must either be `str` or `int`." |
|
goal_function_result, _ = self.goal_function.init_attack_example( |
|
example, ground_truth_output |
|
) |
|
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED: |
|
return SkippedAttackResult(goal_function_result) |
|
else: |
|
result = self._attack(goal_function_result) |
|
return result |
|
|
|
def __repr__(self): |
|
"""Prints attack parameters in a human-readable string. |
|
|
|
Inspired by the readability of printing PyTorch nn.Modules: |
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py |
|
""" |
|
main_str = "Attack" + "(" |
|
lines = [] |
|
|
|
lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2)) |
|
|
|
lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2)) |
|
|
|
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2)) |
|
|
|
constraints_lines = [] |
|
constraints = self.constraints + self.pre_transformation_constraints |
|
if len(constraints): |
|
for i, constraint in enumerate(constraints): |
|
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2)) |
|
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2) |
|
else: |
|
constraints_str = "None" |
|
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2)) |
|
|
|
lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2)) |
|
main_str += "\n " + "\n ".join(lines) + "\n" |
|
main_str += ")" |
|
return main_str |
|
|
|
def __getstate__(self): |
|
state = self.__dict__.copy() |
|
state["transformation_cache"] = None |
|
state["constraints_cache"] = None |
|
return state |
|
|
|
def __setstate__(self, state): |
|
self.__dict__ = state |
|
self.transformation_cache = lru.LRU(self.transformation_cache_size) |
|
self.constraints_cache = lru.LRU(self.constraint_cache_size) |
|
|
|
__str__ = __repr__ |
|
|