Rapid-Textual-Adversarial-Defense
/
textattack
/goal_function_results
/classification_goal_function_result.py
""" | |
ClassificationGoalFunctionResult Class | |
======================================== | |
""" | |
import torch | |
import textattack | |
from textattack.shared import utils | |
from .goal_function_result import GoalFunctionResult | |
class ClassificationGoalFunctionResult(GoalFunctionResult): | |
"""Represents the result of a classification goal function.""" | |
def __init__( | |
self, | |
attacked_text, | |
raw_output, | |
output, | |
goal_status, | |
score, | |
num_queries, | |
ground_truth_output, | |
): | |
super().__init__( | |
attacked_text, | |
raw_output, | |
output, | |
goal_status, | |
score, | |
num_queries, | |
ground_truth_output, | |
goal_function_result_type="Classification", | |
) | |
def _processed_output(self): | |
"""Takes a model output (like `1`) and returns the class labeled output | |
(like `positive`), if possible. | |
Also returns the associated color. | |
""" | |
output_label = self.raw_output.argmax() | |
if self.attacked_text.attack_attrs.get("label_names") is not None: | |
output = self.attacked_text.attack_attrs["label_names"][self.output] | |
output = textattack.shared.utils.process_label_name(output) | |
color = textattack.shared.utils.color_from_output(output, output_label) | |
return output, color | |
else: | |
color = textattack.shared.utils.color_from_label(output_label) | |
return output_label, color | |
def get_text_color_input(self): | |
"""A string representing the color this result's changed portion should | |
be if it represents the original input.""" | |
_, color = self._processed_output | |
return color | |
def get_text_color_perturbed(self): | |
"""A string representing the color this result's changed portion should | |
be if it represents the perturbed input.""" | |
_, color = self._processed_output | |
return color | |
def get_colored_output(self, color_method=None): | |
"""Returns a string representation of this result's output, colored | |
according to `color_method`.""" | |
output_label = self.raw_output.argmax() | |
confidence_score = self.raw_output[output_label] | |
if isinstance(confidence_score, torch.Tensor): | |
confidence_score = confidence_score.item() | |
output, color = self._processed_output | |
# concatenate with label and convert confidence score to percent, like '33%' | |
output_str = f"{output} ({confidence_score:.0%})" | |
return utils.color_text(output_str, color=color, method=color_method) | |