""" ListThingsCommand class ============================== """ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser import textattack from textattack.attack_args import ( ATTACK_RECIPE_NAMES, BLACK_BOX_TRANSFORMATION_CLASS_NAMES, CONSTRAINT_CLASS_NAMES, GOAL_FUNCTION_CLASS_NAMES, SEARCH_METHOD_CLASS_NAMES, WHITE_BOX_TRANSFORMATION_CLASS_NAMES, ) from textattack.augment_args import AUGMENTATION_RECIPE_NAMES from textattack.commands import TextAttackCommand from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS def _cb(s): return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") class ListThingsCommand(TextAttackCommand): """The list module: List default things in textattack. """ def _list(self, list_of_things, plain=False): """Prints a list or dict of things.""" if isinstance(list_of_things, list): list_of_things = sorted(list_of_things) for thing in list_of_things: if plain: print(thing) else: print(_cb(thing)) elif isinstance(list_of_things, dict): for thing in sorted(list_of_things.keys()): thing_long_description = list_of_things[thing] if plain: thing_key = thing else: thing_key = _cb(thing) print(f"{thing_key} ({thing_long_description})") else: raise TypeError(f"Cannot print list of type {type(list_of_things)}") @staticmethod def things(): list_dict = {} list_dict["models"] = list(HUGGINGFACE_MODELS.keys()) + list( TEXTATTACK_MODELS.keys() ) list_dict["search-methods"] = SEARCH_METHOD_CLASS_NAMES list_dict["transformations"] = { **BLACK_BOX_TRANSFORMATION_CLASS_NAMES, **WHITE_BOX_TRANSFORMATION_CLASS_NAMES, } list_dict["constraints"] = CONSTRAINT_CLASS_NAMES list_dict["goal-functions"] = GOAL_FUNCTION_CLASS_NAMES list_dict["attack-recipes"] = ATTACK_RECIPE_NAMES list_dict["augmentation-recipes"] = AUGMENTATION_RECIPE_NAMES return list_dict def run(self, args): try: list_of_things = ListThingsCommand.things()[args.feature] except KeyError: raise ValueError(f"Unknown list key {args.thing}") self._list(list_of_things, plain=args.plain) @staticmethod def register_subcommand(main_parser: ArgumentParser): parser = main_parser.add_parser( "list", help="list features in TextAttack", formatter_class=ArgumentDefaultsHelpFormatter, ) parser.add_argument( "feature", help="the feature to list", choices=ListThingsCommand.things() ) parser.add_argument( "--plain", help="print output without color", default=False, action="store_true", ) parser.set_defaults(func=ListThingsCommand())