anonymous8/RPD-Demo
initial commit
4943752
raw
history blame
3.12 kB
"""
ListThingsCommand class
==============================
"""
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import textattack
from textattack.attack_args import (
ATTACK_RECIPE_NAMES,
BLACK_BOX_TRANSFORMATION_CLASS_NAMES,
CONSTRAINT_CLASS_NAMES,
GOAL_FUNCTION_CLASS_NAMES,
SEARCH_METHOD_CLASS_NAMES,
WHITE_BOX_TRANSFORMATION_CLASS_NAMES,
)
from textattack.augment_args import AUGMENTATION_RECIPE_NAMES
from textattack.commands import TextAttackCommand
from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
class ListThingsCommand(TextAttackCommand):
"""The list module:
List default things in textattack.
"""
def _list(self, list_of_things, plain=False):
"""Prints a list or dict of things."""
if isinstance(list_of_things, list):
list_of_things = sorted(list_of_things)
for thing in list_of_things:
if plain:
print(thing)
else:
print(_cb(thing))
elif isinstance(list_of_things, dict):
for thing in sorted(list_of_things.keys()):
thing_long_description = list_of_things[thing]
if plain:
thing_key = thing
else:
thing_key = _cb(thing)
print(f"{thing_key} ({thing_long_description})")
else:
raise TypeError(f"Cannot print list of type {type(list_of_things)}")
@staticmethod
def things():
list_dict = {}
list_dict["models"] = list(HUGGINGFACE_MODELS.keys()) + list(
TEXTATTACK_MODELS.keys()
)
list_dict["search-methods"] = SEARCH_METHOD_CLASS_NAMES
list_dict["transformations"] = {
**BLACK_BOX_TRANSFORMATION_CLASS_NAMES,
**WHITE_BOX_TRANSFORMATION_CLASS_NAMES,
}
list_dict["constraints"] = CONSTRAINT_CLASS_NAMES
list_dict["goal-functions"] = GOAL_FUNCTION_CLASS_NAMES
list_dict["attack-recipes"] = ATTACK_RECIPE_NAMES
list_dict["augmentation-recipes"] = AUGMENTATION_RECIPE_NAMES
return list_dict
def run(self, args):
try:
list_of_things = ListThingsCommand.things()[args.feature]
except KeyError:
raise ValueError(f"Unknown list key {args.thing}")
self._list(list_of_things, plain=args.plain)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"list",
help="list features in TextAttack",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"feature", help="the feature to list", choices=ListThingsCommand.things()
)
parser.add_argument(
"--plain",
help="print output without color",
default=False,
action="store_true",
)
parser.set_defaults(func=ListThingsCommand())