PFEemp2024's picture
add necessary file
63775f2
"""
AttackResult Class
======================
"""
from abc import ABC
from langdetect import detect
from textattack.goal_function_results import GoalFunctionResult
from textattack.shared import utils
class AttackResult(ABC):
"""Result of an Attack run on a single (output, text_input) pair.
Args:
original_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
Result of the goal function applied to the original text
perturbed_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
Result of the goal function applied to the perturbed text. May or may not have been successful.
"""
def __init__(self, original_result, perturbed_result):
if original_result is None:
raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult):
raise TypeError(f"Invalid original goal function result: {original_result}")
if perturbed_result is None:
raise ValueError("Attack perturbed result cannot be None")
elif not isinstance(perturbed_result, GoalFunctionResult):
raise TypeError(
f"Invalid perturbed goal function result: {perturbed_result}"
)
self.original_result = original_result
self.perturbed_result = perturbed_result
self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present,
# because we won't need them anymore anyway.
self.original_result.attacked_text.free_memory()
self.perturbed_result.attacked_text.free_memory()
def original_text(self, color_method=None):
"""Returns the text portion of `self.original_result`.
Helper method.
"""
return self.original_result.attacked_text.printable_text(
key_color=("bold", "underline"), key_color_method=color_method
)
def perturbed_text(self, color_method=None):
"""Returns the text portion of `self.perturbed_result`.
Helper method.
"""
return self.perturbed_result.attacked_text.printable_text(
key_color=("bold", "underline"), key_color_method=color_method
)
def str_lines(self, color_method=None):
"""A list of the lines to be printed for this result's string
representation."""
lines = [self.goal_function_result_str(color_method=color_method)]
lines.extend(self.diff_color(color_method))
return lines
def __str__(self, color_method=None):
return "\n\n".join(self.str_lines(color_method=color_method))
def goal_function_result_str(self, color_method=None):
"""Returns a string illustrating the results of the goal function."""
orig_colored = self.original_result.get_colored_output(color_method)
pert_colored = self.perturbed_result.get_colored_output(color_method)
return orig_colored + " --> " + pert_colored
def diff_color(self, color_method=None):
"""Highlights the difference between two texts using color.
Has to account for deletions and insertions from original text to
perturbed. Relies on the index map stored in
``self.original_result.attacked_text.attack_attrs["original_index_map"]``.
"""
t1 = self.original_result.attacked_text
t2 = self.perturbed_result.attacked_text
if detect(t1.text) == "zh-cn" or detect(t1.text) == "ko":
return t1.printable_text(), t2.printable_text()
if color_method is None:
return t1.printable_text(), t2.printable_text()
color_1 = self.original_result.get_text_color_input()
color_2 = self.perturbed_result.get_text_color_perturbed()
# iterate through and count equal/unequal words
words_1_idxs = []
t2_equal_idxs = set()
original_index_map = t2.attack_attrs["original_index_map"]
for t1_idx, t2_idx in enumerate(original_index_map):
if t2_idx == -1:
# add words in t1 that are not in t2
words_1_idxs.append(t1_idx)
else:
w1 = t1.words[t1_idx]
w2 = t2.words[t2_idx]
if w1 == w2:
t2_equal_idxs.add(t2_idx)
else:
words_1_idxs.append(t1_idx)
# words to color in t2 are all the words that didn't have an equal,
# mapped word in t1
words_2_idxs = list(sorted(set(range(t2.num_words)) - t2_equal_idxs))
# make lists of colored words
words_1 = [t1.words[i] for i in words_1_idxs]
words_1 = [utils.color_text(w, color_1, color_method) for w in words_1]
words_2 = [t2.words[i] for i in words_2_idxs]
words_2 = [utils.color_text(w, color_2, color_method) for w in words_2]
t1 = self.original_result.attacked_text.replace_words_at_indices(
words_1_idxs, words_1
)
t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
words_2_idxs, words_2
)
key_color = ("bold", "underline")
return (
t1.printable_text(key_color=key_color, key_color_method=color_method),
t2.printable_text(key_color=key_color, key_color_method=color_method),
)