|
""" |
|
Misc Validators |
|
================= |
|
Validators ensure compatibility between search methods, transformations, constraints, and goal functions. |
|
|
|
""" |
|
import re |
|
|
|
import textattack |
|
from textattack.goal_functions import ( |
|
InputReduction, |
|
MinimizeBleu, |
|
NonOverlappingOutput, |
|
TargetedClassification, |
|
UntargetedClassification, |
|
) |
|
|
|
from . import logger |
|
|
|
|
|
MODELS_BY_GOAL_FUNCTIONS = { |
|
(TargetedClassification, UntargetedClassification, InputReduction): [ |
|
r"^textattack.models.helpers.lstm_for_classification.*", |
|
r"^textattack.models.helpers.word_cnn_for_classification.*", |
|
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", |
|
], |
|
( |
|
NonOverlappingOutput, |
|
MinimizeBleu, |
|
): [ |
|
r"^textattack.models.helpers.t5_for_text_to_text.*", |
|
], |
|
} |
|
|
|
|
|
|
|
|
|
MODELS_BY_GOAL_FUNCTION = {} |
|
for goal_functions, matching_model_globs in MODELS_BY_GOAL_FUNCTIONS.items(): |
|
for goal_function in goal_functions: |
|
MODELS_BY_GOAL_FUNCTION[goal_function] = matching_model_globs |
|
|
|
|
|
def validate_model_goal_function_compatibility(goal_function_class, model_class): |
|
"""Determines if ``model_class`` is task-compatible with |
|
``goal_function_class``. |
|
|
|
For example, a text-generative model like one intended for |
|
translation or summarization would not be compatible with a goal |
|
function that requires probability scores, like the |
|
UntargetedGoalFunction. |
|
""" |
|
|
|
try: |
|
matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class] |
|
except KeyError: |
|
matching_model_globs = [] |
|
logger.warn(f"No entry found for goal function {goal_function_class}.") |
|
|
|
|
|
model_module_path = ".".join((model_class.__module__, model_class.__name__)) |
|
|
|
for glob in matching_model_globs: |
|
if re.match(glob, model_module_path): |
|
logger.info( |
|
f"Goal function {goal_function_class} compatible with model {model_class.__name__}." |
|
) |
|
return |
|
|
|
for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items(): |
|
for glob in globs: |
|
if re.match(glob, model_module_path): |
|
logger.warn( |
|
f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}." |
|
f" Found match with other goal functions: {goal_functions}." |
|
) |
|
return |
|
|
|
|
|
|
|
|
|
logger.warn( |
|
f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}." |
|
) |
|
|
|
|
|
def validate_model_gradient_word_swap_compatibility(model): |
|
"""Determines if ``model`` is task-compatible with |
|
``GradientBasedWordSwap``. |
|
|
|
We can only take the gradient with respect to an individual word if |
|
the model uses a word-based tokenizer. |
|
""" |
|
if isinstance(model, textattack.models.helpers.LSTMForClassification): |
|
return True |
|
else: |
|
raise ValueError(f"Cannot perform GradientBasedWordSwap on model {model}.") |
|
|
|
|
|
def transformation_consists_of(transformation, transformation_classes): |
|
"""Determines if ``transformation`` is or consists only of instances of a |
|
class in ``transformation_classes``""" |
|
from textattack.transformations import CompositeTransformation |
|
|
|
if isinstance(transformation, CompositeTransformation): |
|
for t in transformation.transformations: |
|
if not transformation_consists_of(t, transformation_classes): |
|
return False |
|
return True |
|
else: |
|
for transformation_class in transformation_classes: |
|
if isinstance(transformation, transformation_class): |
|
return True |
|
return False |
|
|
|
|
|
def transformation_consists_of_word_swaps(transformation): |
|
"""Determines if ``transformation`` is a word swap or consists of only word |
|
swaps.""" |
|
from textattack.transformations import WordSwap, WordSwapGradientBased |
|
|
|
return transformation_consists_of(transformation, [WordSwap, WordSwapGradientBased]) |
|
|
|
|
|
def transformation_consists_of_word_swaps_and_deletions(transformation): |
|
"""Determines if ``transformation`` is a word swap or consists of only word |
|
swaps and deletions.""" |
|
from textattack.transformations import WordDeletion, WordSwap, WordSwapGradientBased |
|
|
|
return transformation_consists_of( |
|
transformation, [WordDeletion, WordSwap, WordSwapGradientBased] |
|
) |
|
|