""" Misc Validators ================= Validators ensure compatibility between search methods, transformations, constraints, and goal functions. """ import re import textattack from textattack.goal_functions import ( InputReduction, MinimizeBleu, NonOverlappingOutput, TargetedClassification, UntargetedClassification, ) from . import logger # A list of goal functions and the corresponding available models. MODELS_BY_GOAL_FUNCTIONS = { (TargetedClassification, UntargetedClassification, InputReduction): [ r"^textattack.models.helpers.lstm_for_classification.*", r"^textattack.models.helpers.word_cnn_for_classification.*", r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", ], ( NonOverlappingOutput, MinimizeBleu, ): [ r"^textattack.models.helpers.t5_for_text_to_text.*", ], } # Unroll the `MODELS_BY_GOAL_FUNCTIONS` dictionary into a dictionary that has # a key for each goal function. (Note the plurality here that distinguishes # the two variables from one another.) MODELS_BY_GOAL_FUNCTION = {} for goal_functions, matching_model_globs in MODELS_BY_GOAL_FUNCTIONS.items(): for goal_function in goal_functions: MODELS_BY_GOAL_FUNCTION[goal_function] = matching_model_globs def validate_model_goal_function_compatibility(goal_function_class, model_class): """Determines if ``model_class`` is task-compatible with ``goal_function_class``. For example, a text-generative model like one intended for translation or summarization would not be compatible with a goal function that requires probability scores, like the UntargetedGoalFunction. """ # Verify that this is a valid goal function. try: matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class] except KeyError: matching_model_globs = [] logger.warn(f"No entry found for goal function {goal_function_class}.") # Get options for this goal function. # model_module = model_class.__module__ model_module_path = ".".join((model_class.__module__, model_class.__name__)) # Ensure the model matches one of these options. for glob in matching_model_globs: if re.match(glob, model_module_path): logger.info( f"Goal function {goal_function_class} compatible with model {model_class.__name__}." ) return # If we got here, the model does not match the intended goal function. for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items(): for glob in globs: if re.match(glob, model_module_path): logger.warn( f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}." f" Found match with other goal functions: {goal_functions}." ) return # If it matches another goal function, warn user. # Otherwise, this is an unknown model–perhaps user-provided, or we forgot to # update the corresponding dictionary. Warn user and return. logger.warn( f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}." ) def validate_model_gradient_word_swap_compatibility(model): """Determines if ``model`` is task-compatible with ``GradientBasedWordSwap``. We can only take the gradient with respect to an individual word if the model uses a word-based tokenizer. """ if isinstance(model, textattack.models.helpers.LSTMForClassification): return True else: raise ValueError(f"Cannot perform GradientBasedWordSwap on model {model}.") def transformation_consists_of(transformation, transformation_classes): """Determines if ``transformation`` is or consists only of instances of a class in ``transformation_classes``""" from textattack.transformations import CompositeTransformation if isinstance(transformation, CompositeTransformation): for t in transformation.transformations: if not transformation_consists_of(t, transformation_classes): return False return True else: for transformation_class in transformation_classes: if isinstance(transformation, transformation_class): return True return False def transformation_consists_of_word_swaps(transformation): """Determines if ``transformation`` is a word swap or consists of only word swaps.""" from textattack.transformations import WordSwap, WordSwapGradientBased return transformation_consists_of(transformation, [WordSwap, WordSwapGradientBased]) def transformation_consists_of_word_swaps_and_deletions(transformation): """Determines if ``transformation`` is a word swap or consists of only word swaps and deletions.""" from textattack.transformations import WordDeletion, WordSwap, WordSwapGradientBased return transformation_consists_of( transformation, [WordDeletion, WordSwap, WordSwapGradientBased] )