Spaces:
Sleeping
Sleeping
""" | |
Trainer Class | |
============= | |
""" | |
import collections | |
import json | |
import logging | |
import math | |
import os | |
import scipy | |
import torch | |
import tqdm | |
import transformers | |
import textattack | |
from textattack.shared.utils import logger | |
from .attack import Attack | |
from .attack_args import AttackArgs | |
from .attack_results import MaximizedAttackResult, SuccessfulAttackResult | |
from .attacker import Attacker | |
from .model_args import HUGGINGFACE_MODELS | |
from .models.helpers import LSTMForClassification, WordCNNForClassification | |
from .models.wrappers import ModelWrapper | |
from .training_args import CommandLineTrainingArgs, TrainingArgs | |
class Trainer: | |
"""Trainer is training and eval loop for adversarial training. | |
It is designed to work with PyTorch and Transformers models. | |
Args: | |
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): | |
Model wrapper containing both the model and the tokenizer. | |
task_type (:obj:`str`, `optional`, defaults to :obj:`"classification"`): | |
The task that the model is trained to perform. | |
Currently, :class:`~textattack.Trainer` supports two tasks: (1) :obj:`"classification"`, (2) :obj:`"regression"`. | |
attack (:class:`~textattack.Attack`): | |
:class:`~textattack.Attack` used to generate adversarial examples for training. | |
train_dataset (:class:`~textattack.datasets.Dataset`): | |
Dataset for training. | |
eval_dataset (:class:`~textattack.datasets.Dataset`): | |
Dataset for evaluation | |
training_args (:class:`~textattack.TrainingArgs`): | |
Arguments for training. | |
Example:: | |
>>> import textattack | |
>>> import transformers | |
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") | |
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") | |
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
>>> # We only use DeepWordBugGao2018 to demonstration purposes. | |
>>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) | |
>>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") | |
>>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | |
>>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). | |
>>> training_args = textattack.TrainingArgs( | |
... num_epochs=3, | |
... num_clean_epochs=1, | |
... num_train_adv_examples=1000, | |
... learning_rate=5e-5, | |
... per_device_train_batch_size=8, | |
... gradient_accumulation_steps=4, | |
... log_to_tb=True, | |
... ) | |
>>> trainer = textattack.Trainer( | |
... model_wrapper, | |
... "classification", | |
... attack, | |
... train_dataset, | |
... eval_dataset, | |
... training_args | |
... ) | |
>>> trainer.train() | |
.. note:: | |
When using :class:`~textattack.Trainer` with `parallel=True` in :class:`~textattack.TrainingArgs`, | |
make sure to protect the “entry point” of the program by using :obj:`if __name__ == '__main__':`. | |
If not, each worker process used for generating adversarial examples will execute the training code again. | |
""" | |
def __init__( | |
self, | |
model_wrapper, | |
task_type="classification", | |
attack=None, | |
train_dataset=None, | |
eval_dataset=None, | |
training_args=None, | |
): | |
assert isinstance( | |
model_wrapper, ModelWrapper | |
), f"`model_wrapper` must be of type `textattack.models.wrappers.ModelWrapper`, but got type `{type(model_wrapper)}`." | |
# TODO: Support seq2seq training | |
assert task_type in { | |
"classification", | |
"regression", | |
}, '`task_type` must either be "classification" or "regression"' | |
if attack: | |
assert isinstance( | |
attack, Attack | |
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`." | |
if id(model_wrapper) != id(attack.goal_function.model): | |
logger.warn( | |
"`model_wrapper` and the victim model of `attack` are not the same model." | |
) | |
if train_dataset: | |
assert isinstance( | |
train_dataset, textattack.datasets.Dataset | |
), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`." | |
if eval_dataset: | |
assert isinstance( | |
eval_dataset, textattack.datasets.Dataset | |
), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`." | |
if training_args: | |
assert isinstance( | |
training_args, TrainingArgs | |
), f"`training_args` must be of type `textattack.TrainingArgs`, but got type `{type(training_args)}`." | |
else: | |
training_args = TrainingArgs() | |
if not hasattr(model_wrapper, "model"): | |
raise ValueError("Cannot detect `model` in `model_wrapper`") | |
else: | |
assert isinstance( | |
model_wrapper.model, torch.nn.Module | |
), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`." | |
if not hasattr(model_wrapper, "tokenizer"): | |
raise ValueError("Cannot detect `tokenizer` in `model_wrapper`") | |
self.model_wrapper = model_wrapper | |
self.task_type = task_type | |
self.attack = attack | |
self.train_dataset = train_dataset | |
self.eval_dataset = eval_dataset | |
self.training_args = training_args | |
self._metric_name = ( | |
"pearson_correlation" if self.task_type == "regression" else "accuracy" | |
) | |
if self.task_type == "regression": | |
self.loss_fct = torch.nn.MSELoss(reduction="none") | |
else: | |
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
self._global_step = 0 | |
def _generate_adversarial_examples(self, epoch): | |
"""Generate adversarial examples using attacker.""" | |
assert ( | |
self.attack is not None | |
), "`attack` is `None` but attempting to generate adversarial examples." | |
base_file_name = f"attack-train-{epoch}" | |
log_file_name = os.path.join(self.training_args.output_dir, base_file_name) | |
logger.info("Attacking model to generate new adversarial training set...") | |
if isinstance(self.training_args.num_train_adv_examples, float): | |
num_train_adv_examples = math.ceil( | |
len(self.train_dataset) * self.training_args.num_train_adv_examples | |
) | |
else: | |
num_train_adv_examples = self.training_args.num_train_adv_examples | |
# Use Different AttackArgs based on num_train_adv_examples value. | |
# If num_train_adv_examples >= 0 , num_train_adv_examples is | |
# set as number of successful examples. | |
# If num_train_adv_examples == -1 , num_examples is set to -1 to | |
# generate example for all of training data. | |
if num_train_adv_examples >= 0: | |
attack_args = AttackArgs( | |
num_successful_examples=num_train_adv_examples, | |
num_examples_offset=0, | |
query_budget=self.training_args.query_budget_train, | |
shuffle=True, | |
parallel=self.training_args.parallel, | |
num_workers_per_device=self.training_args.attack_num_workers_per_device, | |
disable_stdout=True, | |
silent=True, | |
log_to_txt=log_file_name + ".txt", | |
log_to_csv=log_file_name + ".csv", | |
) | |
elif num_train_adv_examples == -1: | |
# set num_examples when num_train_adv_examples = -1 | |
attack_args = AttackArgs( | |
num_examples=num_train_adv_examples, | |
num_examples_offset=0, | |
query_budget=self.training_args.query_budget_train, | |
shuffle=True, | |
parallel=self.training_args.parallel, | |
num_workers_per_device=self.training_args.attack_num_workers_per_device, | |
disable_stdout=True, | |
silent=True, | |
log_to_txt=log_file_name + ".txt", | |
log_to_csv=log_file_name + ".csv", | |
) | |
else: | |
assert False, "num_train_adv_examples is negative and not equal to -1." | |
attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args) | |
results = attacker.attack_dataset() | |
attack_types = collections.Counter(r.__class__.__name__ for r in results) | |
total_attacks = ( | |
attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"] | |
) | |
success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100 | |
logger.info(f"Total number of attack results: {len(results)}") | |
logger.info( | |
f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]" | |
) | |
# TODO: This will produce a bug if we need to manipulate ground truth output. | |
# To Fix Issue #498 , We need to add the Non Output columns in one tuple to represent input columns | |
# Since adversarial_example won't be an input to the model , we will have to remove it from the input | |
# dictionary in collate_fn | |
adversarial_examples = [ | |
( | |
tuple(r.perturbed_result.attacked_text._text_input.values()) | |
+ ("adversarial_example",), | |
r.perturbed_result.ground_truth_output, | |
) | |
for r in results | |
if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult)) | |
] | |
# Name for column indicating if an example is adversarial is set as "_example_type". | |
adversarial_dataset = textattack.datasets.Dataset( | |
adversarial_examples, | |
input_columns=self.train_dataset.input_columns + ("_example_type",), | |
label_map=self.train_dataset.label_map, | |
label_names=self.train_dataset.label_names, | |
output_scale_factor=self.train_dataset.output_scale_factor, | |
shuffle=False, | |
) | |
return adversarial_dataset | |
def _print_training_args( | |
self, total_training_steps, train_batch_size, num_clean_epochs | |
): | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(self.train_dataset)}") | |
logger.info(f" Num epochs = {self.training_args.num_epochs}") | |
logger.info(f" Num clean epochs = {num_clean_epochs}") | |
logger.info( | |
f" Instantaneous batch size per device = {self.training_args.per_device_train_batch_size}" | |
) | |
logger.info( | |
f" Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * self.training_args.gradient_accumulation_steps}" | |
) | |
logger.info( | |
f" Gradient accumulation steps = {self.training_args.gradient_accumulation_steps}" | |
) | |
logger.info(f" Total optimization steps = {total_training_steps}") | |
def _save_model_checkpoint( | |
self, model, tokenizer, step=None, epoch=None, best=False, last=False | |
): | |
# Save model checkpoint | |
if step: | |
dir_name = f"checkpoint-step-{step}" | |
if epoch: | |
dir_name = f"checkpoint-epoch-{epoch}" | |
if best: | |
dir_name = "best_model" | |
if last: | |
dir_name = "last_model" | |
output_dir = os.path.join(self.training_args.output_dir, dir_name) | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
if isinstance(model, torch.nn.DataParallel): | |
model = model.module | |
if isinstance(model, (WordCNNForClassification, LSTMForClassification)): | |
model.save_pretrained(output_dir) | |
elif isinstance(model, transformers.PreTrainedModel): | |
model.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
else: | |
state_dict = {k: v.cpu() for k, v in model.state_dict().items()} | |
torch.save( | |
state_dict, | |
os.path.join(output_dir, "pytorch_model.bin"), | |
) | |
def _tb_log(self, log, step): | |
if not hasattr(self, "_tb_writer"): | |
from torch.utils.tensorboard import SummaryWriter | |
self._tb_writer = SummaryWriter(self.training_args.tb_log_dir) | |
self._tb_writer.add_hparams(self.training_args.__dict__, {}) | |
self._tb_writer.flush() | |
for key in log: | |
self._tb_writer.add_scalar(key, log[key], step) | |
def _wandb_log(self, log, step): | |
if not hasattr(self, "_wandb_init"): | |
global wandb | |
import wandb | |
self._wandb_init = True | |
wandb.init( | |
project=self.training_args.wandb_project, | |
config=self.training_args.__dict__, | |
) | |
wandb.log(log, step=step) | |
def get_optimizer_and_scheduler(self, model, num_training_steps): | |
"""Returns optimizer and scheduler to use for training. If you are | |
overriding this method and do not want to use a scheduler, simply | |
return :obj:`None` for scheduler. | |
Args: | |
model (:obj:`torch.nn.Module`): | |
Model to be trained. Pass its parameters to optimizer for training. | |
num_training_steps (:obj:`int`): | |
Number of total training steps. | |
Returns: | |
Tuple of optimizer and scheduler :obj:`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]` | |
""" | |
if isinstance(model, torch.nn.DataParallel): | |
model = model.module | |
if isinstance(model, transformers.PreTrainedModel): | |
# Reference https://huggingface.co/transformers/training.html | |
param_optimizer = list(model.named_parameters()) | |
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p | |
for n, p in param_optimizer | |
if not any(nd in n for nd in no_decay) | |
], | |
"weight_decay": self.training_args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p for n, p in param_optimizer if any(nd in n for nd in no_decay) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer = transformers.optimization.AdamW( | |
optimizer_grouped_parameters, lr=self.training_args.learning_rate | |
) | |
if isinstance(self.training_args.num_warmup_steps, float): | |
num_warmup_steps = math.ceil( | |
self.training_args.num_warmup_steps * num_training_steps | |
) | |
else: | |
num_warmup_steps = self.training_args.num_warmup_steps | |
scheduler = transformers.optimization.get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
) | |
else: | |
optimizer = torch.optim.Adam( | |
filter(lambda x: x.requires_grad, model.parameters()), | |
lr=self.training_args.learning_rate, | |
) | |
scheduler = None | |
return optimizer, scheduler | |
def get_train_dataloader(self, dataset, adv_dataset, batch_size): | |
"""Returns the :obj:`torch.utils.data.DataLoader` for training. | |
Args: | |
dataset (:class:`~textattack.datasets.Dataset`): | |
Original training dataset. | |
adv_dataset (:class:`~textattack.datasets.Dataset`): | |
Adversarial examples generated from the original training dataset. :obj:`None` if no adversarial attack takes place. | |
batch_size (:obj:`int`): | |
Batch size for training. | |
Returns: | |
:obj:`torch.utils.data.DataLoader` | |
""" | |
# TODO: Add pairing option where we can pair original examples with adversarial examples. | |
# Helper functions for collating data | |
def collate_fn(data): | |
input_texts = [] | |
targets = [] | |
is_adv_sample = [] | |
for item in data: | |
if "_example_type" in item[0].keys(): | |
# Get example type value from OrderedDict and remove it | |
adv = item[0].pop("_example_type") | |
# with _example_type removed from item[0] OrderedDict | |
# all other keys should be part of input | |
_input, label = item | |
if adv != "adversarial_example": | |
raise ValueError( | |
"`item` has length of 3 but last element is not for marking if the item is an `adversarial example`." | |
) | |
else: | |
is_adv_sample.append(True) | |
else: | |
# else `len(item)` is 2. | |
_input, label = item | |
is_adv_sample.append(False) | |
if isinstance(_input, collections.OrderedDict): | |
_input = tuple(_input.values()) | |
else: | |
_input = tuple(_input) | |
if len(_input) == 1: | |
_input = _input[0] | |
input_texts.append(_input) | |
targets.append(label) | |
return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample) | |
if adv_dataset: | |
dataset = torch.utils.data.ConcatDataset([dataset, adv_dataset]) | |
train_dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=collate_fn, | |
pin_memory=True, | |
) | |
return train_dataloader | |
def get_eval_dataloader(self, dataset, batch_size): | |
"""Returns the :obj:`torch.utils.data.DataLoader` for evaluation. | |
Args: | |
dataset (:class:`~textattack.datasets.Dataset`): | |
Dataset to use for evaluation. | |
batch_size (:obj:`int`): | |
Batch size for evaluation. | |
Returns: | |
:obj:`torch.utils.data.DataLoader` | |
""" | |
# Helper functions for collating data | |
def collate_fn(data): | |
input_texts = [] | |
targets = [] | |
for _input, label in data: | |
if isinstance(_input, collections.OrderedDict): | |
_input = tuple(_input.values()) | |
else: | |
_input = tuple(_input) | |
if len(_input) == 1: | |
_input = _input[0] | |
input_texts.append(_input) | |
targets.append(label) | |
return input_texts, torch.tensor(targets) | |
eval_dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=collate_fn, | |
pin_memory=True, | |
) | |
return eval_dataloader | |
def training_step(self, model, tokenizer, batch): | |
"""Perform a single training step on a batch of inputs. | |
Args: | |
model (:obj:`torch.nn.Module`): | |
Model to train. | |
tokenizer: | |
Tokenizer used to tokenize input text. | |
batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`): | |
By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example. | |
.. note:: | |
If you override the :meth:`get_train_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. | |
Returns: | |
:obj:`tuple[torch.Tensor, torch.Tensor, torch.Tensor]` where | |
- **loss**: :obj:`torch.FloatTensor` of shape 1 containing the loss. | |
- **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. | |
- **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). | |
""" | |
input_texts, targets, is_adv_sample = batch | |
_targets = targets | |
targets = targets.to(textattack.shared.utils.device) | |
if isinstance(model, transformers.PreTrainedModel) or ( | |
isinstance(model, torch.nn.DataParallel) | |
and isinstance(model.module, transformers.PreTrainedModel) | |
): | |
input_ids = tokenizer( | |
input_texts, | |
padding="max_length", | |
return_tensors="pt", | |
truncation=True, | |
) | |
input_ids.to(textattack.shared.utils.device) | |
logits = model(**input_ids)[0] | |
else: | |
input_ids = tokenizer(input_texts) | |
if not isinstance(input_ids, torch.Tensor): | |
input_ids = torch.tensor(input_ids) | |
input_ids = input_ids.to(textattack.shared.utils.device) | |
logits = model(input_ids) | |
if self.task_type == "regression": | |
loss = self.loss_fct(logits.squeeze(), targets.squeeze()) | |
preds = logits | |
else: | |
loss = self.loss_fct(logits, targets) | |
preds = logits.argmax(dim=-1) | |
sample_weights = torch.ones( | |
is_adv_sample.size(), device=textattack.shared.utils.device | |
) | |
sample_weights[is_adv_sample] *= self.training_args.alpha | |
loss = loss * sample_weights | |
loss = torch.mean(loss) | |
preds = preds.cpu() | |
return loss, preds, _targets | |
def evaluate_step(self, model, tokenizer, batch): | |
"""Perform a single evaluation step on a batch of inputs. | |
Args: | |
model (:obj:`torch.nn.Module`): | |
Model to train. | |
tokenizer: | |
Tokenizer used to tokenize input text. | |
batch (:obj:`tuple[list[str], torch.Tensor]`): | |
By default, this will be a tuple of input texts and target tensors. | |
.. note:: | |
If you override the :meth:`get_eval_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. | |
Returns: | |
:obj:`tuple[torch.Tensor, torch.Tensor]` where | |
- **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. | |
- **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). | |
""" | |
input_texts, targets = batch | |
_targets = targets | |
targets = targets.to(textattack.shared.utils.device) | |
if isinstance(model, transformers.PreTrainedModel): | |
input_ids = tokenizer( | |
input_texts, | |
padding="max_length", | |
return_tensors="pt", | |
truncation=True, | |
) | |
input_ids.to(textattack.shared.utils.device) | |
logits = model(**input_ids)[0] | |
else: | |
input_ids = tokenizer(input_texts) | |
if not isinstance(input_ids, torch.Tensor): | |
input_ids = torch.tensor(input_ids) | |
input_ids = input_ids.to(textattack.shared.utils.device) | |
logits = model(input_ids) | |
if self.task_type == "regression": | |
preds = logits | |
else: | |
preds = logits.argmax(dim=-1) | |
return preds.cpu(), _targets | |
def train(self): | |
"""Train the model on given training dataset.""" | |
if not self.train_dataset: | |
raise ValueError("No `train_dataset` available for training.") | |
textattack.shared.utils.set_seed(self.training_args.random_seed) | |
if not os.path.exists(self.training_args.output_dir): | |
os.makedirs(self.training_args.output_dir) | |
# Save logger writes to file | |
log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt") | |
fh = logging.FileHandler(log_txt_path) | |
fh.setLevel(logging.DEBUG) | |
logger.addHandler(fh) | |
logger.info(f"Writing logs to {log_txt_path}.") | |
# Save original self.training_args to file | |
args_save_path = os.path.join( | |
self.training_args.output_dir, "training_args.json" | |
) | |
with open(args_save_path, "w", encoding="utf-8") as f: | |
json.dump(self.training_args.__dict__, f) | |
logger.info(f"Wrote original training args to {args_save_path}.") | |
num_gpus = torch.cuda.device_count() | |
tokenizer = self.model_wrapper.tokenizer | |
model = self.model_wrapper.model | |
if self.training_args.parallel and num_gpus > 1: | |
# TODO: torch.nn.parallel.DistributedDataParallel | |
# Supposedly faster than DataParallel, but requires more work to setup properly. | |
model = torch.nn.DataParallel(model) | |
logger.info(f"Training on {num_gpus} GPUs via `torch.nn.DataParallel`.") | |
train_batch_size = self.training_args.per_device_train_batch_size * num_gpus | |
else: | |
train_batch_size = self.training_args.per_device_train_batch_size | |
if self.attack is None: | |
num_clean_epochs = self.training_args.num_epochs | |
else: | |
num_clean_epochs = self.training_args.num_clean_epochs | |
total_clean_training_steps = ( | |
math.ceil( | |
len(self.train_dataset) | |
/ (train_batch_size * self.training_args.gradient_accumulation_steps) | |
) | |
* num_clean_epochs | |
) | |
# calculate total_adv_training_data_length based on type of | |
# num_train_adv_examples. | |
# if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset. | |
if isinstance(self.training_args.num_train_adv_examples, float): | |
total_adv_training_data_length = ( | |
len(self.train_dataset) * self.training_args.num_train_adv_examples | |
) | |
# if num_train_adv_examples is int and >=0 then it is taken as value. | |
elif ( | |
isinstance(self.training_args.num_train_adv_examples, int) | |
and self.training_args.num_train_adv_examples >= 0 | |
): | |
total_adv_training_data_length = self.training_args.num_train_adv_examples | |
# if num_train_adv_examples is = -1 , we generate all possible adv examples. | |
# Max number of all possible adv examples would be equal to train_dataset. | |
else: | |
total_adv_training_data_length = len(self.train_dataset) | |
# Based on total_adv_training_data_length calculation , find total total_adv_training_steps | |
total_adv_training_steps = math.ceil( | |
(len(self.train_dataset) + total_adv_training_data_length) | |
/ (train_batch_size * self.training_args.gradient_accumulation_steps) | |
) * (self.training_args.num_epochs - num_clean_epochs) | |
total_training_steps = total_clean_training_steps + total_adv_training_steps | |
optimizer, scheduler = self.get_optimizer_and_scheduler( | |
model, total_training_steps | |
) | |
self._print_training_args( | |
total_training_steps, train_batch_size, num_clean_epochs | |
) | |
model.to(textattack.shared.utils.device) | |
# Variables across epochs | |
self._total_loss = 0.0 | |
self._current_loss = 0.0 | |
self._last_log_step = 0 | |
# `best_score` is used to keep track of the best model across training. | |
# Could be loss, accuracy, or other metrics. | |
best_eval_score = 0.0 | |
best_eval_score_epoch = 0 | |
best_model_path = None | |
epochs_since_best_eval_score = 0 | |
for epoch in range(1, self.training_args.num_epochs + 1): | |
logger.info("==========================================================") | |
logger.info(f"Epoch {epoch}") | |
if self.attack and epoch > num_clean_epochs: | |
if ( | |
epoch - num_clean_epochs - 1 | |
) % self.training_args.attack_epoch_interval == 0: | |
# only generate a new adversarial training set every self.training_args.attack_period epochs after the clean epochs | |
# adv_dataset is instance of `textattack.datasets.Dataset` | |
model.eval() | |
adv_dataset = self._generate_adversarial_examples(epoch) | |
model.train() | |
model.to(textattack.shared.utils.device) | |
else: | |
adv_dataset = None | |
else: | |
logger.info(f"Running clean epoch {epoch}/{num_clean_epochs}") | |
adv_dataset = None | |
train_dataloader = self.get_train_dataloader( | |
self.train_dataset, adv_dataset, train_batch_size | |
) | |
model.train() | |
# Epoch variables | |
all_preds = [] | |
all_targets = [] | |
prog_bar = tqdm.tqdm( | |
train_dataloader, | |
desc="Iteration", | |
position=0, | |
leave=True, | |
dynamic_ncols=True, | |
) | |
for step, batch in enumerate(prog_bar): | |
loss, preds, targets = self.training_step(model, tokenizer, batch) | |
if isinstance(model, torch.nn.DataParallel): | |
loss = loss.mean() | |
loss = loss / self.training_args.gradient_accumulation_steps | |
loss.backward() | |
loss = loss.item() | |
self._total_loss += loss | |
self._current_loss += loss | |
all_preds.append(preds) | |
all_targets.append(targets) | |
if (step + 1) % self.training_args.gradient_accumulation_steps == 0: | |
optimizer.step() | |
if scheduler: | |
scheduler.step() | |
optimizer.zero_grad() | |
self._global_step += 1 | |
if self._global_step > 0: | |
prog_bar.set_description( | |
f"Loss {self._total_loss/self._global_step:.5f}" | |
) | |
# TODO: Better way to handle TB and Wandb logging | |
if (self._global_step > 0) and ( | |
self._global_step % self.training_args.logging_interval_step == 0 | |
): | |
lr_to_log = ( | |
scheduler.get_last_lr()[0] | |
if scheduler | |
else self.training_args.learning_rate | |
) | |
if self._global_step - self._last_log_step >= 1: | |
loss_to_log = round( | |
self._current_loss | |
/ (self._global_step - self._last_log_step), | |
4, | |
) | |
else: | |
loss_to_log = round(self._current_loss, 4) | |
log = {"train/loss": loss_to_log, "train/learning_rate": lr_to_log} | |
if self.training_args.log_to_tb: | |
self._tb_log(log, self._global_step) | |
if self.training_args.log_to_wandb: | |
self._wandb_log(log, self._global_step) | |
self._current_loss = 0.0 | |
self._last_log_step = self._global_step | |
# Save model checkpoint to file. | |
if self.training_args.checkpoint_interval_steps: | |
if ( | |
self._global_step > 0 | |
and ( | |
self._global_step | |
% self.training_args.checkpoint_interval_steps | |
) | |
== 0 | |
): | |
self._save_model_checkpoint( | |
model, tokenizer, step=self._global_step | |
) | |
preds = torch.cat(all_preds) | |
targets = torch.cat(all_targets) | |
if self._metric_name == "accuracy": | |
correct_predictions = (preds == targets).sum().item() | |
accuracy = correct_predictions / len(targets) | |
metric_log = {"train/train_accuracy": accuracy} | |
logger.info(f"Train accuracy: {accuracy*100:.2f}%") | |
else: | |
pearson_correlation, pearson_pvalue = scipy.stats.pearsonr( | |
preds, targets | |
) | |
metric_log = { | |
"train/pearson_correlation": pearson_correlation, | |
"train/pearson_pvalue": pearson_pvalue, | |
} | |
logger.info(f"Train Pearson correlation: {pearson_correlation:.4f}%") | |
if len(targets) > 0: | |
if self.training_args.log_to_tb: | |
self._tb_log(metric_log, epoch) | |
if self.training_args.log_to_wandb: | |
metric_log["epoch"] = epoch | |
self._wandb_log(metric_log, self._global_step) | |
# Evaluate after each epoch. | |
eval_score = self.evaluate() | |
if self.training_args.log_to_tb: | |
self._tb_log({f"eval/{self._metric_name}": eval_score}, epoch) | |
if self.training_args.log_to_wandb: | |
self._wandb_log( | |
{f"eval/{self._metric_name}": eval_score, "epoch": epoch}, | |
self._global_step, | |
) | |
if ( | |
self.training_args.checkpoint_interval_epochs | |
and (epoch % self.training_args.checkpoint_interval_epochs) == 0 | |
): | |
self._save_model_checkpoint(model, tokenizer, epoch=epoch) | |
if eval_score > best_eval_score: | |
best_eval_score = eval_score | |
best_eval_score_epoch = epoch | |
epochs_since_best_eval_score = 0 | |
self._save_model_checkpoint(model, tokenizer, best=True) | |
logger.info( | |
f"Best score found. Saved model to {self.training_args.output_dir}/best_model/" | |
) | |
else: | |
epochs_since_best_eval_score += 1 | |
if self.training_args.early_stopping_epochs and ( | |
epochs_since_best_eval_score | |
> self.training_args.early_stopping_epochs | |
): | |
logger.info( | |
f"Stopping early since it's been {self.training_args.early_stopping_epochs} steps since validation score increased." | |
) | |
break | |
if self.training_args.log_to_tb: | |
self._tb_writer.flush() | |
# Finish training | |
if isinstance(model, torch.nn.DataParallel): | |
model = model.module | |
if self.training_args.load_best_model_at_end: | |
best_model_path = os.path.join(self.training_args.output_dir, "best_model") | |
if hasattr(model, "from_pretrained"): | |
model = model.__class__.from_pretrained(best_model_path) | |
else: | |
model = model.load_state_dict( | |
torch.load(os.path.join(best_model_path, "pytorch_model.bin")) | |
) | |
if self.training_args.save_last: | |
self._save_model_checkpoint(model, tokenizer, last=True) | |
self.model_wrapper.model = model | |
self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size) | |
def evaluate(self): | |
"""Evaluate the model on given evaluation dataset.""" | |
if not self.eval_dataset: | |
raise ValueError("No `eval_dataset` available for training.") | |
logging.info("Evaluating model on evaluation dataset.") | |
model = self.model_wrapper.model | |
tokenizer = self.model_wrapper.tokenizer | |
model.eval() | |
all_preds = [] | |
all_targets = [] | |
if isinstance(model, torch.nn.DataParallel): | |
num_gpus = torch.cuda.device_count() | |
eval_batch_size = self.training_args.per_device_eval_batch_size * num_gpus | |
else: | |
eval_batch_size = self.training_args.per_device_eval_batch_size | |
eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size) | |
with torch.no_grad(): | |
for step, batch in enumerate(eval_dataloader): | |
preds, targets = self.evaluate_step(model, tokenizer, batch) | |
all_preds.append(preds) | |
all_targets.append(targets) | |
preds = torch.cat(all_preds) | |
targets = torch.cat(all_targets) | |
if self.task_type == "regression": | |
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets) | |
eval_score = pearson_correlation | |
else: | |
correct_predictions = (preds == targets).sum().item() | |
accuracy = correct_predictions / len(targets) | |
eval_score = accuracy | |
if self._metric_name == "accuracy": | |
logger.info(f"Eval {self._metric_name}: {eval_score*100:.2f}%") | |
else: | |
logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%") | |
return eval_score | |
def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size): | |
if isinstance(self.training_args, CommandLineTrainingArgs): | |
model_name = self.training_args.model_name_or_path | |
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel): | |
if ( | |
hasattr(self.model_wrapper.model.config, "_name_or_path") | |
and self.model_wrapper.model.config._name_or_path in HUGGINGFACE_MODELS | |
): | |
# TODO Better way than just checking HUGGINGFACE_MODELS ? | |
model_name = self.model_wrapper.model.config._name_or_path | |
elif hasattr(self.model_wrapper.model.config, "model_type"): | |
model_name = self.model_wrapper.model.config.model_type | |
else: | |
model_name = "" | |
else: | |
model_name = "" | |
if model_name: | |
model_name = f"`{model_name}`" | |
if ( | |
isinstance(self.training_args, CommandLineTrainingArgs) | |
and self.training_args.model_max_length | |
): | |
model_max_length = self.training_args.model_max_length | |
elif isinstance( | |
self.model_wrapper.model, | |
( | |
transformers.PreTrainedModel, | |
LSTMForClassification, | |
WordCNNForClassification, | |
), | |
): | |
model_max_length = self.model_wrapper.tokenizer.model_max_length | |
else: | |
model_max_length = None | |
if model_max_length: | |
model_max_length_str = f" a maximum sequence length of {model_max_length}," | |
else: | |
model_max_length_str = "" | |
if isinstance( | |
self.train_dataset, textattack.datasets.HuggingFaceDataset | |
) and hasattr(self.train_dataset, "_name"): | |
dataset_name = self.train_dataset._name | |
if hasattr(self.train_dataset, "_subset"): | |
dataset_name += f" ({self.train_dataset._subset})" | |
elif isinstance( | |
self.eval_dataset, textattack.datasets.HuggingFaceDataset | |
) and hasattr(self.eval_dataset, "_name"): | |
dataset_name = self.eval_dataset._name | |
if hasattr(self.eval_dataset, "_subset"): | |
dataset_name += f" ({self.eval_dataset._subset})" | |
else: | |
dataset_name = None | |
if dataset_name: | |
dataset_str = ( | |
"and the `{dataset_name}` dataset loaded using the `datasets` library" | |
) | |
else: | |
dataset_str = "" | |
loss_func = ( | |
"mean squared error" if self.task_type == "regression" else "cross-entropy" | |
) | |
metric_name = ( | |
"pearson correlation" if self.task_type == "regression" else "accuracy" | |
) | |
epoch_info = f"{best_eval_score_epoch} epoch" + ( | |
"s" if best_eval_score_epoch > 1 else "" | |
) | |
readme_text = f""" | |
## TextAttack Model Card | |
This {model_name} model was fine-tuned using TextAttack{dataset_str}. The model was fine-tuned | |
for {self.training_args.num_epochs} epochs with a batch size of {train_batch_size}, | |
{model_max_length_str} and an initial learning rate of {self.training_args.learning_rate}. | |
Since this was a {self.task_type} task, the model was trained with a {loss_func} loss function. | |
The best score the model achieved on this task was {best_eval_score}, as measured by the | |
eval set {metric_name}, found after {epoch_info}. | |
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack). | |
""" | |
readme_save_path = os.path.join(self.training_args.output_dir, "README.md") | |
with open(readme_save_path, "w", encoding="utf-8") as f: | |
f.write(readme_text.strip() + "\n") | |
logger.info(f"Wrote README to {readme_save_path}.") | |