DCWIR-Offcial-Demo / textattack /goal_function_results /classification_goal_function_result.py
PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame
2.66 kB
"""
ClassificationGoalFunctionResult Class
========================================
"""
import torch
import textattack
from textattack.shared import utils
from .goal_function_result import GoalFunctionResult
class ClassificationGoalFunctionResult(GoalFunctionResult):
"""Represents the result of a classification goal function."""
def __init__(
self,
attacked_text,
raw_output,
output,
goal_status,
score,
num_queries,
ground_truth_output,
):
super().__init__(
attacked_text,
raw_output,
output,
goal_status,
score,
num_queries,
ground_truth_output,
goal_function_result_type="Classification",
)
@property
def _processed_output(self):
"""Takes a model output (like `1`) and returns the class labeled output
(like `positive`), if possible.
Also returns the associated color.
"""
output_label = self.raw_output.argmax()
if self.attacked_text.attack_attrs.get("label_names") is not None:
output = self.attacked_text.attack_attrs["label_names"][self.output]
output = textattack.shared.utils.process_label_name(output)
color = textattack.shared.utils.color_from_output(output, output_label)
return output, color
else:
color = textattack.shared.utils.color_from_label(output_label)
return output_label, color
def get_text_color_input(self):
"""A string representing the color this result's changed portion should
be if it represents the original input."""
_, color = self._processed_output
return color
def get_text_color_perturbed(self):
"""A string representing the color this result's changed portion should
be if it represents the perturbed input."""
_, color = self._processed_output
return color
def get_colored_output(self, color_method=None):
"""Returns a string representation of this result's output, colored
according to `color_method`."""
output_label = self.raw_output.argmax()
confidence_score = self.raw_output[output_label]
if isinstance(confidence_score, torch.Tensor):
confidence_score = confidence_score.item()
output, color = self._processed_output
# concatenate with label and convert confidence score to percent, like '33%'
output_str = f"{output} ({confidence_score:.0%})"
return utils.color_text(output_str, color=color, method=color_method)