Spaces:
Sleeping
Sleeping
""" | |
ListThingsCommand class | |
============================== | |
""" | |
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
import textattack | |
from textattack.attack_args import ( | |
ATTACK_RECIPE_NAMES, | |
BLACK_BOX_TRANSFORMATION_CLASS_NAMES, | |
CONSTRAINT_CLASS_NAMES, | |
GOAL_FUNCTION_CLASS_NAMES, | |
SEARCH_METHOD_CLASS_NAMES, | |
WHITE_BOX_TRANSFORMATION_CLASS_NAMES, | |
) | |
from textattack.augment_args import AUGMENTATION_RECIPE_NAMES | |
from textattack.commands import TextAttackCommand | |
from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS | |
def _cb(s): | |
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | |
class ListThingsCommand(TextAttackCommand): | |
"""The list module: | |
List default things in textattack. | |
""" | |
def _list(self, list_of_things, plain=False): | |
"""Prints a list or dict of things.""" | |
if isinstance(list_of_things, list): | |
list_of_things = sorted(list_of_things) | |
for thing in list_of_things: | |
if plain: | |
print(thing) | |
else: | |
print(_cb(thing)) | |
elif isinstance(list_of_things, dict): | |
for thing in sorted(list_of_things.keys()): | |
thing_long_description = list_of_things[thing] | |
if plain: | |
thing_key = thing | |
else: | |
thing_key = _cb(thing) | |
print(f"{thing_key} ({thing_long_description})") | |
else: | |
raise TypeError(f"Cannot print list of type {type(list_of_things)}") | |
def things(): | |
list_dict = {} | |
list_dict["models"] = list(HUGGINGFACE_MODELS.keys()) + list( | |
TEXTATTACK_MODELS.keys() | |
) | |
list_dict["search-methods"] = SEARCH_METHOD_CLASS_NAMES | |
list_dict["transformations"] = { | |
**BLACK_BOX_TRANSFORMATION_CLASS_NAMES, | |
**WHITE_BOX_TRANSFORMATION_CLASS_NAMES, | |
} | |
list_dict["constraints"] = CONSTRAINT_CLASS_NAMES | |
list_dict["goal-functions"] = GOAL_FUNCTION_CLASS_NAMES | |
list_dict["attack-recipes"] = ATTACK_RECIPE_NAMES | |
list_dict["augmentation-recipes"] = AUGMENTATION_RECIPE_NAMES | |
return list_dict | |
def run(self, args): | |
try: | |
list_of_things = ListThingsCommand.things()[args.feature] | |
except KeyError: | |
raise ValueError(f"Unknown list key {args.thing}") | |
self._list(list_of_things, plain=args.plain) | |
def register_subcommand(main_parser: ArgumentParser): | |
parser = main_parser.add_parser( | |
"list", | |
help="list features in TextAttack", | |
formatter_class=ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"feature", help="the feature to list", choices=ListThingsCommand.things() | |
) | |
parser.add_argument( | |
"--plain", | |
help="print output without color", | |
default=False, | |
action="store_true", | |
) | |
parser.set_defaults(func=ListThingsCommand()) | |