""" GoalFunctionResult class ==================================== """ from abc import ABC, abstractmethod import torch from textattack.shared import utils class GoalFunctionResultStatus: SUCCEEDED = 0 SEARCHING = 1 # In process of searching for a success MAXIMIZING = 2 SKIPPED = 3 class GoalFunctionResult(ABC): """Represents the result of a goal function evaluating a AttackedText object. Args: attacked_text: The sequence that was evaluated. output: The display-friendly output. goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal. score: A score representing how close the model is to achieving its goal. num_queries: How many model queries have been used ground_truth_output: The ground truth output """ def __init__( self, attacked_text, raw_output, output, goal_status, score, num_queries, ground_truth_output, goal_function_result_type="", ): self.attacked_text = attacked_text self.raw_output = raw_output self.output = output self.score = score self.goal_status = goal_status self.num_queries = num_queries self.ground_truth_output = ground_truth_output self.goal_function_result_type = goal_function_result_type if isinstance(self.raw_output, torch.Tensor): self.raw_output = self.raw_output.numpy() if isinstance(self.score, torch.Tensor): self.score = self.score.item() def __repr__(self): main_str = "GoalFunctionResult( " lines = [] lines.append( utils.add_indent( f"(goal_function_result_type): {self.goal_function_result_type}", 2 ) ) lines.append(utils.add_indent(f"(attacked_text): {self.attacked_text.text}", 2)) lines.append( utils.add_indent(f"(ground_truth_output): {self.ground_truth_output}", 2) ) lines.append(utils.add_indent(f"(model_output): {self.output}", 2)) lines.append(utils.add_indent(f"(score): {self.score}", 2)) main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str @abstractmethod def get_text_color_input(self): """A string representing the color this result's changed portion should be if it represents the original input.""" raise NotImplementedError() @abstractmethod def get_text_color_perturbed(self): """A string representing the color this result's changed portion should be if it represents the perturbed input.""" raise NotImplementedError() @abstractmethod def get_colored_output(self, color_method=None): """Returns a string representation of this result's output, colored according to `color_method`.""" raise NotImplementedError()