# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import collections
import gc
import inspect
import math
import os
import re
import shutil
import sys
import time
import warnings
from logging import StreamHandler
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
# Integrations must be imported before ML frameworks:
from .integrations import ( # isort: split
default_hp_search_backend,
get_reporting_integration_callbacks,
hp_params,
is_fairscale_available,
is_optuna_available,
is_ray_tune_available,
run_hp_search_optuna,
run_hp_search_ray,
init_deepspeed,
)
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import (
WEIGHTS_NAME,
is_apex_available,
is_datasets_available,
is_in_notebook,
is_sagemaker_distributed_available,
is_torch_tpu_available,
is_training_run_on_sagemaker,
)
from .modeling_utils import PreTrainedModel, unwrap_model
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
HPSearchBackend,
PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
get_last_checkpoint,
set_seed,
speed_metrics,
)
from .training_args import ParallelMode, TrainingArguments
from .utils import logging
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
_is_native_amp_available = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from .utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
if is_datasets_available():
import datasets
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available():
import fairscale
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.wrap import auto_wrap
else:
FullyShardedDDP = None
if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
import torch.distributed as dist
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__)
[docs]class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
Args:
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
.. note::
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
they work the same way as the 🤗 Transformers models.
args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of
:class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
the current directory if not provided.
data_collator (:obj:`DataCollator`, `optional`):
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
layers, dropout probabilities etc).
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in :doc:`here <callback>`.
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple
containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
Important attributes:
- **model** -- Always points to the core model. If using a transformers model, it will be a
:class:`~transformers.PreTrainedModel` subclass.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
the inner model is wrapped in ``DeepSpeed`` and then again in ``torch.nn.DistributedDataParallel``. If the
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to :obj:`False` if model parallel or deepspeed is used, or if the default
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
- **is_in_train** -- Whether or not a model is currently running ``train`` (e.g. when ``evaluate`` is called
while in ``train``)
"""
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__(
self,
model: Union[PreTrainedModel, torch.nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
):
if args is None:
output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
args = TrainingArguments(output_dir=output_dir)
self.args = args
# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
# force device and distributed setup init explicitly
args._setup_devices
if model is None:
if model_init is not None:
self.model_init = model_init
model = self.call_model_init()
else:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
else:
if model_init is not None:
warnings.warn(
"`Trainer` requires either a `model` or `model_init` argument, but not both. "
"`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
FutureWarning,
)
self.model_init = model_init
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
self.is_model_parallel = True
else:
self.is_model_parallel = False
# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.SIMPLE
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or (args.deepspeed and args.do_train)
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
if self.place_model_on_device:
model = model.to(args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
if self.is_model_parallel:
self.args._n_gpu = 1
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
self.model_wrapped = model
self.model = model
self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
raise RuntimeError(
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
self._signature_columns = None
if is_datasets_available():
if isinstance(train_dataset, datasets.Dataset):
self._remove_unused_columns(self.train_dataset, description="training")
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")
# Mixed precision setup
self.use_apex = False
self.use_amp = False
self.fp16_backend = None
if args.fp16:
if args.fp16_backend == "auto":
self.fp16_backend = "amp" if _is_native_amp_available else "apex"
else:
self.fp16_backend = args.fp16_backend
logger.info(f"Using {self.fp16_backend} fp16 backend")
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
)
self.use_apex = True
# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None
self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
["start_positions", "end_positions"]
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# very last
self._memory_tracker.stop_and_update_metrics()
[docs] def add_callback(self, callback):
"""
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will instantiate a member of that class.
"""
self.callback_handler.add_callback(callback)
[docs] def pop_callback(self, callback):
"""
Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
If the callback is not found, returns :obj:`None` (and no error is raised).
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will pop the first member of that class found in the list of callbacks.
Returns:
:class:`~transformer.TrainerCallback`: The callback removed, if found.
"""
return self.callback_handler.pop_callback(callback)
[docs] def remove_callback(self, callback):
"""
Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will remove the first member of that class found in the list of callbacks.
"""
self.callback_handler.remove_callback(callback)
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += ["label", "label_ids"]
columns = [k for k in self._signature_columns if k in dataset.column_names]
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if description is None else f"in the {description} set "
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)
dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
return None
# Build the sampler.
if self.args.group_by_length:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
)
else:
return DistributedLengthGroupedSampler(
self.train_dataset,
self.args.train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
model_input_name=model_input_name,
)
else:
if self.args.world_size <= 1:
return RandomSampler(self.train_dataset)
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
)
else:
return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
)
[docs] def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = self._get_train_sampler()
return DataLoader(
self.train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
else:
return SequentialSampler(eval_dataset)
[docs] def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
eval_sampler = self._get_eval_sampler(eval_dataset)
return DataLoader(
eval_dataset,
sampler=eval_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
[docs] def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the test :class:`~torch.utils.data.DataLoader`.
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
self._remove_unused_columns(test_dataset, description="test")
test_sampler = self._get_eval_sampler(test_dataset)
# We use the same batch_size as for eval.
return DataLoader(
test_dataset,
sampler=test_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
pin_memory=self.args.dataloader_pin_memory,
)
[docs] def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
warmup_steps = (
self.args.warmup_steps
if self.args.warmup_steps > 0
else math.ceil(num_training_steps * self.args.warmup_ratio)
)
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
self.optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
)
[docs] def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
"""
return len(dataloader.dataset)
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
self._trial = trial
if self.hp_search_backend is None or trial is None:
return
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
for key, value in params.items():
if not hasattr(self.args, key):
raise AttributeError(
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
)
old_attr = getattr(self.args, key, None)
# Casting value to the proper type
if old_attr is not None:
value = type(old_attr)(value)
setattr(self.args, key, value)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
logger.info("Trial:", trial.params)
def _report_to_hp_search(
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
trial.report(self.objective, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
if self.control.should_save:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self):
from ray import tune
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
def call_model_init(self, trial=None):
model_init_argcount = len(inspect.signature(self.model_init).parameters)
if model_init_argcount == 0:
model = self.model_init()
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise RuntimeError("model_init should have 0 or 1 argument.")
if model is None:
raise RuntimeError("model_init should not return None.")
return model
def _wrap_model(self, model, training=True):
# already initialized its own DDP and AMP
if self.deepspeed:
return self.deepspeed
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
return model
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
# Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None:
find_unused_parameters = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
else:
find_unused_parameters = True
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=find_unused_parameters,
)
return model
[docs] def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
**kwargs,
):
"""
Main training entry point.
Args:
resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
:class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()
self.is_in_train = True
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
"instead.",
FutureWarning,
)
if len(kwargs) > 0:
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
# This might change the seed so needs to run first.
self._hp_search_setup(trial)
# Model re-init
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
logger.info(f"Loading model from {resume_from_checkpoint}).")
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(resume_from_checkpoint)
model_reloaded = True
else:
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs)
else:
# see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps
num_train_epochs = 1
num_update_steps_per_epoch = max_steps
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module
self.model_wrapped = model # will get further wrapped in DDP
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# Train!
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
world_size = dist.get_world_size()
else:
world_size = 1
total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
num_examples = (
self.num_examples(train_dataloader)
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, "trainer_state.json")
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not self.args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not self.args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
self.state.trial_params = hp_params(trial) if trial is not None else None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos
model.zero_grad()
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not self.args.ignore_data_skip:
for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if train_dataset_is_sized
else self.args.max_steps * self.args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and self.args._no_sync_in_gradient_accumulation
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss += self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
self._total_flos += float(self.floating_point_ops(inputs))
# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
if self.deepspeed:
self.deepspeed.step()
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping
if self.use_amp:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
)
# Optimizer step
if self.deepspeed:
pass # called outside the loop
elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
if not self.deepspeed:
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif self.args.local_rank != -1:
dist.barrier()
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
if self.deepspeed:
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
if self.deepspeed:
# free up any memory that might be useful for eval
self.deepspeed = None
self.optimizer = None
self.lr_scheduler = None
self.model_wrapped = self.model
gc.collect() # force memory release
# to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed
self.place_model_on_device = self.args.place_model_on_device
if self.is_model_parallel:
self.place_model_on_device = False
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
else:
from ray import tune
run_id = tune.get_trial_id()
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
run_dir = self.args.output_dir
self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir)
if self.deepspeed:
self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if checkpoint is None:
return
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
os.path.join(checkpoint, "scheduler.pt")
):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
if self.deepspeed:
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
self.deepspeed.load_checkpoint(checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True)
[docs] def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
n_trials: int = 20,
direction: str = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs,
) -> BestRun:
"""
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
:obj:`compute_objective`, which defaults to a function returning the evaluation loss when no metric is
provided, the sum of all metrics otherwise.
.. warning::
To use this method, you need to have provided a ``model_init`` when initializing your
:class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.
Args:
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
A function that defines the hyperparameter search space. Will default to
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
A function computing the objective to minimize or maximize from the metrics returned by the
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
n_trials (:obj:`int`, `optional`, defaults to 100):
The number of trial runs to test.
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
several metrics.
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
one is installed. If both are installed, will default to optuna.
kwargs:
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
more information see:
- the documentation of `optuna.create_study
<https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html>`__
- the documentation of `tune.run
<https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
Returns:
:class:`transformers.trainer_utils.BestRun`: All the information about the best run.
"""
if backend is None:
backend = default_hp_search_backend()
if backend is None:
raise RuntimeError(
"At least one of optuna or ray should be installed. "
"To install optuna run `pip install optuna`."
"To install ray run `pip install ray[tune]`."
)
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
best_run = run_hp_search(self, n_trials, direction, **kwargs)
self.hp_search_backend = None
return best_run
[docs] def log(self, logs: Dict[str, float]) -> None:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
"""
if self.state.epoch is not None:
logs["epoch"] = round(self.state.epoch, 2)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
return inputs
[docs] def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
# loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss)
else:
loss.backward()
return loss.detach()
[docs] def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
[docs] def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
machines) main process.
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=True)
else:
return self.args.local_rank in [-1, 0]
[docs] def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process).
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=False)
else:
return self.args.local_rank == -1 or dist.get_rank() == 0
[docs] def save_model(self, output_dir: Optional[str] = None):
"""
Will save the model, so you can reload it using :obj:`from_pretrained()`.
Will only save from the main process.
"""
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero():
self._save(output_dir, state_dict=state_dict)
elif self.is_world_process_zero():
self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info("Saving model checkpoint to %s", output_dir)
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
save_config=self.is_world_process_zero(),
state_dict=self.model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
if state_dict is None:
state_dict = self.model.state_dict()
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self._total_flos is not None:
if self.args.local_rank != -1:
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
else:
self.state.total_flos = self._total_flos
def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
) -> List[str]:
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*")]
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match and regex_match.groups():
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
# Make sure we don't delete the best model.
if self.state.best_model_checkpoint is not None:
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
checkpoints_sorted[-1],
checkpoints_sorted[best_model_index],
)
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
if len(checkpoints_sorted) <= self.args.save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)
[docs] def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
output = self.prediction_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
return output.metrics
[docs] def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval"
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in :obj:`evaluate()`.
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
.. note::
If your predictions or labels have different sequence length (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
output = self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
self._memory_tracker.stop_and_update_metrics(output.metrics)
return output
[docs] def prediction_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Works both with or without labels.
"""
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
if self.args.deepspeed and not self.args.do_train:
# no harm, but flagging to the user that deepspeed config is ignored for eval
# flagging only for when --do_train wasn't passed as only then it's redundant
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", batch_size)
losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = max(1, self.args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
# a batch size to the sampler)
make_multiple_of = None
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
make_multiple_of = dataloader.sampler.batch_size
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
model.eval()
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
self._past = None
self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize() if not prediction_loss_only else None
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if eval_loss is not None:
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _gather_and_numpify(self, tensors, name):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`
"""
if tensors is None:
return
if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
return nested_numpify(tensors)
[docs] def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if has_labels:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
[docs] def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
floating point operations for every backward + forward pass. If using another model, either implement such a
method in the model or subclass and override this method.
Args:
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
Returns:
:obj:`int`: The number of floating-point operations.
"""
if hasattr(self.model, "floating_point_ops"):
return self.model.floating_point_ops(inputs)
else:
return 0