DCWIR-Demo / textattack /goal_function_results /goal_function_result.py
PFEemp2024's picture
add necessary file
63775f2
"""
GoalFunctionResult class
====================================
"""
from abc import ABC, abstractmethod
import torch
from textattack.shared import utils
class GoalFunctionResultStatus:
SUCCEEDED = 0
SEARCHING = 1 # In process of searching for a success
MAXIMIZING = 2
SKIPPED = 3
class GoalFunctionResult(ABC):
"""Represents the result of a goal function evaluating a AttackedText
object.
Args:
attacked_text: The sequence that was evaluated.
output: The display-friendly output.
goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal.
score: A score representing how close the model is to achieving its goal.
num_queries: How many model queries have been used
ground_truth_output: The ground truth output
"""
def __init__(
self,
attacked_text,
raw_output,
output,
goal_status,
score,
num_queries,
ground_truth_output,
goal_function_result_type="",
):
self.attacked_text = attacked_text
self.raw_output = raw_output
self.output = output
self.score = score
self.goal_status = goal_status
self.num_queries = num_queries
self.ground_truth_output = ground_truth_output
self.goal_function_result_type = goal_function_result_type
if isinstance(self.raw_output, torch.Tensor):
self.raw_output = self.raw_output.numpy()
if isinstance(self.score, torch.Tensor):
self.score = self.score.item()
def __repr__(self):
main_str = "GoalFunctionResult( "
lines = []
lines.append(
utils.add_indent(
f"(goal_function_result_type): {self.goal_function_result_type}", 2
)
)
lines.append(utils.add_indent(f"(attacked_text): {self.attacked_text.text}", 2))
lines.append(
utils.add_indent(f"(ground_truth_output): {self.ground_truth_output}", 2)
)
lines.append(utils.add_indent(f"(model_output): {self.output}", 2))
lines.append(utils.add_indent(f"(score): {self.score}", 2))
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
@abstractmethod
def get_text_color_input(self):
"""A string representing the color this result's changed portion should
be if it represents the original input."""
raise NotImplementedError()
@abstractmethod
def get_text_color_perturbed(self):
"""A string representing the color this result's changed portion should
be if it represents the perturbed input."""
raise NotImplementedError()
@abstractmethod
def get_colored_output(self, color_method=None):
"""Returns a string representation of this result's output, colored
according to `color_method`."""
raise NotImplementedError()