Spaces:
Sleeping
Sleeping
""" | |
TrainingArgs Class | |
================== | |
""" | |
from dataclasses import dataclass, field | |
import datetime | |
import os | |
from typing import Union | |
from textattack.datasets import HuggingFaceDataset | |
from textattack.models.helpers import LSTMForClassification, WordCNNForClassification | |
from textattack.models.wrappers import ( | |
HuggingFaceModelWrapper, | |
ModelWrapper, | |
PyTorchModelWrapper, | |
) | |
from textattack.shared import logger | |
from textattack.shared.utils import ARGS_SPLIT_TOKEN | |
from .attack import Attack | |
from .attack_args import ATTACK_RECIPE_NAMES | |
def default_output_dir(): | |
return os.path.join( | |
"./outputs", datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") | |
) | |
class TrainingArgs: | |
"""Arguments for ``Trainer`` class that is used for adversarial training. | |
Args: | |
num_epochs (:obj:`int`, `optional`, defaults to :obj:`3`): | |
Total number of epochs for training. | |
num_clean_epochs (:obj:`int`, `optional`, defaults to :obj:`1`): | |
Number of epochs to train on just the original training dataset before adversarial training. | |
attack_epoch_interval (:obj:`int`, `optional`, defaults to :obj:`1`): | |
Generate a new adversarial training set every `N` epochs. | |
early_stopping_epochs (:obj:`int`, `optional`, defaults to :obj:`None`): | |
Number of epochs validation must increase before stopping early (:obj:`None` for no early stopping). | |
learning_rate (:obj:`float`, `optional`, defaults to :obj:`5e-5`): | |
Learning rate for optimizer. | |
num_warmup_steps (:obj:`int` or :obj:`float`, `optional`, defaults to :obj:`500`): | |
The number of steps for the warmup phase of linear scheduler. | |
If :obj:`num_warmup_steps` is a :obj:`float` between 0 and 1, the number of warmup steps will be :obj:`math.ceil(num_training_steps * num_warmup_steps)`. | |
weight_decay (:obj:`float`, `optional`, defaults to :obj:`0.01`): | |
Weight decay (L2 penalty). | |
per_device_train_batch_size (:obj:`int`, `optional`, defaults to :obj:`8`): | |
The batch size per GPU/CPU for training. | |
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`): | |
The batch size per GPU/CPU for evaluation. | |
gradient_accumulation_steps (:obj:`int`, `optional`, defaults to :obj:`1`): | |
Number of updates steps to accumulate the gradients before performing a backward/update pass. | |
random_seed (:obj:`int`, `optional`, defaults to :obj:`786`): | |
Random seed for reproducibility. | |
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If :obj:`True`, train using multiple GPUs using :obj:`torch.DataParallel`. | |
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If :obj:`True`, keep track of the best model across training and load it at the end. | |
alpha (:obj:`float`, `optional`, defaults to :obj:`1.0`): | |
The weight for adversarial loss. | |
num_train_adv_examples (:obj:`int` or :obj:`float`, `optional`, defaults to :obj:`-1`): | |
The number of samples to successfully attack when generating adversarial training set before start of every epoch. | |
If :obj:`num_train_adv_examples` is a :obj:`float` between 0 and 1, the number of adversarial examples generated is | |
fraction of the original training set. | |
query_budget_train (:obj:`int`, `optional`, defaults to :obj:`None`): | |
The max query budget to use when generating adversarial training set. :obj:`None` means infinite query budget. | |
attack_num_workers_per_device (:obj:`int`, defaults to `optional`, :obj:`1`): | |
Number of worker processes to run per device for attack. Same as :obj:`num_workers_per_device` argument for :class:`~textattack.AttackArgs`. | |
output_dir (:obj:`str`, `optional`): | |
Directory to output training logs and checkpoints. Defaults to :obj:`./outputs/%Y-%m-%d-%H-%M-%S-%f` format. | |
checkpoint_interval_steps (:obj:`int`, `optional`, defaults to :obj:`None`): | |
If set, save model checkpoint after every `N` updates to the model. | |
checkpoint_interval_epochs (:obj:`int`, `optional`, defaults to :obj:`None`): | |
If set, save model checkpoint after every `N` epochs. | |
save_last (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
If :obj:`True`, save the model at end of training. Can be used with :obj:`load_best_model_at_end` to save the best model at the end. | |
log_to_tb (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If :obj:`True`, log to Tensorboard. | |
tb_log_dir (:obj:`str`, `optional`, defaults to :obj:`"./runs"`): | |
Path of Tensorboard log directory. | |
log_to_wandb (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If :obj:`True`, log to Wandb. | |
wandb_project (:obj:`str`, `optional`, defaults to :obj:`"textattack"`): | |
Name of Wandb project for logging. | |
logging_interval_step (:obj:`int`, `optional`, defaults to :obj:`1`): | |
Log to Tensorboard/Wandb every `N` training steps. | |
""" | |
num_epochs: int = 3 | |
num_clean_epochs: int = 1 | |
attack_epoch_interval: int = 1 | |
early_stopping_epochs: int = None | |
learning_rate: float = 5e-5 | |
num_warmup_steps: Union[int, float] = 500 | |
weight_decay: float = 0.01 | |
per_device_train_batch_size: int = 8 | |
per_device_eval_batch_size: int = 32 | |
gradient_accumulation_steps: int = 1 | |
random_seed: int = 786 | |
parallel: bool = False | |
load_best_model_at_end: bool = False | |
alpha: float = 1.0 | |
num_train_adv_examples: Union[int, float] = -1 | |
query_budget_train: int = None | |
attack_num_workers_per_device: int = 1 | |
output_dir: str = field(default_factory=default_output_dir) | |
checkpoint_interval_steps: int = None | |
checkpoint_interval_epochs: int = None | |
save_last: bool = True | |
log_to_tb: bool = False | |
tb_log_dir: str = None | |
log_to_wandb: bool = False | |
wandb_project: str = "textattack" | |
logging_interval_step: int = 1 | |
def __post_init__(self): | |
assert self.num_epochs > 0, "`num_epochs` must be greater than 0." | |
assert ( | |
self.num_clean_epochs >= 0 | |
), "`num_clean_epochs` must be greater than or equal to 0." | |
if self.early_stopping_epochs is not None: | |
assert ( | |
self.early_stopping_epochs > 0 | |
), "`early_stopping_epochs` must be greater than 0." | |
if self.attack_epoch_interval is not None: | |
assert ( | |
self.attack_epoch_interval > 0 | |
), "`attack_epoch_interval` must be greater than 0." | |
assert ( | |
self.num_warmup_steps >= 0 | |
), "`num_warmup_steps` must be greater than or equal to 0." | |
assert ( | |
self.gradient_accumulation_steps > 0 | |
), "`gradient_accumulation_steps` must be greater than 0." | |
assert ( | |
self.num_clean_epochs <= self.num_epochs | |
), f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})." | |
if isinstance(self.num_train_adv_examples, float): | |
assert ( | |
self.num_train_adv_examples >= 0.0 | |
and self.num_train_adv_examples <= 1.0 | |
), "If `num_train_adv_examples` is float, it must be between 0 and 1." | |
elif isinstance(self.num_train_adv_examples, int): | |
assert ( | |
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1 | |
), "If `num_train_adv_examples` is int, it must be greater than 0 or equal to -1." | |
else: | |
raise TypeError( | |
"`num_train_adv_examples` must be of either type `int` or `float`." | |
) | |
def _add_parser_args(cls, parser): | |
"""Add listed args to command line parser.""" | |
default_obj = cls() | |
def int_or_float(v): | |
try: | |
return int(v) | |
except ValueError: | |
return float(v) | |
parser.add_argument( | |
"--num-epochs", | |
"--epochs", | |
type=int, | |
default=default_obj.num_epochs, | |
help="Total number of epochs for training.", | |
) | |
parser.add_argument( | |
"--num-clean-epochs", | |
type=int, | |
default=default_obj.num_clean_epochs, | |
help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)", | |
) | |
parser.add_argument( | |
"--attack-epoch-interval", | |
type=int, | |
default=default_obj.attack_epoch_interval, | |
help="Generate a new adversarial training set every N epochs.", | |
) | |
parser.add_argument( | |
"--early-stopping-epochs", | |
type=int, | |
default=default_obj.early_stopping_epochs, | |
help="Number of epochs validation must increase before stopping early (-1 for no early stopping)", | |
) | |
parser.add_argument( | |
"--learning-rate", | |
"--lr", | |
type=float, | |
default=default_obj.learning_rate, | |
help="Learning rate for Adam Optimization.", | |
) | |
parser.add_argument( | |
"--num-warmup-steps", | |
type=int_or_float, | |
default=default_obj.num_warmup_steps, | |
help="The number of steps for the warmup phase of linear scheduler.", | |
) | |
parser.add_argument( | |
"--weight-decay", | |
type=float, | |
default=default_obj.weight_decay, | |
help="Weight decay (L2 penalty).", | |
) | |
parser.add_argument( | |
"--per-device-train-batch-size", | |
type=int, | |
default=default_obj.per_device_train_batch_size, | |
help="The batch size per GPU/CPU for training.", | |
) | |
parser.add_argument( | |
"--per-device-eval-batch-size", | |
type=int, | |
default=default_obj.per_device_eval_batch_size, | |
help="The batch size per GPU/CPU for evaluation.", | |
) | |
parser.add_argument( | |
"--gradient-accumulation-steps", | |
type=int, | |
default=default_obj.gradient_accumulation_steps, | |
help="Number of updates steps to accumulate the gradients for, before performing a backward/update pass.", | |
) | |
parser.add_argument( | |
"--random-seed", | |
type=int, | |
default=default_obj.random_seed, | |
help="Random seed.", | |
) | |
parser.add_argument( | |
"--parallel", | |
action="store_true", | |
default=default_obj.parallel, | |
help="If set, run training on multiple GPUs.", | |
) | |
parser.add_argument( | |
"--load-best-model-at-end", | |
action="store_true", | |
default=default_obj.load_best_model_at_end, | |
help="If set, keep track of the best model across training and load it at the end.", | |
) | |
parser.add_argument( | |
"--alpha", | |
type=float, | |
default=1.0, | |
help="The weight of adversarial loss.", | |
) | |
parser.add_argument( | |
"--num-train-adv-examples", | |
type=int_or_float, | |
default=default_obj.num_train_adv_examples, | |
help="The number of samples to attack when generating adversarial training set. Default is -1 (which is all possible samples).", | |
) | |
parser.add_argument( | |
"--query-budget-train", | |
type=int, | |
default=default_obj.query_budget_train, | |
help="The max query budget to use when generating adversarial training set.", | |
) | |
parser.add_argument( | |
"--attack-num-workers-per-device", | |
type=int, | |
default=default_obj.attack_num_workers_per_device, | |
help="Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.", | |
) | |
parser.add_argument( | |
"--output-dir", | |
type=str, | |
default=default_output_dir(), | |
help="Directory to output training logs and checkpoints.", | |
) | |
parser.add_argument( | |
"--checkpoint-interval-steps", | |
type=int, | |
default=default_obj.checkpoint_interval_steps, | |
help="Save model checkpoint after every N updates to the model.", | |
) | |
parser.add_argument( | |
"--checkpoint-interval-epochs", | |
type=int, | |
default=default_obj.checkpoint_interval_epochs, | |
help="Save model checkpoint after every N epochs.", | |
) | |
parser.add_argument( | |
"--save-last", | |
action="store_true", | |
default=default_obj.save_last, | |
help="If set, save the model at end of training. Can be used with `--load-best-model-at-end` to save the best model at the end.", | |
) | |
parser.add_argument( | |
"--log-to-tb", | |
action="store_true", | |
default=default_obj.log_to_tb, | |
help="If set, log to Tensorboard", | |
) | |
parser.add_argument( | |
"--tb-log-dir", | |
type=str, | |
default=default_obj.tb_log_dir, | |
help="Path of Tensorboard log directory.", | |
) | |
parser.add_argument( | |
"--log-to-wandb", | |
action="store_true", | |
default=default_obj.log_to_wandb, | |
help="If set, log to Wandb.", | |
) | |
parser.add_argument( | |
"--wandb-project", | |
type=str, | |
default=default_obj.wandb_project, | |
help="Name of Wandb project for logging.", | |
) | |
parser.add_argument( | |
"--logging-interval-step", | |
type=int, | |
default=default_obj.logging_interval_step, | |
help="Log to Tensorboard/Wandb every N steps.", | |
) | |
return parser | |
class _CommandLineTrainingArgs: | |
"""Command line interface training args. | |
This requires more arguments to create models and get datasets. | |
Args: | |
model_name_or_path (str): Name or path of the model we want to create. "lstm" and "cnn" will create TextAttack\'s LSTM and CNN models while | |
any other input will be used to create Transformers model. (e.g."brt-base-uncased"). | |
attack (str): Attack recipe to use (enables adversarial training) | |
dataset (str): dataset for training; will be loaded from `datasets` library. | |
task_type (str): Type of task model is supposed to perform. Options: `classification`, `regression`. | |
model_max_length (int): The maximum sequence length of the model. | |
model_num_labels (int): The number of labels for classification (1 for regression). | |
dataset_train_split (str): Name of the train split. If not provided will try `train` as the split name. | |
dataset_eval_split (str): Name of the train split. If not provided will try `dev`, `validation`, or `eval` as split name. | |
""" | |
model_name_or_path: str | |
attack: str | |
dataset: str | |
task_type: str = "classification" | |
model_max_length: int = None | |
model_num_labels: int = None | |
dataset_train_split: str = None | |
dataset_eval_split: str = None | |
filter_train_by_labels: list = None | |
filter_eval_by_labels: list = None | |
def _add_parser_args(cls, parser): | |
# Arguments that are needed if we want to create a model to train. | |
parser.add_argument( | |
"--model-name-or-path", | |
"--model", | |
type=str, | |
required=True, | |
help='Name or path of the model we want to create. "lstm" and "cnn" will create TextAttack\'s LSTM and CNN models while' | |
' any other input will be used to create Transformers model. (e.g."brt-base-uncased").', | |
) | |
parser.add_argument( | |
"--model-max-length", | |
type=int, | |
default=None, | |
help="The maximum sequence length of the model.", | |
) | |
parser.add_argument( | |
"--model-num-labels", | |
type=int, | |
default=None, | |
help="The number of labels for classification.", | |
) | |
parser.add_argument( | |
"--attack", | |
type=str, | |
required=False, | |
default=None, | |
help="Attack recipe to use (enables adversarial training)", | |
) | |
parser.add_argument( | |
"--task-type", | |
type=str, | |
default="classification", | |
help="Type of task model is supposed to perform. Options: `classification`, `regression`.", | |
) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
required=True, | |
default="yelp", | |
help="dataset for training; will be loaded from " | |
"`datasets` library. if dataset has a subset, separate with a colon. " | |
" ex: `glue^sst2` or `rotten_tomatoes`", | |
) | |
parser.add_argument( | |
"--dataset-train-split", | |
type=str, | |
default="", | |
help="train dataset split, if non-standard " | |
"(can automatically detect 'train'", | |
) | |
parser.add_argument( | |
"--dataset-eval-split", | |
type=str, | |
default="", | |
help="val dataset split, if non-standard " | |
"(can automatically detect 'dev', 'validation', 'eval')", | |
) | |
parser.add_argument( | |
"--filter-train-by-labels", | |
nargs="+", | |
type=int, | |
required=False, | |
default=None, | |
help="List of labels to keep in the train dataset and discard all others.", | |
) | |
parser.add_argument( | |
"--filter-eval-by-labels", | |
nargs="+", | |
type=int, | |
required=False, | |
default=None, | |
help="List of labels to keep in the eval dataset and discard all others.", | |
) | |
return parser | |
def _create_model_from_args(cls, args): | |
"""Given ``CommandLineTrainingArgs``, return specified | |
``textattack.models.wrappers.ModelWrapper`` object.""" | |
assert isinstance( | |
args, cls | |
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." | |
if args.model_name_or_path == "lstm": | |
logger.info("Loading textattack model: LSTMForClassification") | |
max_seq_len = args.model_max_length if args.model_max_length else 128 | |
num_labels = args.model_num_labels if args.model_num_labels else 2 | |
model = LSTMForClassification( | |
max_seq_length=max_seq_len, | |
num_labels=num_labels, | |
emb_layer_trainable=True, | |
) | |
model = PyTorchModelWrapper(model, model.tokenizer) | |
elif args.model_name_or_path == "cnn": | |
logger.info("Loading textattack model: WordCNNForClassification") | |
max_seq_len = args.model_max_length if args.model_max_length else 128 | |
num_labels = args.model_num_labels if args.model_num_labels else 2 | |
model = WordCNNForClassification( | |
max_seq_length=max_seq_len, | |
num_labels=num_labels, | |
emb_layer_trainable=True, | |
) | |
model = PyTorchModelWrapper(model, model.tokenizer) | |
else: | |
import transformers | |
logger.info( | |
f"Loading transformers AutoModelForSequenceClassification: {args.model_name_or_path}" | |
) | |
max_seq_len = args.model_max_length if args.model_max_length else 512 | |
num_labels = args.model_num_labels if args.model_num_labels else 2 | |
config = transformers.AutoConfig.from_pretrained( | |
args.model_name_or_path, | |
num_labels=num_labels, | |
) | |
model = transformers.AutoModelForSequenceClassification.from_pretrained( | |
args.model_name_or_path, | |
config=config, | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
args.model_name_or_path, | |
model_max_length=max_seq_len, | |
) | |
model = HuggingFaceModelWrapper(model, tokenizer) | |
assert isinstance( | |
model, ModelWrapper | |
), "`model` must be of type `textattack.models.wrappers.ModelWrapper`." | |
return model | |
def _create_dataset_from_args(cls, args): | |
dataset_args = args.dataset.split(ARGS_SPLIT_TOKEN) | |
# TODO `HuggingFaceDataset` -> `HuggingFaceDataset` | |
if args.dataset_train_split: | |
train_dataset = HuggingFaceDataset( | |
*dataset_args, split=args.dataset_train_split | |
) | |
else: | |
try: | |
train_dataset = HuggingFaceDataset(*dataset_args, split="train") | |
args.dataset_train_split = "train" | |
except KeyError: | |
raise KeyError( | |
f"Error: no `train` split found in `{args.dataset}` dataset" | |
) | |
if args.dataset_eval_split: | |
eval_dataset = HuggingFaceDataset( | |
*dataset_args, split=args.dataset_eval_split | |
) | |
else: | |
# try common dev split names | |
try: | |
eval_dataset = HuggingFaceDataset(*dataset_args, split="dev") | |
args.dataset_eval_split = "dev" | |
except KeyError: | |
try: | |
eval_dataset = HuggingFaceDataset(*dataset_args, split="eval") | |
args.dataset_eval_split = "eval" | |
except KeyError: | |
try: | |
eval_dataset = HuggingFaceDataset( | |
*dataset_args, split="validation" | |
) | |
args.dataset_eval_split = "validation" | |
except KeyError: | |
try: | |
eval_dataset = HuggingFaceDataset( | |
*dataset_args, split="test" | |
) | |
args.dataset_eval_split = "test" | |
except KeyError: | |
raise KeyError( | |
f"Could not find `dev`, `eval`, `validation`, or `test` split in dataset {args.dataset}." | |
) | |
if args.filter_train_by_labels: | |
train_dataset.filter_by_labels_(args.filter_train_by_labels) | |
if args.filter_eval_by_labels: | |
eval_dataset.filter_by_labels_(args.filter_eval_by_labels) | |
# Testing for Coverage of model return values with dataset. | |
num_labels = args.model_num_labels if args.model_num_labels else 2 | |
# Only Perform labels checks if output_column is equal to label. | |
if ( | |
train_dataset.output_column == "label" | |
and eval_dataset.output_column == "label" | |
): | |
train_dataset_labels = train_dataset._dataset["label"] | |
eval_dataset_labels = eval_dataset._dataset["label"] | |
train_dataset_labels_set = set(train_dataset_labels) | |
assert all( | |
label >= 0 | |
for label in train_dataset_labels_set | |
if isinstance(label, int) | |
), f"Train dataset has negative label/s {[label for label in train_dataset_labels_set if isinstance(label,int) and label < 0 ]} which is/are not supported by pytorch.Use --filter-train-by-labels to keep suitable labels" | |
assert num_labels >= len( | |
train_dataset_labels_set | |
), f"Model constructed has {num_labels} output nodes and train dataset has {len(train_dataset_labels_set)} labels , Model should have output nodes greater than or equal to labels in train dataset.Use --model-num-labels to set model's output nodes." | |
eval_dataset_labels_set = set(eval_dataset_labels) | |
assert all( | |
label >= 0 | |
for label in eval_dataset_labels_set | |
if isinstance(label, int) | |
), f"Eval dataset has negative label/s {[label for label in eval_dataset_labels_set if isinstance(label,int) and label < 0 ]} which is/are not supported by pytorch.Use --filter-eval-by-labels to keep suitable labels" | |
assert num_labels >= len( | |
set(eval_dataset_labels_set) | |
), f"Model constructed has {num_labels} output nodes and eval dataset has {len(eval_dataset_labels_set)} labels , Model should have output nodes greater than or equal to labels in eval dataset.Use --model-num-labels to set model's output nodes." | |
return train_dataset, eval_dataset | |
def _create_attack_from_args(cls, args, model_wrapper): | |
import textattack # noqa: F401 | |
if args.attack is None: | |
return None | |
assert ( | |
args.attack in ATTACK_RECIPE_NAMES | |
), f"Unavailable attack recipe {args.attack}" | |
attack = eval(f"{ATTACK_RECIPE_NAMES[args.attack]}.build(model_wrapper)") | |
assert isinstance( | |
attack, Attack | |
), "`attack` must be of type `textattack.Attack`." | |
return attack | |
# This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass. | |
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses | |
class CommandLineTrainingArgs(TrainingArgs, _CommandLineTrainingArgs): | |
def _add_parser_args(cls, parser): | |
parser = _CommandLineTrainingArgs._add_parser_args(parser) | |
parser = TrainingArgs._add_parser_args(parser) | |
return parser | |