Source code for transformers.trainer_seq2seq

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset

from .trainer import Trainer
from .trainer_utils import PredictionOutput
from .utils import logging


if version.parse(torch.__version__) >= version.parse("1.6"):
    from torch.cuda.amp import autocast


logger = logging.get_logger(__name__)


[docs]class Seq2SeqTrainer(Trainer):
[docs] def evaluate( self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", max_length: Optional[int] = None, num_beams: Optional[int] = None, ) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init :obj:`compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (:obj:`Dataset`, `optional`): Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the :obj:`__len__` method. ignore_keys (:obj:`List[str]`, `optional`): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is ``"eval"`` (default) max_length (:obj:`int`, `optional`): The maximum target length to use when predicting with the generate method. num_beams (:obj:`int`, `optional`): Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ self._max_length = max_length self._num_beams = num_beams return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
[docs] def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", max_length: Optional[int] = None, num_beams: Optional[int] = None, ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in :obj:`evaluate()`. Args: test_dataset (:obj:`Dataset`): Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` ignore_keys (:obj:`List[str]`, `optional`): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is ``"eval"`` (default) max_length (:obj:`int`, `optional`): The maximum target length to use when predicting with the generate method. num_beams (:obj:`int`, `optional`): Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. .. note:: If your predictions or labels have different sequence lengths (for instance because you're doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100. Returns: `NamedTuple` A namedtuple with the following keys: - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some). - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset contained labels). """ self._max_length = max_length self._num_beams = num_beams return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to evaluate. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (:obj:`bool`): Whether or not to return the loss only. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ if not self.args.predict_with_generate or prediction_loss_only: return super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) has_labels = "labels" in inputs inputs = self._prepare_inputs(inputs) gen_kwargs = { "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, } generated_tokens = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], **gen_kwargs, ) # in case the batch is shorter than max length, the output should be padded if generated_tokens.shape[-1] < gen_kwargs["max_length"]: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) with torch.no_grad(): if self.use_amp: with autocast(): outputs = model(**inputs) else: outputs = model(**inputs) if has_labels: if self.label_smoother is not None: loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() else: loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() else: loss = None if self.args.prediction_loss_only: return (loss, None, None) labels = inputs["labels"] if labels.shape[-1] < gen_kwargs["max_length"]: labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) return (loss, generated_tokens, labels) def _pad_tensors_to_max_len(self, tensor, max_length): if self.tokenizer is None: raise ValueError( f"Tensor need to be padded to `max_length={max_length}` but no tokenzier was passed when creating " "this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer." ) # If PAD token is not defined at least EOS token has to be defined pad_token_id = ( self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id ) padded_tensor = pad_token_id * torch.ones( (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device ) padded_tensor[:, : tensor.shape[-1]] = tensor return padded_tensor