Source code for transformers.trainer

# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import collections
import inspect
import math
import os
import re
import shutil
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union


# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
    default_hp_search_backend,
    hp_params,
    is_azureml_available,
    is_comet_available,
    is_mlflow_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
    run_hp_search_optuna,
    run_hp_search_ray,
)

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
    DistributedTensorGatherer,
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    get_tpu_sampler,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    set_seed,
)
from .training_args import TrainingArguments
from .utils import logging


_use_native_amp = False
_use_apex = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from .file_utils import is_apex_available

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast

if version.parse(torch.__version__) < version.parse("1.2"):
    _use_ddp_no_sync = False
else:
    _use_ddp_no_sync = True

if is_datasets_available():
    import datasets

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

if is_tensorboard_available():
    from .integrations import TensorBoardCallback

    DEFAULT_CALLBACKS.append(TensorBoardCallback)


if is_wandb_available():
    from .integrations import WandbCallback

    DEFAULT_CALLBACKS.append(WandbCallback)

if is_comet_available():
    from .integrations import CometCallback

    DEFAULT_CALLBACKS.append(CometCallback)

if is_mlflow_available():
    from .integrations import MLflowCallback

    DEFAULT_CALLBACKS.append(MLflowCallback)

if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

if is_azureml_available():
    from .integrations import AzureMLCallback

    DEFAULT_CALLBACKS.append(AzureMLCallback)

logger = logging.get_logger(__name__)


[docs]class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Args: model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`): The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed. .. note:: :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel` provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as they work the same way as the 🤗 Transformers models. args (:class:`~transformers.TrainingArguments`, `optional`): The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided. data_collator (:obj:`DataCollator`, `optional`): The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise. train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. tokenizer (:class:`PreTrainedTokenizerBase`, `optional`): The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`): A function that instantiates the model to be used. If provided, each call to :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function. The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): The function that will be used to compute metrics at evaluation. Must take a :class:`~transformers.EvalPrediction` and return a dictionary string to metric values. callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in :doc:`here <callback>`. If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method. optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`~transformers.AdamW` on your model and a scheduler given by :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`. kwargs: Deprecated keyword arguments. """ def __init__( self, model: Union[PreTrainedModel, torch.nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, tokenizer: Optional["PreTrainedTokenizerBase"] = None, model_init: Callable[[], PreTrainedModel] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), **kwargs, ): if args is None: logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.") args = TrainingArguments("tmp_trainer") self.args = args # Seed must be set before instantiating the model when using model set_seed(self.args.seed) assert ( model is not None or model_init is not None ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument." self.model_init = model_init self.hp_name = None if model is None and model_init is not None: model = self.call_model_init() self.model = model.to(args.device) if model is not None else None default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer self.compute_metrics = compute_metrics self.optimizer, self.lr_scheduler = optimizers if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): raise RuntimeError( "Passing a `model_init` is incompatible with providing the `optimizers` argument." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) # Deprecated arguments if "tb_writer" in kwargs: warnings.warn( "Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a " + "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`" + "argument", FutureWarning, ) tb_writer = kwargs.pop("tb_writer") self.remove_callback(TensorBoardCallback) self.add_callback(TensorBoardCallback(tb_writer=tb_writer)) if "prediction_loss_only" in kwargs: warnings.warn( "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a " + "future version. Use `args.prediction_loss_only` instead. Setting " + f"`args.prediction_loss_only={kwargs['prediction_loss_only']}", FutureWarning, ) self.args.prediction_loss_only = kwargs.pop("prediction_loss_only") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. self._loggers_initialized = False # Create output directory if needed if self.is_world_process_zero(): os.makedirs(self.args.output_dir, exist_ok=True) if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel): # Set an xla_device flag on the model's config. # We'll find a more elegant and not need to do this in the future. self.model.config.xla_device = True if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): self.data_collator = self.data_collator.collate_batch warnings.warn( ( "The `data_collator` should now be a simple callable (function, class with `__call__`), classes " + "with a `collate_batch` are deprecated and won't be supported in a future version." ), FutureWarning, ) if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") # Enforce rules on using datasets with no __len__ if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") if is_datasets_available(): if isinstance(train_dataset, datasets.Dataset): self._remove_unused_columns(self.train_dataset, description="training") if isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(self.eval_dataset, description="evaluation") self.state = TrainerState() self.control = TrainerControl() # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the # state at each call to self.log. self._total_flos = None if self.args.fp16 and _use_native_amp: self.scaler = torch.cuda.amp.GradScaler() self.hp_search_backend = None self.use_tune_checkpoints = False default_label_names = ( ["start_positions", "end_positions"] if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() else ["labels"] ) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
[docs] def add_callback(self, callback): """ Add a callback to the current list of :class:`~transformer.TrainerCallback`. Args: callback (:obj:`type` or :class:`~transformer.TrainerCallback`): A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`. In the first case, will instantiate a member of that class. """ self.callback_handler.add_callback(callback)
[docs] def pop_callback(self, callback): """ Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it. If the callback is not found, returns :obj:`None` (and no error is raised). Args: callback (:obj:`type` or :class:`~transformer.TrainerCallback`): A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`. In the first case, will pop the first member of that class found in the list of callbacks. Returns: :class:`~transformer.TrainerCallback`: The callback removed, if found. """ return self.callback_handler.pop_callback(callback)
[docs] def remove_callback(self, callback): """ Remove a callback from the current list of :class:`~transformer.TrainerCallback`. Args: callback (:obj:`type` or :class:`~transformer.TrainerCallback`): A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`. In the first case, will remove the first member of that class found in the list of callbacks. """ self.callback_handler.remove_callback(callback)
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: return # Inspect model forward signature to keep only the arguments it accepts. signature = inspect.signature(self.model.forward) signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that. signature_columns += ["label", "label_ids"] columns = [k for k in signature_columns if k in dataset.column_names] ignored_columns = list(set(dataset.column_names) - set(signature_columns)) dset_description = "" if description is None else f"in the {description} set " logger.info( f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." ) dataset.set_format(type=dataset.format["type"], columns=columns) def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance( self.train_dataset, collections.abc.Sized ): return None elif is_torch_tpu_available(): return get_tpu_sampler(self.train_dataset) else: return ( RandomSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset) )
[docs] def get_train_dataloader(self) -> DataLoader: """ Returns the training :class:`~torch.utils.data.DataLoader`. Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted to distributed training if necessary) otherwise. Subclass and override this method if you want to inject some custom behavior. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_sampler = self._get_train_sampler() return DataLoader( self.train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: if is_torch_tpu_available(): return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.args.local_rank != -1: return SequentialDistributedSampler(eval_dataset) else: return SequentialSampler(eval_dataset)
[docs] def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: """ Returns the evaluation :class:`~torch.utils.data.DataLoader`. Subclass and override this method if you want to inject some custom behavior. Args: eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(eval_dataset, description="evaluation") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_sampler = self._get_eval_sampler(eval_dataset) return DataLoader( eval_dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, )
[docs] def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ Returns the test :class:`~torch.utils.data.DataLoader`. Subclass and override this method if you want to inject some custom behavior. Args: test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. """ if not isinstance(test_dataset, collections.abc.Sized): raise ValueError("test_dataset must implement __len__") elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset): self._remove_unused_columns(test_dataset, description="test") test_sampler = self._get_eval_sampler(test_dataset) # We use the same batch_size as for eval. return DataLoader( test_dataset, sampler=test_sampler, batch_size=self.args.eval_batch_size, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, )
[docs] def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. """ if self.optimizer is None: no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] self.optimizer = AdamW( optimizer_grouped_parameters, lr=self.args.learning_rate, betas=(self.args.adam_beta1, self.args.adam_beta2), eps=self.args.adam_epsilon, ) if self.lr_scheduler is None: self.lr_scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps )
[docs] def num_examples(self, dataloader: DataLoader) -> int: """ Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset. Will raise an exception if the underlying dataset dese not implement method :obj:`__len__` """ return len(dataloader.dataset)
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """ HP search setup code """ self._trial = trial if self.hp_search_backend is None or trial is None: return params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial for key, value in params.items(): if not hasattr(self.args, key): raise AttributeError( f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`." ) old_attr = getattr(self.args, key, None) # Casting value to the proper type if old_attr is not None: value = type(old_attr)(value) setattr(self.args, key, value) if self.hp_search_backend == HPSearchBackend.OPTUNA: logger.info("Trial:", trial.params) def _report_to_hp_search( self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float] ): if self.hp_search_backend is None or trial is None: return self.objective = self.compute_objective(metrics.copy()) if self.hp_search_backend == HPSearchBackend.OPTUNA: trial.report(self.objective, epoch) if trial.should_prune(): raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: if self.state.global_step % self.args.save_steps == 0: self._tune_save_checkpoint() tune.report(objective=self.objective, **metrics) def _tune_save_checkpoint(self): if not self.use_tune_checkpoints: return with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: self.args.output_dir = checkpoint_dir output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir) if self.is_world_master(): self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) def call_model_init(self, trial=None): model_init_argcount = len(inspect.signature(self.model_init).parameters) if model_init_argcount == 0: model = self.model_init() elif model_init_argcount == 1: model = self.model_init(trial) else: raise RuntimeError("model_init should have 0 or 1 argument.") if model is None: raise RuntimeError("model_init should not return None.") return model
[docs] def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): """ Main training entry point. Args: model_path (:obj:`str`, `optional`): Local path to the model if the model to train has been instantiated from a local path. If present, training will resume from the optimizer/scheduler states loaded here. trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): The trial run or the hyperparameter dictionary for hyperparameter search. """ # This might change the seed so needs to run first. self._hp_search_setup(trial) # Model re-init if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. set_seed(self.args.seed) model = self.call_model_init(trial) self.model = model.to(self.args.device) # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None # Keeping track whether we can can len() on the dataset or not train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) # Data loader and number of training steps train_dataloader = self.get_train_dataloader() # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps if train_dataset_is_sized: num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) if self.args.max_steps > 0: max_steps = self.args.max_steps num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( self.args.max_steps % num_update_steps_per_epoch > 0 ) else: max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(self.args.num_train_epochs) else: # see __init__. max_steps is set when the dataset has no __len__ max_steps = self.args.max_steps num_train_epochs = 1 num_update_steps_per_epoch = max_steps self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(model_path) # Mixed precision training with apex (torch < 1.6) model = self.model if self.args.fp16 and _use_apex: if not is_apex_available(): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=( not getattr(model.config, "gradient_checkpointing", False) if isinstance(model, PreTrainedModel) else True ), ) # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() else: total_train_batch_size = ( self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) ) num_examples = ( self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * self.args.max_steps ) logger.info("***** Running training *****") logger.info(" Num examples = %d", num_examples) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", max_steps) self.state.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) epochs_trained = self.state.global_step // num_update_steps_per_epoch steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.state.global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None self.state.trial_params = hp_params(trial) if trial is not None else None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() tr_loss = torch.tensor(0.0).to(self.args.device) self._logging_loss_scalar = 0 self._globalstep_last_logged = 0 self._total_flos = self.state.total_flos model.zero_grad() self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( self.args.device ) epoch_iterator = parallel_loader else: epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if self.args.past_index >= 0: self._past = None steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue if (step + 1) % self.args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) if ( ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1 and _use_ddp_no_sync ): with model.no_sync(): tr_loss += self.training_step(model, inputs) else: tr_loss += self.training_step(model, inputs) self._total_flos += self.floating_point_ops(inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= self.args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ): if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(self.optimizer) elif self.args.fp16 and _use_native_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) if self.control.should_epoch_stop or self.control.should_training_stop: break self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: logger.warning( "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) if self.control.should_training_stop: break if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: logger.info( f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) if isinstance(model, PreTrainedModel): self.model = model.from_pretrained(self.state.best_model_checkpoint) self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) self.model.load_state_dict(state_dict) if self._total_flos is not None: self.store_flos() self.log({"total_flos": self.state.total_flos}) self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): if self.control.should_log: logs: Dict[str, float] = {} tr_loss_scalar = tr_loss.item() logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / ( self.state.global_step - self._globalstep_last_logged ) # backward compatibility for pytorch schedulers logs["learning_rate"] = ( self.lr_scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else self.lr_scheduler.get_lr()[0] ) self._logging_loss_scalar = tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.log(logs) metrics = None if self.control.should_evaluate: metrics = self.evaluate() self._report_to_hp_search(trial, epoch, metrics) if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _save_checkpoint(self, model, trial, metrics=None): # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.model, f"Module {model.module} should be a reference to self.model" else: assert model is self.model, f"Model {model} should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is not None and trial is not None: run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder) else: output_dir = os.path.join(self.args.output_dir, checkpoint_folder) self.store_flos() self.save_model(output_dir) # Save optimizer and scheduler if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) reissue_pt_warnings(caught_warnings) elif self.is_world_process_zero(): torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) reissue_pt_warnings(caught_warnings) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" metric_value = metrics[metric_to_check] operator = np.greater if self.args.greater_is_better else np.less if ( self.state.best_metric is None or self.state.best_model_checkpoint is None or operator(metric_value, self.state.best_metric) ): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir # Save the Trainer state if self.is_world_process_zero(): self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) # Maybe delete some older checkpoints. if self.is_world_process_zero(): self._rotate_checkpoints(use_mtime=True) def _load_optimizer_and_scheduler(self, model_path): """If optimizer and scheduler states exist, load them.""" if ( model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt")) ): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu") reissue_pt_warnings(caught_warnings) xm.send_cpu_data_to_device(optimizer_state, self.args.device) xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) self.optimizer.load_state_dict(optimizer_state) self.lr_scheduler.load_state_dict(lr_scheduler_state) else: self.optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) reissue_pt_warnings(caught_warnings)
[docs] def log(self, logs: Dict[str, float]) -> None: """ Log :obj:`logs` on the various objects watching training. Subclass and override this method to inject custom behavior. Args: logs (:obj:`Dict[str, float]`): The values to log. """ if hasattr(self, "_log"): warnings.warn( "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.", FutureWarning, ) return self._log(logs) if self.state.epoch is not None: logs["epoch"] = self.state.epoch self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ for k, v in inputs.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(self.args.device) if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past return inputs
[docs] def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to train. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. Return: :obj:`torch.Tensor`: The tensor with training loss on this batch. """ if hasattr(self, "_training_step"): warnings.warn( "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.", FutureWarning, ) return self._training_step(model, inputs, self.optimizer) model.train() inputs = self._prepare_inputs(inputs) if self.args.fp16 and _use_native_amp: with autocast(): loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() return loss.detach()
[docs] def compute_loss(self, model, inputs): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ outputs = model(**inputs) # Save past state if it exists if self.args.past_index >= 0: self._past = outputs[self.args.past_index] # We don't use .loss here since the model may return tuples instead of ModelOutput. return outputs[0]
[docs] def is_local_master(self) -> bool: """ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process. .. warning:: This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead. """ warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning) return self.is_local_process_zero()
[docs] def is_local_process_zero(self) -> bool: """ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process. """ if is_torch_tpu_available(): return xm.is_master_ordinal(local=True) else: return self.args.local_rank in [-1, 0]
[docs] def is_world_master(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be :obj:`True` for one process). .. warning:: This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead. """ warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning) return self.is_world_process_zero()
[docs] def is_world_process_zero(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be :obj:`True` for one process). """ if is_torch_tpu_available(): return xm.is_master_ordinal(local=False) else: return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
[docs] def save_model(self, output_dir: Optional[str] = None): """ Will save the model, so you can reload it using :obj:`from_pretrained()`. Will only save from the world_master process (unless in TPUs). """ if is_torch_tpu_available(): self._save_tpu(output_dir) elif self.is_world_process_zero(): self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info("Saving model checkpoint to %s", output_dir) if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") if not isinstance(self.model, PreTrainedModel): logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained(output_dir) if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir) def _save(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained(output_dir) if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin")) def store_flos(self): # Storing the number of floating-point operations that went into the model if self._total_flos is not None: if self.args.local_rank != -1: self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item() else: self.state.total_flos = self._total_flos def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: ordering_and_checkpoint_path = [] glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] for path in glob_checkpoints: if use_mtime: ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) else: regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) if regex_match and regex_match.groups(): ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] # Make sure we don't delete the best model. if self.state.best_model_checkpoint is not None: best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = ( checkpoints_sorted[-1], checkpoints_sorted[best_model_index], ) return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False) -> None: if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return # Check if we should delete older checkpoint(s) checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime) if len(checkpoints_sorted) <= self.args.save_total_limit: return number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit) checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] for checkpoint in checkpoints_to_be_deleted: logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) shutil.rmtree(checkpoint)
[docs] def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init :obj:`compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (:obj:`Dataset`, `optional`): Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the :obj:`__len__` method. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self.prediction_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, ) self.log(output.metrics) if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) return output.metrics
[docs] def predict(self, test_dataset: Dataset) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in :obj:`evaluate()`. Args: test_dataset (:obj:`Dataset`): Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` .. note:: If your predictions or labels have different sequence length (for instance because you're doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100. Returns: `NamedTuple` A namedtuple with the following keys: - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some). - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset contained labels). """ if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized): raise ValueError("test_dataset must implement __len__") test_dataloader = self.get_test_dataloader(test_dataset) return self.prediction_loop(test_dataloader, description="Prediction")
[docs] def prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None ) -> PredictionOutput: """ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Works both with or without labels. """ if hasattr(self, "_prediction_loop"): warnings.warn( "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.", FutureWarning, ) return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only) if not isinstance(dataloader.dataset, collections.abc.Sized): raise ValueError("dataset must implement __len__") prediction_loss_only = ( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) model = self.model # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", num_examples) logger.info(" Batch size = %d", batch_size) losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None world_size = 1 if is_torch_tpu_available(): world_size = xm.xrt_world_size() elif self.args.local_rank != -1: world_size = torch.distributed.get_world_size() world_size = max(1, world_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) if not prediction_loss_only: preds_gatherer = DistributedTensorGatherer(world_size, num_examples) labels_gatherer = DistributedTensorGatherer(world_size, num_examples) model.eval() if is_torch_tpu_available(): dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) if self.args.past_index >= 0: self._past = None self.callback_handler.eval_dataloader = dataloader for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) if loss is not None: losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) if logits is not None: preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) # Set back to None to begin a new accumulation losses_host, preds_host, labels_host = None, None, None if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) eval_loss = eval_losses_gatherer.finalize() preds = preds_gatherer.finalize() if not prediction_loss_only else None label_ids = labels_gatherer.finalize() if not prediction_loss_only else None if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} if eval_loss is not None: metrics["eval_loss"] = eval_loss.mean().item() # Prefix all keys with eval_ for key in list(metrics.keys()): if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _gather_and_numpify(self, tensors, name): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before concatenating them to `gathered` """ if tensors is None: return if is_torch_tpu_available(): tensors = nested_xla_mesh_reduce(tensors, name) elif self.args.local_rank != -1: tensors = distributed_concat(tensors) return nested_numpify(tensors)
[docs] def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to evaluate. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (:obj:`bool`): Whether or not to return the loss only. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) with torch.no_grad(): if self.args.fp16 and _use_native_amp: with autocast(): outputs = model(**inputs) else: outputs = model(**inputs) if has_labels: loss = outputs[0].mean().detach() logits = outputs[1:] else: loss = None # Slicing so we get a tuple even if `outputs` is a `ModelOutput`. logits = outputs[:] if self.args.past_index >= 0: self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] # Remove the past from the logits. logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] if has_labels: labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None return (loss, logits, labels)
[docs] def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): """ For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of floating point operations for every backward + forward pass. If using another model, either implement such a method in the model or subclass and override this method. Args: inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. Returns: :obj:`int`: The number of floating-point operations. """ model = self._actual_model(self.model) if hasattr(model, "floating_point_ops"): return model.floating_point_ops(inputs) else: return 0
@staticmethod def _actual_model( model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module] ) -> torch.nn.modules.Module: """ Args: model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`): Model object used during training Returns: :obj:`torch.nn.modules.Module`: unwrapped module """ if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module else: model = model return model