PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame
35.1 kB
"""
AttackArgs Class
================
"""
from dataclasses import dataclass, field
import json
import os
import sys
import time
from typing import Dict, Optional
import textattack
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
from .attack import Attack
from .dataset_args import DatasetArgs
from .model_args import ModelArgs
ATTACK_RECIPE_NAMES = {
"alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
"bae": "textattack.attack_recipes.BAEGarg2019",
"bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
"morpheus": "textattack.attack_recipes.MorpheusTan2020",
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
"pwws": "textattack.attack_recipes.PWWSRen2019",
"iga": "textattack.attack_recipes.IGAWang2019",
"pruthi": "textattack.attack_recipes.Pruthi2019",
"pso": "textattack.attack_recipes.PSOZang2020",
"checklist": "textattack.attack_recipes.CheckList2020",
"clare": "textattack.attack_recipes.CLARE2020",
"a2t": "textattack.attack_recipes.A2TYoo2021",
}
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
"random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
"word-deletion": "textattack.transformations.WordDeletion",
"word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
"word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
"word-swap-inflections": "textattack.transformations.WordSwapInflections",
"word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
"word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
"word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
"word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
"word-swap-hownet": "textattack.transformations.WordSwapHowNet",
"word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
}
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
"word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
}
CONSTRAINT_CLASS_NAMES = {
#
# Semantics constraints
#
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
"muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
"bert-score": "textattack.constraints.semantics.BERTScore",
#
# Grammaticality constraints
#
"lang-tool": "textattack.constraints.grammaticality.LanguageTool",
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
"cola": "textattack.constraints.grammaticality.COLA",
#
# Overlap constraints
#
"bleu": "textattack.constraints.overlap.BLEU",
"chrf": "textattack.constraints.overlap.chrF",
"edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
"meteor": "textattack.constraints.overlap.METEOR",
"max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
#
# Pre-transformation constraints
#
"repeat": "textattack.constraints.pre_transformation.RepeatModification",
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
"max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
}
SEARCH_METHOD_CLASS_NAMES = {
"beam-search": "textattack.search_methods.BeamSearch",
"greedy": "textattack.search_methods.GreedySearch",
"ga-word": "textattack.search_methods.GeneticAlgorithm",
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
"pso": "textattack.search_methods.ParticleSwarmOptimization",
}
GOAL_FUNCTION_CLASS_NAMES = {
#
# Classification goal functions
#
"targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
"untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
"input-reduction": "textattack.goal_functions.classification.InputReduction",
#
# Text goal functions
#
"minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
"non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
"text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
}
@dataclass
class AttackArgs:
"""Attack arguments to be passed to :class:`~textattack.Attacker`.
Args:
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
The number of examples to attack. :obj:`-1` for entire dataset.
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
The number of successful adversarial examples we want. This is different from :obj:`num_examples`
as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
until we have `N` successful cases.
.. note::
If set, this argument overrides `num_examples` argument.
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
The offset index to start at in the dataset.
attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run attack until total of `N` examples have been attacked (and not skipped).
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
:obj:`shuffle` can now be used with checkpoint saving.
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
The maximum number of model queries allowed per example attacked.
If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
.. note::
Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
The directory to save checkpoint files.
random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
Random seed for reproducibility.
parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
If :obj:`True`, run attack using multiple CPUs/GPUs.
num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
then 2 processes will be running in each GPU.
log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
If set, save attack logs as a `.txt` file to the directory specified by this argument.
If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
If set, save attack logs as a CSV file to the directory specified by this argument.
If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
:obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
key and its corresponding value: :obj:`"project"`.
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable displaying individual attack results to stdout.
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
"""
num_examples: int = 10
num_successful_examples: int = None
num_examples_offset: int = 0
attack_n: bool = False
shuffle: bool = False
query_budget: int = None
checkpoint_interval: int = None
checkpoint_dir: str = "checkpoints"
random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
parallel: bool = False
num_workers_per_device: int = 1
log_to_txt: str = None
log_to_csv: str = None
log_summary_to_json: str = None
csv_coloring_style: str = "file"
log_to_visdom: dict = None
log_to_wandb: dict = None
disable_stdout: bool = False
silent: bool = False
enable_advance_metrics: bool = False
metrics: Optional[Dict] = None
def __post_init__(self):
if self.num_successful_examples:
self.num_examples = None
if self.num_examples:
assert (
self.num_examples >= 0 or self.num_examples == -1
), "`num_examples` must be greater than or equal to 0 or equal to -1."
if self.num_successful_examples:
assert (
self.num_successful_examples >= 0
), "`num_examples` must be greater than or equal to 0."
if self.query_budget:
assert self.query_budget > 0, "`query_budget` must be greater than 0."
if self.checkpoint_interval:
assert (
self.checkpoint_interval > 0
), "`checkpoint_interval` must be greater than 0."
assert (
self.num_workers_per_device > 0
), "`num_workers_per_device` must be greater than 0."
@classmethod
def _add_parser_args(cls, parser):
"""Add listed args to command line parser."""
default_obj = cls()
num_ex_group = parser.add_mutually_exclusive_group(required=False)
num_ex_group.add_argument(
"--num-examples",
"-n",
type=int,
default=default_obj.num_examples,
help="The number of examples to process, -1 for entire dataset.",
)
num_ex_group.add_argument(
"--num-successful-examples",
type=int,
default=default_obj.num_successful_examples,
help="The number of successful adversarial examples we want.",
)
parser.add_argument(
"--num-examples-offset",
"-o",
type=int,
required=False,
default=default_obj.num_examples_offset,
help="The offset to start at in the dataset.",
)
parser.add_argument(
"--query-budget",
"-q",
type=int,
default=default_obj.query_budget,
help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
)
parser.add_argument(
"--shuffle",
action="store_true",
default=default_obj.shuffle,
help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
)
parser.add_argument(
"--attack-n",
action="store_true",
default=default_obj.attack_n,
help="Whether to run attack until `n` examples have been attacked (not skipped).",
)
parser.add_argument(
"--checkpoint-dir",
required=False,
type=str,
default=default_obj.checkpoint_dir,
help="The directory to save checkpoint files.",
)
parser.add_argument(
"--checkpoint-interval",
required=False,
type=int,
default=default_obj.checkpoint_interval,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
parser.add_argument(
"--random-seed",
default=default_obj.random_seed,
type=int,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--parallel",
action="store_true",
default=default_obj.parallel,
help="Run attack using multiple GPUs.",
)
parser.add_argument(
"--num-workers-per-device",
default=default_obj.num_workers_per_device,
type=int,
help="Number of worker processes to run per device.",
)
parser.add_argument(
"--log-to-txt",
nargs="?",
default=default_obj.log_to_txt,
const="",
type=str,
help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
"If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
)
parser.add_argument(
"--log-to-csv",
nargs="?",
default=default_obj.log_to_csv,
const="",
type=str,
help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
"If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
)
parser.add_argument(
"--log-summary-to-json",
nargs="?",
default=default_obj.log_summary_to_json,
const="",
type=str,
help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
"If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
)
parser.add_argument(
"--csv-coloring-style",
default=default_obj.csv_coloring_style,
type=str,
help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
'"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
)
parser.add_argument(
"--log-to-visdom",
nargs="?",
default=None,
const='{"env": "main", "port": 8097, "hostname": "localhost"}',
type=json.loads,
help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
'three keys and their corresponding values: `"env", "port", "hostname"`. '
'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
)
parser.add_argument(
"--log-to-wandb",
nargs="?",
default=None,
const='{"project": "textattack"}',
type=json.loads,
help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
'key and its corresponding value: `"project"`. '
'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
)
parser.add_argument(
"--disable-stdout",
action="store_true",
default=default_obj.disable_stdout,
help="Disable logging attack results to stdout",
)
parser.add_argument(
"--silent",
action="store_true",
default=default_obj.silent,
help="Disable all logging",
)
parser.add_argument(
"--enable-advance-metrics",
action="store_true",
default=default_obj.enable_advance_metrics,
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
)
return parser
@classmethod
def create_loggers_from_args(cls, args):
"""Creates AttackLogManager from an AttackArgs object."""
assert isinstance(
args, cls
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
# Create logger
attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
# Get current time for file naming
timestamp = time.strftime("%Y-%m-%d-%H-%M")
# if '--log-to-txt' specified with arguments
if args.log_to_txt is not None:
if args.log_to_txt.lower().endswith(".txt"):
txt_file_path = args.log_to_txt
else:
txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")
dir_path = os.path.dirname(txt_file_path)
dir_path = dir_path if dir_path else "."
if not os.path.exists(dir_path):
os.makedirs(os.path.dirname(txt_file_path))
color_method = "file"
attack_log_manager.add_output_file(txt_file_path, color_method)
# if '--log-to-csv' specified with arguments
if args.log_to_csv is not None:
if args.log_to_csv.lower().endswith(".csv"):
csv_file_path = args.log_to_csv
else:
csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")
dir_path = os.path.dirname(csv_file_path)
dir_path = dir_path if dir_path else "."
if not os.path.exists(dir_path):
os.makedirs(dir_path)
color_method = (
None if args.csv_coloring_style == "plain" else args.csv_coloring_style
)
attack_log_manager.add_output_csv(csv_file_path, color_method)
# if '--log-summary-to-json' specified with arguments
if args.log_summary_to_json is not None:
if args.log_summary_to_json.lower().endswith(".json"):
summary_json_file_path = args.log_summary_to_json
else:
summary_json_file_path = os.path.join(
args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
)
dir_path = os.path.dirname(summary_json_file_path)
dir_path = dir_path if dir_path else "."
if not os.path.exists(dir_path):
os.makedirs(os.path.dirname(summary_json_file_path))
attack_log_manager.add_output_summary_json(summary_json_file_path)
# Visdom
if args.log_to_visdom is not None:
attack_log_manager.enable_visdom(**args.log_to_visdom)
# Weights & Biases
if args.log_to_wandb is not None:
attack_log_manager.enable_wandb(**args.log_to_wandb)
# Stdout
if not args.disable_stdout and not sys.stdout.isatty():
attack_log_manager.disable_color()
elif not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager
@dataclass
class _CommandLineAttackArgs:
"""Attack args for command line execution. This requires more arguments to
create ``Attack`` object as specified.
Args:
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
Name of transformation to use.
constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
List of names of constraints to use.
goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
Name of goal function to use.
search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
Name of search method to use.
attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
Name of attack recipe to use.
.. note::
Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
.. note::
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, carry attack in interactive mode.
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, attack in parallel.
model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
The batch size for making queries to the victim model.
model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
The maximum number of items to keep in the model results cache at once.
constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
The maximum number of items to keep in the constraints cache at once.
"""
transformation: str = "word-swap-embedding"
constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
goal_function: str = "untargeted-classification"
search_method: str = "greedy-word-wir"
attack_recipe: str = None
attack_from_file: str = None
interactive: bool = False
parallel: bool = False
model_batch_size: int = 32
model_cache_size: int = 2**18
constraint_cache_size: int = 2**18
@classmethod
def _add_parser_args(cls, parser):
"""Add listed args to command line parser."""
default_obj = cls()
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
)
parser.add_argument(
"--transformation",
type=str,
required=False,
default=default_obj.transformation,
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(transformation_names),
)
parser.add_argument(
"--constraints",
type=str,
required=False,
nargs="*",
default=default_obj.constraints,
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(CONSTRAINT_CLASS_NAMES.keys()),
)
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
parser.add_argument(
"--goal-function",
"-g",
default=default_obj.goal_function,
help=f"The goal function to use. choices: {goal_function_choices}",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
attack_group.add_argument(
"--search-method",
"--search",
"-s",
type=str,
required=False,
default=default_obj.search_method,
help=f"The search method to use. choices: {search_choices}",
)
attack_group.add_argument(
"--attack-recipe",
"--recipe",
"-r",
type=str,
required=False,
default=default_obj.attack_recipe,
help="full attack recipe (overrides provided goal function, transformation & constraints)",
choices=ATTACK_RECIPE_NAMES.keys(),
)
attack_group.add_argument(
"--attack-from-file",
type=str,
required=False,
default=default_obj.attack_from_file,
help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
)
parser.add_argument(
"--interactive",
action="store_true",
default=default_obj.interactive,
help="Whether to run attacks interactively.",
)
parser.add_argument(
"--model-batch-size",
type=int,
default=default_obj.model_batch_size,
help="The batch size for making calls to the model.",
)
parser.add_argument(
"--model-cache-size",
type=int,
default=default_obj.model_cache_size,
help="The maximum number of items to keep in the model results cache at once.",
)
parser.add_argument(
"--constraint-cache-size",
type=int,
default=default_obj.constraint_cache_size,
help="The maximum number of items to keep in the constraints cache at once.",
)
return parser
@classmethod
def _create_transformation_from_args(cls, args, model_wrapper):
"""Create `Transformation` based on provided `args` and
`model_wrapper`."""
transformation_name = args.transformation
if ARGS_SPLIT_TOKEN in transformation_name:
transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
)
else:
raise ValueError(
f"Error: unsupported transformation {transformation_name}"
)
else:
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
)
else:
raise ValueError(
f"Error: unsupported transformation {transformation_name}"
)
return transformation
@classmethod
def _create_goal_function_from_args(cls, args, model_wrapper):
"""Create `GoalFunction` based on provided `args` and
`model_wrapper`."""
goal_function = args.goal_function
if ARGS_SPLIT_TOKEN in goal_function:
goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
raise ValueError(
f"Error: unsupported goal_function {goal_function_name}"
)
goal_function = eval(
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
)
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
goal_function = eval(
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
)
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
if args.query_budget:
goal_function.query_budget = args.query_budget
goal_function.model_cache_size = args.model_cache_size
goal_function.batch_size = args.model_batch_size
return goal_function
@classmethod
def _create_constraints_from_args(cls, args):
"""Create list of `Constraints` based on provided `args`."""
if not args.constraints:
return []
_constraints = []
for constraint in args.constraints:
if ARGS_SPLIT_TOKEN in constraint:
constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
if constraint_name not in CONSTRAINT_CLASS_NAMES:
raise ValueError(f"Error: unsupported constraint {constraint_name}")
_constraints.append(
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
)
elif constraint in CONSTRAINT_CLASS_NAMES:
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
else:
raise ValueError(f"Error: unsupported constraint {constraint}")
return _constraints
@classmethod
def _create_attack_from_args(cls, args, model_wrapper):
"""Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
``Attack`` object."""
assert isinstance(
args, cls
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
if args.attack_recipe:
if ARGS_SPLIT_TOKEN in args.attack_recipe:
recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
if recipe_name not in ATTACK_RECIPE_NAMES:
raise ValueError(f"Error: unsupported recipe {recipe_name}")
recipe = eval(
f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
)
elif args.attack_recipe in ATTACK_RECIPE_NAMES:
recipe = eval(
f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
)
else:
raise ValueError(f"Invalid recipe {args.attack_recipe}")
if args.query_budget:
recipe.goal_function.query_budget = args.query_budget
recipe.goal_function.model_cache_size = args.model_cache_size
recipe.constraint_cache_size = args.constraint_cache_size
return recipe
elif args.attack_from_file:
if ARGS_SPLIT_TOKEN in args.attack_from_file:
attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
else:
attack_file, attack_name = args.attack_from_file, "attack"
attack_module = load_module_from_file(attack_file)
if not hasattr(attack_module, attack_name):
raise ValueError(
f"Loaded `{attack_file}` but could not find `{attack_name}`."
)
attack_func = getattr(attack_module, attack_name)
return attack_func(model_wrapper)
else:
goal_function = cls._create_goal_function_from_args(args, model_wrapper)
transformation = cls._create_transformation_from_args(args, model_wrapper)
constraints = cls._create_constraints_from_args(args)
if ARGS_SPLIT_TOKEN in args.search_method:
search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
if search_name not in SEARCH_METHOD_CLASS_NAMES:
raise ValueError(f"Error: unsupported search {search_name}")
search_method = eval(
f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
)
elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
search_method = eval(
f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
)
else:
raise ValueError(f"Error: unsupported attack {args.search_method}")
return Attack(
goal_function,
constraints,
transformation,
search_method,
constraint_cache_size=args.constraint_cache_size,
)
# This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@dataclass
class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
@classmethod
def _add_parser_args(cls, parser):
"""Add listed args to command line parser."""
parser = ModelArgs._add_parser_args(parser)
parser = DatasetArgs._add_parser_args(parser)
parser = _CommandLineAttackArgs._add_parser_args(parser)
parser = AttackArgs._add_parser_args(parser)
return parser