|
""".. _attacked_text: |
|
|
|
Attacked Text Class |
|
===================== |
|
|
|
A helper class that represents a string that can be attacked. |
|
""" |
|
|
|
from collections import OrderedDict |
|
import math |
|
|
|
import flair |
|
from flair.data import Sentence |
|
import numpy as np |
|
import torch |
|
|
|
import textattack |
|
|
|
from .utils import device, words_from_text |
|
|
|
flair.device = device |
|
|
|
|
|
class AttackedText: |
|
|
|
"""A helper class that represents a string that can be attacked. |
|
|
|
Models that take multiple sentences as input separate them by ``SPLIT_TOKEN``. |
|
Attacks "see" the entire input, joined into one string, without the split token. |
|
|
|
``AttackedText`` instances that were perturbed from other ``AttackedText`` |
|
objects contain a pointer to the previous text |
|
(``attack_attrs["previous_attacked_text"]``), so that the full chain of |
|
perturbations might be reconstructed by using this key to form a linked |
|
list. |
|
|
|
Args: |
|
text (string): The string that this AttackedText represents |
|
attack_attrs (dict): Dictionary of various attributes stored |
|
during the course of an attack. |
|
""" |
|
|
|
SPLIT_TOKEN = "<SPLIT>" |
|
|
|
def __init__(self, text_input, attack_attrs=None): |
|
|
|
if isinstance(text_input, str): |
|
self._text_input = OrderedDict([("text", text_input)]) |
|
elif isinstance(text_input, OrderedDict): |
|
self._text_input = text_input |
|
else: |
|
raise TypeError( |
|
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)" |
|
) |
|
|
|
self._words = None |
|
self._words_per_input = None |
|
self._pos_tags = None |
|
self._ner_tags = None |
|
|
|
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()]) |
|
if attack_attrs is None: |
|
self.attack_attrs = dict() |
|
elif isinstance(attack_attrs, dict): |
|
self.attack_attrs = attack_attrs |
|
else: |
|
raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}") |
|
|
|
|
|
self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words)) |
|
|
|
self.attack_attrs.setdefault("modified_indices", set()) |
|
|
|
def __eq__(self, other): |
|
"""Compares two text instances to make sure they have the same attack |
|
attributes. |
|
|
|
Since some elements stored in ``self.attack_attrs`` may be numpy |
|
arrays, we have to take special care when comparing them. |
|
""" |
|
if not (self.text == other.text): |
|
return False |
|
if len(self.attack_attrs) != len(other.attack_attrs): |
|
return False |
|
for key in self.attack_attrs: |
|
if key not in other.attack_attrs: |
|
return False |
|
elif isinstance(self.attack_attrs[key], np.ndarray): |
|
if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape): |
|
return False |
|
elif not (self.attack_attrs[key] == other.attack_attrs[key]).all(): |
|
return False |
|
else: |
|
if not self.attack_attrs[key] == other.attack_attrs[key]: |
|
return False |
|
return True |
|
|
|
def __hash__(self): |
|
return hash(self.text) |
|
|
|
def free_memory(self): |
|
"""Delete items that take up memory. |
|
|
|
Can be called once the AttackedText is only needed to display. |
|
""" |
|
if "previous_attacked_text" in self.attack_attrs: |
|
self.attack_attrs["previous_attacked_text"].free_memory() |
|
self.attack_attrs.pop("previous_attacked_text", None) |
|
|
|
self.attack_attrs.pop("last_transformation", None) |
|
|
|
for key in self.attack_attrs: |
|
if isinstance(self.attack_attrs[key], torch.Tensor): |
|
self.attack_attrs.pop(key, None) |
|
|
|
def text_window_around_index(self, index, window_size): |
|
"""The text window of ``window_size`` words centered around |
|
``index``.""" |
|
length = self.num_words |
|
half_size = (window_size - 1) / 2.0 |
|
if index - half_size < 0: |
|
start = 0 |
|
end = min(window_size - 1, length - 1) |
|
elif index + half_size >= length: |
|
start = max(0, length - window_size) |
|
end = length - 1 |
|
else: |
|
start = index - math.ceil(half_size) |
|
end = index + math.floor(half_size) |
|
text_idx_start = self._text_index_of_word_index(start) |
|
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end]) |
|
return self.text[text_idx_start:text_idx_end] |
|
|
|
def pos_of_word_index(self, desired_word_idx): |
|
"""Returns the part-of-speech of the word at index `word_idx`. |
|
|
|
Uses FLAIR part-of-speech tagger. |
|
""" |
|
if not self._pos_tags: |
|
sentence = Sentence( |
|
self.text, |
|
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), |
|
) |
|
textattack.shared.utils.flair_tag(sentence) |
|
self._pos_tags = sentence |
|
flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result( |
|
self._pos_tags |
|
) |
|
|
|
for word_idx, word in enumerate(self.words): |
|
assert ( |
|
word in flair_word_list |
|
), "word absent in flair returned part-of-speech tags" |
|
word_idx_in_flair_tags = flair_word_list.index(word) |
|
if word_idx == desired_word_idx: |
|
return flair_pos_list[word_idx_in_flair_tags] |
|
else: |
|
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :] |
|
flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :] |
|
|
|
raise ValueError( |
|
f"Did not find word from index {desired_word_idx} in flair POS tag" |
|
) |
|
|
|
def ner_of_word_index(self, desired_word_idx, model_name="ner"): |
|
"""Returns the ner tag of the word at index `word_idx`. |
|
|
|
Uses FLAIR ner tagger. |
|
""" |
|
if not self._ner_tags: |
|
sentence = Sentence( |
|
self.text, |
|
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), |
|
) |
|
textattack.shared.utils.flair_tag(sentence, model_name) |
|
self._ner_tags = sentence |
|
flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result( |
|
self._ner_tags, "ner" |
|
) |
|
|
|
for word_idx, word in enumerate(flair_word_list): |
|
word_idx_in_flair_tags = flair_word_list.index(word) |
|
if word_idx == desired_word_idx: |
|
return flair_ner_list[word_idx_in_flair_tags] |
|
else: |
|
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :] |
|
flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :] |
|
|
|
raise ValueError( |
|
f"Did not find word from index {desired_word_idx} in flair POS tag" |
|
) |
|
|
|
def _text_index_of_word_index(self, i): |
|
"""Returns the index of word ``i`` in self.text.""" |
|
pre_words = self.words[: i + 1] |
|
lower_text = self.text.lower() |
|
|
|
look_after_index = 0 |
|
for word in pre_words: |
|
look_after_index = lower_text.find(word.lower(), look_after_index) + len( |
|
word |
|
) |
|
look_after_index -= len(self.words[i]) |
|
return look_after_index |
|
|
|
def text_until_word_index(self, i): |
|
"""Returns the text before the beginning of word at index ``i``.""" |
|
look_after_index = self._text_index_of_word_index(i) |
|
return self.text[:look_after_index] |
|
|
|
def text_after_word_index(self, i): |
|
"""Returns the text after the end of word at index ``i``.""" |
|
|
|
look_after_index = self._text_index_of_word_index(i) + len(self.words[i]) |
|
return self.text[look_after_index:] |
|
|
|
def first_word_diff(self, other_attacked_text): |
|
"""Returns the first word in self.words that differs from |
|
other_attacked_text. |
|
|
|
Useful for word swap strategies. |
|
""" |
|
w1 = self.words |
|
w2 = other_attacked_text.words |
|
for i in range(min(len(w1), len(w2))): |
|
if w1[i] != w2[i]: |
|
return w1[i] |
|
return None |
|
|
|
def first_word_diff_index(self, other_attacked_text): |
|
"""Returns the index of the first word in self.words that differs from |
|
other_attacked_text. |
|
|
|
Useful for word swap strategies. |
|
""" |
|
w1 = self.words |
|
w2 = other_attacked_text.words |
|
for i in range(min(len(w1), len(w2))): |
|
if w1[i] != w2[i]: |
|
return i |
|
return None |
|
|
|
def all_words_diff(self, other_attacked_text): |
|
"""Returns the set of indices for which this and other_attacked_text |
|
have different words.""" |
|
indices = set() |
|
w1 = self.words |
|
w2 = other_attacked_text.words |
|
for i in range(min(len(w1), len(w2))): |
|
if w1[i] != w2[i]: |
|
indices.add(i) |
|
return indices |
|
|
|
def ith_word_diff(self, other_attacked_text, i): |
|
"""Returns whether the word at index i differs from |
|
other_attacked_text.""" |
|
w1 = self.words |
|
w2 = other_attacked_text.words |
|
if len(w1) - 1 < i or len(w2) - 1 < i: |
|
return True |
|
return w1[i] != w2[i] |
|
|
|
def words_diff_num(self, other_attacked_text): |
|
|
|
def generate_tokens(words): |
|
result = {} |
|
idx = 1 |
|
for w in words: |
|
if w not in result: |
|
result[w] = idx |
|
idx += 1 |
|
return result |
|
|
|
def words_to_tokens(words, tokens): |
|
result = [] |
|
for w in words: |
|
result.append(tokens[w]) |
|
return result |
|
|
|
def edit_distance(w1_t, w2_t): |
|
matrix = [ |
|
[i + j for j in range(len(w2_t) + 1)] for i in range(len(w1_t) + 1) |
|
] |
|
|
|
for i in range(1, len(w1_t) + 1): |
|
for j in range(1, len(w2_t) + 1): |
|
if w1_t[i - 1] == w2_t[j - 1]: |
|
d = 0 |
|
else: |
|
d = 1 |
|
matrix[i][j] = min( |
|
matrix[i - 1][j] + 1, |
|
matrix[i][j - 1] + 1, |
|
matrix[i - 1][j - 1] + d, |
|
) |
|
|
|
return matrix[len(w1_t)][len(w2_t)] |
|
|
|
def cal_dif(w1, w2): |
|
tokens = generate_tokens(w1 + w2) |
|
w1_t = words_to_tokens(w1, tokens) |
|
w2_t = words_to_tokens(w2, tokens) |
|
return edit_distance(w1_t, w2_t) |
|
|
|
w1 = self.words |
|
w2 = other_attacked_text.words |
|
return cal_dif(w1, w2) |
|
|
|
def convert_from_original_idxs(self, idxs): |
|
"""Takes indices of words from original string and converts them to |
|
indices of the same words in the current string. |
|
|
|
Uses information from |
|
``self.attack_attrs['original_index_map']``, which maps word |
|
indices from the original to perturbed text. |
|
""" |
|
if len(self.attack_attrs["original_index_map"]) == 0: |
|
return idxs |
|
elif isinstance(idxs, set): |
|
idxs = list(idxs) |
|
|
|
elif not isinstance(idxs, [list, np.ndarray]): |
|
raise TypeError( |
|
f"convert_from_original_idxs got invalid idxs type {type(idxs)}" |
|
) |
|
|
|
return [self.attack_attrs["original_index_map"][i] for i in idxs] |
|
|
|
def replace_words_at_indices(self, indices, new_words): |
|
"""This code returns a new AttackedText object where the word at |
|
``index`` is replaced with a new word.""" |
|
if len(indices) != len(new_words): |
|
raise ValueError( |
|
f"Cannot replace {len(new_words)} words at {len(indices)} indices." |
|
) |
|
words = self.words[:] |
|
for i, new_word in zip(indices, new_words): |
|
if not isinstance(new_word, str): |
|
raise TypeError( |
|
f"replace_words_at_indices requires ``str`` words, got {type(new_word)}" |
|
) |
|
if (i < 0) or (i > len(words)): |
|
raise ValueError(f"Cannot assign word at index {i}") |
|
words[i] = new_word |
|
return self.generate_new_attacked_text(words) |
|
|
|
def replace_word_at_index(self, index, new_word): |
|
"""This code returns a new AttackedText object where the word at |
|
``index`` is replaced with a new word.""" |
|
if not isinstance(new_word, str): |
|
raise TypeError( |
|
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}" |
|
) |
|
return self.replace_words_at_indices([index], [new_word]) |
|
|
|
def delete_word_at_index(self, index): |
|
"""This code returns a new AttackedText object where the word at |
|
``index`` is removed.""" |
|
return self.replace_word_at_index(index, "") |
|
|
|
def insert_text_after_word_index(self, index, text): |
|
"""Inserts a string before word at index ``index`` and attempts to add |
|
appropriate spacing.""" |
|
if not isinstance(text, str): |
|
raise TypeError(f"text must be an str, got type {type(text)}") |
|
word_at_index = self.words[index] |
|
new_text = " ".join((word_at_index, text)) |
|
return self.replace_word_at_index(index, new_text) |
|
|
|
def insert_text_before_word_index(self, index, text): |
|
"""Inserts a string before word at index ``index`` and attempts to add |
|
appropriate spacing.""" |
|
if not isinstance(text, str): |
|
raise TypeError(f"text must be an str, got type {type(text)}") |
|
word_at_index = self.words[index] |
|
|
|
|
|
new_text = " ".join((text, word_at_index)) |
|
return self.replace_word_at_index(index, new_text) |
|
|
|
def get_deletion_indices(self): |
|
return self.attack_attrs["original_index_map"][ |
|
self.attack_attrs["original_index_map"] == -1 |
|
] |
|
|
|
def generate_new_attacked_text(self, new_words): |
|
"""Returns a new AttackedText object and replaces old list of words |
|
with a new list of words, but preserves the punctuation and spacing of |
|
the original message. |
|
|
|
``self.words`` is a list of the words in the current text with |
|
punctuation removed. However, each "word" in ``new_words`` could |
|
be an empty string, representing a word deletion, or a string |
|
with multiple space-separated words, representation an insertion |
|
of one or more words. |
|
""" |
|
perturbed_text = "" |
|
original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values()) |
|
new_attack_attrs = dict() |
|
if "label_names" in self.attack_attrs: |
|
new_attack_attrs["label_names"] = self.attack_attrs["label_names"] |
|
new_attack_attrs["newly_modified_indices"] = set() |
|
|
|
new_attack_attrs["previous_attacked_text"] = self |
|
|
|
|
|
new_attack_attrs["modified_indices"] = self.attack_attrs[ |
|
"modified_indices" |
|
].copy() |
|
new_attack_attrs["original_index_map"] = self.attack_attrs[ |
|
"original_index_map" |
|
].copy() |
|
new_i = 0 |
|
|
|
|
|
for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)): |
|
word_start = original_text.index(input_word) |
|
word_end = word_start + len(input_word) |
|
perturbed_text += original_text[:word_start] |
|
original_text = original_text[word_end:] |
|
adv_words = words_from_text(adv_word_seq) |
|
adv_num_words = len(adv_words) |
|
num_words_diff = adv_num_words - len(words_from_text(input_word)) |
|
|
|
if num_words_diff != 0: |
|
|
|
|
|
shifted_modified_indices = set() |
|
for modified_idx in new_attack_attrs["modified_indices"]: |
|
if modified_idx < i: |
|
shifted_modified_indices.add(modified_idx) |
|
elif modified_idx > i: |
|
shifted_modified_indices.add(modified_idx + num_words_diff) |
|
else: |
|
pass |
|
new_attack_attrs["modified_indices"] = shifted_modified_indices |
|
|
|
|
|
new_idx_map = new_attack_attrs["original_index_map"].copy() |
|
if num_words_diff == -1: |
|
|
|
new_idx_map[new_idx_map == i] = -1 |
|
new_idx_map[new_idx_map > i] += num_words_diff |
|
|
|
if num_words_diff > 0 and input_word != adv_words[0]: |
|
|
|
new_idx_map[new_idx_map == i] += num_words_diff |
|
|
|
new_attack_attrs["original_index_map"] = new_idx_map |
|
|
|
for j in range(i, i + adv_num_words): |
|
if input_word != adv_word_seq: |
|
new_attack_attrs["modified_indices"].add(new_i) |
|
new_attack_attrs["newly_modified_indices"].add(new_i) |
|
new_i += 1 |
|
|
|
if adv_num_words == 0 and len(original_text): |
|
|
|
|
|
|
|
if i == 0: |
|
|
|
if original_text[0] == " ": |
|
original_text = original_text[1:] |
|
else: |
|
|
|
if perturbed_text[-1] == " ": |
|
perturbed_text = perturbed_text[:-1] |
|
|
|
perturbed_text += adv_word_seq |
|
perturbed_text += original_text |
|
|
|
|
|
new_attack_attrs["prev_attacked_text"] = self |
|
|
|
|
|
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN) |
|
perturbed_input = OrderedDict( |
|
zip(self._text_input.keys(), perturbed_input_texts) |
|
) |
|
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs) |
|
|
|
def words_diff_ratio(self, x): |
|
"""Get the ratio of words difference between current text and `x`. |
|
|
|
Note that current text and `x` must have same number of words. |
|
""" |
|
assert self.num_words == x.num_words |
|
return float(np.sum(self.words != x.words)) / self.num_words |
|
|
|
def align_with_model_tokens(self, model_wrapper): |
|
"""Align AttackedText's `words` with target model's tokenization scheme |
|
(e.g. word, character, subword). Specifically, we map each word to list |
|
of indices of tokens that compose the word (e.g. embedding --> ["em", |
|
"##bed", "##ding"]) |
|
|
|
Args: |
|
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model |
|
|
|
Returns: |
|
word2token_mapping (dict[int, list[int]]): Dictionary that maps i-th word to list of indices. |
|
""" |
|
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0] |
|
word2token_mapping = {} |
|
j = 0 |
|
last_matched = 0 |
|
|
|
for i, word in enumerate(self.words): |
|
matched_tokens = [] |
|
while j < len(tokens) and len(word) > 0: |
|
token = tokens[j].lower() |
|
idx = word.lower().find(token) |
|
if idx == 0: |
|
word = word[idx + len(token) :] |
|
matched_tokens.append(j) |
|
last_matched = j |
|
j += 1 |
|
|
|
if not matched_tokens: |
|
word2token_mapping[i] = None |
|
j = last_matched |
|
else: |
|
word2token_mapping[i] = matched_tokens |
|
|
|
return word2token_mapping |
|
|
|
@property |
|
def tokenizer_input(self): |
|
"""The tuple of inputs to be passed to the tokenizer.""" |
|
input_tuple = tuple(self._text_input.values()) |
|
|
|
if len(input_tuple) == 1: |
|
return input_tuple[0] |
|
else: |
|
return input_tuple |
|
|
|
@property |
|
def column_labels(self): |
|
"""Returns the labels for this text's columns. |
|
|
|
For single-sequence inputs, this simply returns ['text']. |
|
""" |
|
return list(self._text_input.keys()) |
|
|
|
@property |
|
def words_per_input(self): |
|
"""Returns a list of lists of words corresponding to each input.""" |
|
if not self._words_per_input: |
|
self._words_per_input = [ |
|
words_from_text(_input) for _input in self._text_input.values() |
|
] |
|
return self._words_per_input |
|
|
|
@property |
|
def words(self): |
|
if not self._words: |
|
self._words = words_from_text(self.text) |
|
return self._words |
|
|
|
@property |
|
def text(self): |
|
"""Represents full text input. |
|
|
|
Multiply inputs are joined with a line break. |
|
""" |
|
return "\n".join(self._text_input.values()) |
|
|
|
@property |
|
def num_words(self): |
|
"""Returns the number of words in the sequence.""" |
|
return len(self.words) |
|
|
|
@property |
|
def newly_swapped_words(self): |
|
return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]] |
|
|
|
def printable_text(self, key_color="bold", key_color_method=None): |
|
"""Represents full text input. Adds field descriptions. |
|
|
|
For example, entailment inputs look like: |
|
``` |
|
premise: ... |
|
hypothesis: ... |
|
``` |
|
""" |
|
|
|
if len(self._text_input) == 1: |
|
return next(iter(self._text_input.values())) |
|
|
|
|
|
else: |
|
if key_color_method: |
|
|
|
def ck(k): |
|
return textattack.shared.utils.color_text( |
|
k, key_color, key_color_method |
|
) |
|
|
|
else: |
|
|
|
def ck(k): |
|
return k |
|
|
|
return "\n".join( |
|
f"{ck(key.capitalize())}: {value}" |
|
for key, value in self._text_input.items() |
|
) |
|
|
|
def __repr__(self): |
|
return f'<AttackedText "{self.text}">' |
|
|