anonymous8
update
d65ddc0
raw
history blame
41.7 kB
"""
Trainer Class
=============
"""
import collections
import json
import logging
import math
import os
import scipy
import torch
import tqdm
import transformers
import textattack
from textattack.shared.utils import logger
from .attack import Attack
from .attack_args import AttackArgs
from .attack_results import MaximizedAttackResult, SuccessfulAttackResult
from .attacker import Attacker
from .model_args import HUGGINGFACE_MODELS
from .models.helpers import LSTMForClassification, WordCNNForClassification
from .models.wrappers import ModelWrapper
from .training_args import CommandLineTrainingArgs, TrainingArgs
class Trainer:
"""Trainer is training and eval loop for adversarial training.
It is designed to work with PyTorch and Transformers models.
Args:
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
Model wrapper containing both the model and the tokenizer.
task_type (:obj:`str`, `optional`, defaults to :obj:`"classification"`):
The task that the model is trained to perform.
Currently, :class:`~textattack.Trainer` supports two tasks: (1) :obj:`"classification"`, (2) :obj:`"regression"`.
attack (:class:`~textattack.Attack`):
:class:`~textattack.Attack` used to generate adversarial examples for training.
train_dataset (:class:`~textattack.datasets.Dataset`):
Dataset for training.
eval_dataset (:class:`~textattack.datasets.Dataset`):
Dataset for evaluation
training_args (:class:`~textattack.TrainingArgs`):
Arguments for training.
Example::
>>> import textattack
>>> import transformers
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>>> # We only use DeepWordBugGao2018 to demonstration purposes.
>>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train")
>>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")
>>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4).
>>> training_args = textattack.TrainingArgs(
... num_epochs=3,
... num_clean_epochs=1,
... num_train_adv_examples=1000,
... learning_rate=5e-5,
... per_device_train_batch_size=8,
... gradient_accumulation_steps=4,
... log_to_tb=True,
... )
>>> trainer = textattack.Trainer(
... model_wrapper,
... "classification",
... attack,
... train_dataset,
... eval_dataset,
... training_args
... )
>>> trainer.train()
.. note::
When using :class:`~textattack.Trainer` with `parallel=True` in :class:`~textattack.TrainingArgs`,
make sure to protect the “entry point” of the program by using :obj:`if __name__ == '__main__':`.
If not, each worker process used for generating adversarial examples will execute the training code again.
"""
def __init__(
self,
model_wrapper,
task_type="classification",
attack=None,
train_dataset=None,
eval_dataset=None,
training_args=None,
):
assert isinstance(
model_wrapper, ModelWrapper
), f"`model_wrapper` must be of type `textattack.models.wrappers.ModelWrapper`, but got type `{type(model_wrapper)}`."
# TODO: Support seq2seq training
assert task_type in {
"classification",
"regression",
}, '`task_type` must either be "classification" or "regression"'
if attack:
assert isinstance(
attack, Attack
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`."
if id(model_wrapper) != id(attack.goal_function.model):
logger.warn(
"`model_wrapper` and the victim model of `attack` are not the same model."
)
if train_dataset:
assert isinstance(
train_dataset, textattack.datasets.Dataset
), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`."
if eval_dataset:
assert isinstance(
eval_dataset, textattack.datasets.Dataset
), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`."
if training_args:
assert isinstance(
training_args, TrainingArgs
), f"`training_args` must be of type `textattack.TrainingArgs`, but got type `{type(training_args)}`."
else:
training_args = TrainingArgs()
if not hasattr(model_wrapper, "model"):
raise ValueError("Cannot detect `model` in `model_wrapper`")
else:
assert isinstance(
model_wrapper.model, torch.nn.Module
), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`."
if not hasattr(model_wrapper, "tokenizer"):
raise ValueError("Cannot detect `tokenizer` in `model_wrapper`")
self.model_wrapper = model_wrapper
self.task_type = task_type
self.attack = attack
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.training_args = training_args
self._metric_name = (
"pearson_correlation" if self.task_type == "regression" else "accuracy"
)
if self.task_type == "regression":
self.loss_fct = torch.nn.MSELoss(reduction="none")
else:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
self._global_step = 0
def _generate_adversarial_examples(self, epoch):
"""Generate adversarial examples using attacker."""
assert (
self.attack is not None
), "`attack` is `None` but attempting to generate adversarial examples."
base_file_name = f"attack-train-{epoch}"
log_file_name = os.path.join(self.training_args.output_dir, base_file_name)
logger.info("Attacking model to generate new adversarial training set...")
if isinstance(self.training_args.num_train_adv_examples, float):
num_train_adv_examples = math.ceil(
len(self.train_dataset) * self.training_args.num_train_adv_examples
)
else:
num_train_adv_examples = self.training_args.num_train_adv_examples
# Use Different AttackArgs based on num_train_adv_examples value.
# If num_train_adv_examples >= 0 , num_train_adv_examples is
# set as number of successful examples.
# If num_train_adv_examples == -1 , num_examples is set to -1 to
# generate example for all of training data.
if num_train_adv_examples >= 0:
attack_args = AttackArgs(
num_successful_examples=num_train_adv_examples,
num_examples_offset=0,
query_budget=self.training_args.query_budget_train,
shuffle=True,
parallel=self.training_args.parallel,
num_workers_per_device=self.training_args.attack_num_workers_per_device,
disable_stdout=True,
silent=True,
log_to_txt=log_file_name + ".txt",
log_to_csv=log_file_name + ".csv",
)
elif num_train_adv_examples == -1:
# set num_examples when num_train_adv_examples = -1
attack_args = AttackArgs(
num_examples=num_train_adv_examples,
num_examples_offset=0,
query_budget=self.training_args.query_budget_train,
shuffle=True,
parallel=self.training_args.parallel,
num_workers_per_device=self.training_args.attack_num_workers_per_device,
disable_stdout=True,
silent=True,
log_to_txt=log_file_name + ".txt",
log_to_csv=log_file_name + ".csv",
)
else:
assert False, "num_train_adv_examples is negative and not equal to -1."
attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args)
results = attacker.attack_dataset()
attack_types = collections.Counter(r.__class__.__name__ for r in results)
total_attacks = (
attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"]
)
success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100
logger.info(f"Total number of attack results: {len(results)}")
logger.info(
f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]"
)
# TODO: This will produce a bug if we need to manipulate ground truth output.
# To Fix Issue #498 , We need to add the Non Output columns in one tuple to represent input columns
# Since adversarial_example won't be an input to the model , we will have to remove it from the input
# dictionary in collate_fn
adversarial_examples = [
(
tuple(r.perturbed_result.attacked_text._text_input.values())
+ ("adversarial_example",),
r.perturbed_result.ground_truth_output,
)
for r in results
if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult))
]
# Name for column indicating if an example is adversarial is set as "_example_type".
adversarial_dataset = textattack.datasets.Dataset(
adversarial_examples,
input_columns=self.train_dataset.input_columns + ("_example_type",),
label_map=self.train_dataset.label_map,
label_names=self.train_dataset.label_names,
output_scale_factor=self.train_dataset.output_scale_factor,
shuffle=False,
)
return adversarial_dataset
def _print_training_args(
self, total_training_steps, train_batch_size, num_clean_epochs
):
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(self.train_dataset)}")
logger.info(f" Num epochs = {self.training_args.num_epochs}")
logger.info(f" Num clean epochs = {num_clean_epochs}")
logger.info(
f" Instantaneous batch size per device = {self.training_args.per_device_train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * self.training_args.gradient_accumulation_steps}"
)
logger.info(
f" Gradient accumulation steps = {self.training_args.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {total_training_steps}")
def _save_model_checkpoint(
self, model, tokenizer, step=None, epoch=None, best=False, last=False
):
# Save model checkpoint
if step:
dir_name = f"checkpoint-step-{step}"
if epoch:
dir_name = f"checkpoint-epoch-{epoch}"
if best:
dir_name = "best_model"
if last:
dir_name = "last_model"
output_dir = os.path.join(self.training_args.output_dir, dir_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if isinstance(model, torch.nn.DataParallel):
model = model.module
if isinstance(model, (WordCNNForClassification, LSTMForClassification)):
model.save_pretrained(output_dir)
elif isinstance(model, transformers.PreTrainedModel):
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
else:
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
torch.save(
state_dict,
os.path.join(output_dir, "pytorch_model.bin"),
)
def _tb_log(self, log, step):
if not hasattr(self, "_tb_writer"):
from torch.utils.tensorboard import SummaryWriter
self._tb_writer = SummaryWriter(self.training_args.tb_log_dir)
self._tb_writer.add_hparams(self.training_args.__dict__, {})
self._tb_writer.flush()
for key in log:
self._tb_writer.add_scalar(key, log[key], step)
def _wandb_log(self, log, step):
if not hasattr(self, "_wandb_init"):
global wandb
import wandb
self._wandb_init = True
wandb.init(
project=self.training_args.wandb_project,
config=self.training_args.__dict__,
)
wandb.log(log, step=step)
def get_optimizer_and_scheduler(self, model, num_training_steps):
"""Returns optimizer and scheduler to use for training. If you are
overriding this method and do not want to use a scheduler, simply
return :obj:`None` for scheduler.
Args:
model (:obj:`torch.nn.Module`):
Model to be trained. Pass its parameters to optimizer for training.
num_training_steps (:obj:`int`):
Number of total training steps.
Returns:
Tuple of optimizer and scheduler :obj:`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`
"""
if isinstance(model, torch.nn.DataParallel):
model = model.module
if isinstance(model, transformers.PreTrainedModel):
# Reference https://huggingface.co/transformers/training.html
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in param_optimizer
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.training_args.weight_decay,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = transformers.optimization.AdamW(
optimizer_grouped_parameters, lr=self.training_args.learning_rate
)
if isinstance(self.training_args.num_warmup_steps, float):
num_warmup_steps = math.ceil(
self.training_args.num_warmup_steps * num_training_steps
)
else:
num_warmup_steps = self.training_args.num_warmup_steps
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
else:
optimizer = torch.optim.Adam(
filter(lambda x: x.requires_grad, model.parameters()),
lr=self.training_args.learning_rate,
)
scheduler = None
return optimizer, scheduler
def get_train_dataloader(self, dataset, adv_dataset, batch_size):
"""Returns the :obj:`torch.utils.data.DataLoader` for training.
Args:
dataset (:class:`~textattack.datasets.Dataset`):
Original training dataset.
adv_dataset (:class:`~textattack.datasets.Dataset`):
Adversarial examples generated from the original training dataset. :obj:`None` if no adversarial attack takes place.
batch_size (:obj:`int`):
Batch size for training.
Returns:
:obj:`torch.utils.data.DataLoader`
"""
# TODO: Add pairing option where we can pair original examples with adversarial examples.
# Helper functions for collating data
def collate_fn(data):
input_texts = []
targets = []
is_adv_sample = []
for item in data:
if "_example_type" in item[0].keys():
# Get example type value from OrderedDict and remove it
adv = item[0].pop("_example_type")
# with _example_type removed from item[0] OrderedDict
# all other keys should be part of input
_input, label = item
if adv != "adversarial_example":
raise ValueError(
"`item` has length of 3 but last element is not for marking if the item is an `adversarial example`."
)
else:
is_adv_sample.append(True)
else:
# else `len(item)` is 2.
_input, label = item
is_adv_sample.append(False)
if isinstance(_input, collections.OrderedDict):
_input = tuple(_input.values())
else:
_input = tuple(_input)
if len(_input) == 1:
_input = _input[0]
input_texts.append(_input)
targets.append(label)
return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample)
if adv_dataset:
dataset = torch.utils.data.ConcatDataset([dataset, adv_dataset])
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
pin_memory=True,
)
return train_dataloader
def get_eval_dataloader(self, dataset, batch_size):
"""Returns the :obj:`torch.utils.data.DataLoader` for evaluation.
Args:
dataset (:class:`~textattack.datasets.Dataset`):
Dataset to use for evaluation.
batch_size (:obj:`int`):
Batch size for evaluation.
Returns:
:obj:`torch.utils.data.DataLoader`
"""
# Helper functions for collating data
def collate_fn(data):
input_texts = []
targets = []
for _input, label in data:
if isinstance(_input, collections.OrderedDict):
_input = tuple(_input.values())
else:
_input = tuple(_input)
if len(_input) == 1:
_input = _input[0]
input_texts.append(_input)
targets.append(label)
return input_texts, torch.tensor(targets)
eval_dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
pin_memory=True,
)
return eval_dataloader
def training_step(self, model, tokenizer, batch):
"""Perform a single training step on a batch of inputs.
Args:
model (:obj:`torch.nn.Module`):
Model to train.
tokenizer:
Tokenizer used to tokenize input text.
batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`):
By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
.. note::
If you override the :meth:`get_train_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch.
Returns:
:obj:`tuple[torch.Tensor, torch.Tensor, torch.Tensor]` where
- **loss**: :obj:`torch.FloatTensor` of shape 1 containing the loss.
- **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch.
- **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values).
"""
input_texts, targets, is_adv_sample = batch
_targets = targets
targets = targets.to(textattack.shared.utils.device)
if isinstance(model, transformers.PreTrainedModel) or (
isinstance(model, torch.nn.DataParallel)
and isinstance(model.module, transformers.PreTrainedModel)
):
input_ids = tokenizer(
input_texts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
input_ids.to(textattack.shared.utils.device)
logits = model(**input_ids)[0]
else:
input_ids = tokenizer(input_texts)
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(textattack.shared.utils.device)
logits = model(input_ids)
if self.task_type == "regression":
loss = self.loss_fct(logits.squeeze(), targets.squeeze())
preds = logits
else:
loss = self.loss_fct(logits, targets)
preds = logits.argmax(dim=-1)
sample_weights = torch.ones(
is_adv_sample.size(), device=textattack.shared.utils.device
)
sample_weights[is_adv_sample] *= self.training_args.alpha
loss = loss * sample_weights
loss = torch.mean(loss)
preds = preds.cpu()
return loss, preds, _targets
def evaluate_step(self, model, tokenizer, batch):
"""Perform a single evaluation step on a batch of inputs.
Args:
model (:obj:`torch.nn.Module`):
Model to train.
tokenizer:
Tokenizer used to tokenize input text.
batch (:obj:`tuple[list[str], torch.Tensor]`):
By default, this will be a tuple of input texts and target tensors.
.. note::
If you override the :meth:`get_eval_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch.
Returns:
:obj:`tuple[torch.Tensor, torch.Tensor]` where
- **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch.
- **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values).
"""
input_texts, targets = batch
_targets = targets
targets = targets.to(textattack.shared.utils.device)
if isinstance(model, transformers.PreTrainedModel):
input_ids = tokenizer(
input_texts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
input_ids.to(textattack.shared.utils.device)
logits = model(**input_ids)[0]
else:
input_ids = tokenizer(input_texts)
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(textattack.shared.utils.device)
logits = model(input_ids)
if self.task_type == "regression":
preds = logits
else:
preds = logits.argmax(dim=-1)
return preds.cpu(), _targets
def train(self):
"""Train the model on given training dataset."""
if not self.train_dataset:
raise ValueError("No `train_dataset` available for training.")
textattack.shared.utils.set_seed(self.training_args.random_seed)
if not os.path.exists(self.training_args.output_dir):
os.makedirs(self.training_args.output_dir)
# Save logger writes to file
log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt")
fh = logging.FileHandler(log_txt_path)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
logger.info(f"Writing logs to {log_txt_path}.")
# Save original self.training_args to file
args_save_path = os.path.join(
self.training_args.output_dir, "training_args.json"
)
with open(args_save_path, "w", encoding="utf-8") as f:
json.dump(self.training_args.__dict__, f)
logger.info(f"Wrote original training args to {args_save_path}.")
num_gpus = torch.cuda.device_count()
tokenizer = self.model_wrapper.tokenizer
model = self.model_wrapper.model
if self.training_args.parallel and num_gpus > 1:
# TODO: torch.nn.parallel.DistributedDataParallel
# Supposedly faster than DataParallel, but requires more work to setup properly.
model = torch.nn.DataParallel(model)
logger.info(f"Training on {num_gpus} GPUs via `torch.nn.DataParallel`.")
train_batch_size = self.training_args.per_device_train_batch_size * num_gpus
else:
train_batch_size = self.training_args.per_device_train_batch_size
if self.attack is None:
num_clean_epochs = self.training_args.num_epochs
else:
num_clean_epochs = self.training_args.num_clean_epochs
total_clean_training_steps = (
math.ceil(
len(self.train_dataset)
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
)
* num_clean_epochs
)
# calculate total_adv_training_data_length based on type of
# num_train_adv_examples.
# if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset.
if isinstance(self.training_args.num_train_adv_examples, float):
total_adv_training_data_length = (
len(self.train_dataset) * self.training_args.num_train_adv_examples
)
# if num_train_adv_examples is int and >=0 then it is taken as value.
elif (
isinstance(self.training_args.num_train_adv_examples, int)
and self.training_args.num_train_adv_examples >= 0
):
total_adv_training_data_length = self.training_args.num_train_adv_examples
# if num_train_adv_examples is = -1 , we generate all possible adv examples.
# Max number of all possible adv examples would be equal to train_dataset.
else:
total_adv_training_data_length = len(self.train_dataset)
# Based on total_adv_training_data_length calculation , find total total_adv_training_steps
total_adv_training_steps = math.ceil(
(len(self.train_dataset) + total_adv_training_data_length)
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
) * (self.training_args.num_epochs - num_clean_epochs)
total_training_steps = total_clean_training_steps + total_adv_training_steps
optimizer, scheduler = self.get_optimizer_and_scheduler(
model, total_training_steps
)
self._print_training_args(
total_training_steps, train_batch_size, num_clean_epochs
)
model.to(textattack.shared.utils.device)
# Variables across epochs
self._total_loss = 0.0
self._current_loss = 0.0
self._last_log_step = 0
# `best_score` is used to keep track of the best model across training.
# Could be loss, accuracy, or other metrics.
best_eval_score = 0.0
best_eval_score_epoch = 0
best_model_path = None
epochs_since_best_eval_score = 0
for epoch in range(1, self.training_args.num_epochs + 1):
logger.info("==========================================================")
logger.info(f"Epoch {epoch}")
if self.attack and epoch > num_clean_epochs:
if (
epoch - num_clean_epochs - 1
) % self.training_args.attack_epoch_interval == 0:
# only generate a new adversarial training set every self.training_args.attack_period epochs after the clean epochs
# adv_dataset is instance of `textattack.datasets.Dataset`
model.eval()
adv_dataset = self._generate_adversarial_examples(epoch)
model.train()
model.to(textattack.shared.utils.device)
else:
adv_dataset = None
else:
logger.info(f"Running clean epoch {epoch}/{num_clean_epochs}")
adv_dataset = None
train_dataloader = self.get_train_dataloader(
self.train_dataset, adv_dataset, train_batch_size
)
model.train()
# Epoch variables
all_preds = []
all_targets = []
prog_bar = tqdm.tqdm(
train_dataloader,
desc="Iteration",
position=0,
leave=True,
dynamic_ncols=True,
)
for step, batch in enumerate(prog_bar):
loss, preds, targets = self.training_step(model, tokenizer, batch)
if isinstance(model, torch.nn.DataParallel):
loss = loss.mean()
loss = loss / self.training_args.gradient_accumulation_steps
loss.backward()
loss = loss.item()
self._total_loss += loss
self._current_loss += loss
all_preds.append(preds)
all_targets.append(targets)
if (step + 1) % self.training_args.gradient_accumulation_steps == 0:
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
self._global_step += 1
if self._global_step > 0:
prog_bar.set_description(
f"Loss {self._total_loss/self._global_step:.5f}"
)
# TODO: Better way to handle TB and Wandb logging
if (self._global_step > 0) and (
self._global_step % self.training_args.logging_interval_step == 0
):
lr_to_log = (
scheduler.get_last_lr()[0]
if scheduler
else self.training_args.learning_rate
)
if self._global_step - self._last_log_step >= 1:
loss_to_log = round(
self._current_loss
/ (self._global_step - self._last_log_step),
4,
)
else:
loss_to_log = round(self._current_loss, 4)
log = {"train/loss": loss_to_log, "train/learning_rate": lr_to_log}
if self.training_args.log_to_tb:
self._tb_log(log, self._global_step)
if self.training_args.log_to_wandb:
self._wandb_log(log, self._global_step)
self._current_loss = 0.0
self._last_log_step = self._global_step
# Save model checkpoint to file.
if self.training_args.checkpoint_interval_steps:
if (
self._global_step > 0
and (
self._global_step
% self.training_args.checkpoint_interval_steps
)
== 0
):
self._save_model_checkpoint(
model, tokenizer, step=self._global_step
)
preds = torch.cat(all_preds)
targets = torch.cat(all_targets)
if self._metric_name == "accuracy":
correct_predictions = (preds == targets).sum().item()
accuracy = correct_predictions / len(targets)
metric_log = {"train/train_accuracy": accuracy}
logger.info(f"Train accuracy: {accuracy*100:.2f}%")
else:
pearson_correlation, pearson_pvalue = scipy.stats.pearsonr(
preds, targets
)
metric_log = {
"train/pearson_correlation": pearson_correlation,
"train/pearson_pvalue": pearson_pvalue,
}
logger.info(f"Train Pearson correlation: {pearson_correlation:.4f}%")
if len(targets) > 0:
if self.training_args.log_to_tb:
self._tb_log(metric_log, epoch)
if self.training_args.log_to_wandb:
metric_log["epoch"] = epoch
self._wandb_log(metric_log, self._global_step)
# Evaluate after each epoch.
eval_score = self.evaluate()
if self.training_args.log_to_tb:
self._tb_log({f"eval/{self._metric_name}": eval_score}, epoch)
if self.training_args.log_to_wandb:
self._wandb_log(
{f"eval/{self._metric_name}": eval_score, "epoch": epoch},
self._global_step,
)
if (
self.training_args.checkpoint_interval_epochs
and (epoch % self.training_args.checkpoint_interval_epochs) == 0
):
self._save_model_checkpoint(model, tokenizer, epoch=epoch)
if eval_score > best_eval_score:
best_eval_score = eval_score
best_eval_score_epoch = epoch
epochs_since_best_eval_score = 0
self._save_model_checkpoint(model, tokenizer, best=True)
logger.info(
f"Best score found. Saved model to {self.training_args.output_dir}/best_model/"
)
else:
epochs_since_best_eval_score += 1
if self.training_args.early_stopping_epochs and (
epochs_since_best_eval_score
> self.training_args.early_stopping_epochs
):
logger.info(
f"Stopping early since it's been {self.training_args.early_stopping_epochs} steps since validation score increased."
)
break
if self.training_args.log_to_tb:
self._tb_writer.flush()
# Finish training
if isinstance(model, torch.nn.DataParallel):
model = model.module
if self.training_args.load_best_model_at_end:
best_model_path = os.path.join(self.training_args.output_dir, "best_model")
if hasattr(model, "from_pretrained"):
model = model.__class__.from_pretrained(best_model_path)
else:
model = model.load_state_dict(
torch.load(os.path.join(best_model_path, "pytorch_model.bin"))
)
if self.training_args.save_last:
self._save_model_checkpoint(model, tokenizer, last=True)
self.model_wrapper.model = model
self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size)
def evaluate(self):
"""Evaluate the model on given evaluation dataset."""
if not self.eval_dataset:
raise ValueError("No `eval_dataset` available for training.")
logging.info("Evaluating model on evaluation dataset.")
model = self.model_wrapper.model
tokenizer = self.model_wrapper.tokenizer
model.eval()
all_preds = []
all_targets = []
if isinstance(model, torch.nn.DataParallel):
num_gpus = torch.cuda.device_count()
eval_batch_size = self.training_args.per_device_eval_batch_size * num_gpus
else:
eval_batch_size = self.training_args.per_device_eval_batch_size
eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size)
with torch.no_grad():
for step, batch in enumerate(eval_dataloader):
preds, targets = self.evaluate_step(model, tokenizer, batch)
all_preds.append(preds)
all_targets.append(targets)
preds = torch.cat(all_preds)
targets = torch.cat(all_targets)
if self.task_type == "regression":
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets)
eval_score = pearson_correlation
else:
correct_predictions = (preds == targets).sum().item()
accuracy = correct_predictions / len(targets)
eval_score = accuracy
if self._metric_name == "accuracy":
logger.info(f"Eval {self._metric_name}: {eval_score*100:.2f}%")
else:
logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%")
return eval_score
def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size):
if isinstance(self.training_args, CommandLineTrainingArgs):
model_name = self.training_args.model_name_or_path
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel):
if (
hasattr(self.model_wrapper.model.config, "_name_or_path")
and self.model_wrapper.model.config._name_or_path in HUGGINGFACE_MODELS
):
# TODO Better way than just checking HUGGINGFACE_MODELS ?
model_name = self.model_wrapper.model.config._name_or_path
elif hasattr(self.model_wrapper.model.config, "model_type"):
model_name = self.model_wrapper.model.config.model_type
else:
model_name = ""
else:
model_name = ""
if model_name:
model_name = f"`{model_name}`"
if (
isinstance(self.training_args, CommandLineTrainingArgs)
and self.training_args.model_max_length
):
model_max_length = self.training_args.model_max_length
elif isinstance(
self.model_wrapper.model,
(
transformers.PreTrainedModel,
LSTMForClassification,
WordCNNForClassification,
),
):
model_max_length = self.model_wrapper.tokenizer.model_max_length
else:
model_max_length = None
if model_max_length:
model_max_length_str = f" a maximum sequence length of {model_max_length},"
else:
model_max_length_str = ""
if isinstance(
self.train_dataset, textattack.datasets.HuggingFaceDataset
) and hasattr(self.train_dataset, "_name"):
dataset_name = self.train_dataset._name
if hasattr(self.train_dataset, "_subset"):
dataset_name += f" ({self.train_dataset._subset})"
elif isinstance(
self.eval_dataset, textattack.datasets.HuggingFaceDataset
) and hasattr(self.eval_dataset, "_name"):
dataset_name = self.eval_dataset._name
if hasattr(self.eval_dataset, "_subset"):
dataset_name += f" ({self.eval_dataset._subset})"
else:
dataset_name = None
if dataset_name:
dataset_str = (
"and the `{dataset_name}` dataset loaded using the `datasets` library"
)
else:
dataset_str = ""
loss_func = (
"mean squared error" if self.task_type == "regression" else "cross-entropy"
)
metric_name = (
"pearson correlation" if self.task_type == "regression" else "accuracy"
)
epoch_info = f"{best_eval_score_epoch} epoch" + (
"s" if best_eval_score_epoch > 1 else ""
)
readme_text = f"""
## TextAttack Model Card
This {model_name} model was fine-tuned using TextAttack{dataset_str}. The model was fine-tuned
for {self.training_args.num_epochs} epochs with a batch size of {train_batch_size},
{model_max_length_str} and an initial learning rate of {self.training_args.learning_rate}.
Since this was a {self.task_type} task, the model was trained with a {loss_func} loss function.
The best score the model achieved on this task was {best_eval_score}, as measured by the
eval set {metric_name}, found after {epoch_info}.
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
"""
readme_save_path = os.path.join(self.training_args.output_dir, "README.md")
with open(readme_save_path, "w", encoding="utf-8") as f:
f.write(readme_text.strip() + "\n")
logger.info(f"Wrote README to {readme_save_path}.")