Source code for transformers.generation_utils

# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
    EncoderNoRepeatNGramLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)
from .generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxNewTokensCriteria,
    MaxTimeCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from .utils import logging


logger = logging.get_logger(__name__)


[docs]@dataclass class GreedySearchDecoderOnlyOutput(ModelOutput): """ Base class for outputs of decoder-only generation models using greedy search. Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size, config.vocab_size)`). attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class GreedySearchEncoderDecoderOutput(ModelOutput): """ Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size, config.vocab_size)`). encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class SampleDecoderOnlyOutput(ModelOutput): """ Base class for outputs of decoder-only generation models using sampling. Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`). attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences*batch_size, num_heads, generated_length, sequence_length)`. hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences*batch_size, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class SampleEncoderDecoderOutput(ModelOutput): """ Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`). encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`. decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length, sequence_length)`. cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class BeamSearchDecoderOnlyOutput(ModelOutput): """ Base class for outputs of decoder-only generation models using beam search. Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Final beam scores of the generated ``sequences``. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam . :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`. hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class BeamSearchEncoderDecoderOutput(ModelOutput): """ Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Final beam scores of the generated ``sequences``. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam . :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size*num_beams, config.vocab_size)`). attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class BeamSampleDecoderOnlyOutput(ModelOutput): """ Base class for outputs of decoder-only generation models using beam sample. Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Final beam scores of the generated ``sequences``. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam . :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`. hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
[docs]@dataclass class BeamSampleEncoderDecoderOutput(ModelOutput): """ Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) Args: sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_beams, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Final beam scores of the generated ``sequences``. scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam . :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size*num_beams, config.vocab_size)`). encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size*num_beams, sequence_length, hidden_size)`. decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`. cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
[docs]class GenerationMixin: """ A class containing all of the functions supporting generation, to be used as a mixin in :class:`~transformers.PreTrainedModel`. """
[docs] def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the generate method. """ return {"input_ids": input_ids}
[docs] def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in the generate method. """ return logits
def _prepare_input_ids_for_generation( self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] ) -> torch.LongTensor: if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.size()[:-1] return torch.ones(shape, dtype=torch.long, device=self.device) * -100 if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int ) -> torch.LongTensor: is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( (eos_token_id is not None) and (pad_token_id != eos_token_id) ) if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: return input_ids.ne(pad_token_id).long() return input_ids.new_ones(input_ids.shape, dtype=torch.long) def _prepare_encoder_decoder_kwargs_for_generation( self, input_ids: torch.LongTensor, model_kwargs ) -> Dict[str, Any]: if "encoder_outputs" not in model_kwargs: # retrieve encoder hidden states encoder = self.get_encoder() encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) } model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) return model_kwargs def _prepare_decoder_input_ids_for_generation( self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None ) -> torch.LongTensor: decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) decoder_input_ids = ( torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id ) return decoder_input_ids def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: if pad_token_id is None and eos_token_id is not None: logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id return pad_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id elif ( hasattr(self.config, "decoder") and hasattr(self.config.decoder, "decoder_start_token_id") and self.config.decoder.decoder_start_token_id is not None ): return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id elif ( hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id") and self.config.decoder.bos_token_id is not None ): return self.config.decoder.bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: torch.LongTensor = None, encoder_outputs: ModelOutput = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) if is_encoder_decoder: assert encoder_outputs is not None encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs @staticmethod def _update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False ) -> Dict[str, Any]: # update past if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values elif "mems" in outputs: model_kwargs["past"] = outputs.mems elif "past_buckets_states" in outputs: model_kwargs["past"] = outputs.past_buckets_states else: model_kwargs["past"] = None # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention mask if not is_encoder_decoder: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) return model_kwargs def _reorder_cache(self, past, beam_idx): raise NotImplementedError( f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" ) def _get_logits_warper( self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant :obj:`~transformers.LogitsWarper` instances used for multinomial sampling. """ # init warp parameters top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if temperature is not None and temperature != 1.0: warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) return warpers def _get_logits_processor( self, repetition_penalty: float, no_repeat_ngram_size: int, encoder_no_repeat_ngram_size: int, encoder_input_ids: torch.LongTensor, bad_words_ids: List[List[int]], min_length: int, max_length: int, eos_token_id: int, forced_bos_token_id: int, forced_eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, num_beam_groups: int, diversity_penalty: float, remove_invalid_values: bool, ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant :obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. """ processors = LogitsProcessorList() # init warp parameters repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) encoder_no_repeat_ngram_size = ( encoder_no_repeat_ngram_size if encoder_no_repeat_ngram_size is not None else self.config.encoder_no_repeat_ngram_size ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty forced_bos_token_id = ( forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id ) forced_eos_token_id = ( forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id ) remove_invalid_values = ( remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values ) # instantiate processors list # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if diversity_penalty is not None and diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups ) ) if repetition_penalty is not None and repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: if self.config.is_encoder_decoder: processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) else: raise ValueError( "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" ) if bad_words_ids is not None: processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > -1: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) if forced_bos_token_id is not None: processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) if remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) return processors def _get_stopping_criteria( self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int ) -> StoppingCriteriaList: stopping_criteria = StoppingCriteriaList() if max_length is not None: stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) if max_time is not None: stopping_criteria.append(MaxTimeCriteria(max_time=max_time)) if max_new_tokens is not None: stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens)) return stopping_criteria
[docs] @torch.no_grad() def generate( self, input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, max_time: Optional[float] = None, max_new_tokens: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, num_beam_groups: Optional[int] = None, diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = None, **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, multinomial sampling, beam-search decoding, and beam-search multinomial sampling. Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default values of those config. Most of these parameters are explained in more detail in `this blog post <https://huggingface.co/blog/how-to-generate>`__. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it with :obj:`bos_token_id` and a batch size of 1. max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`): The maximum length of the sequence to be generated. max_new_tokens (:obj:`int`, `optional`, defaults to None): The maximum numbers of tokens to generate, ignore the current number of tokens. Use either :obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose. min_length (:obj:`int`, `optional`, defaults to 10): The minimum length of the sequence to be generated. do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use sampling ; use greedy decoding otherwise. early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. temperature (:obj:`float`, `optional`, defaults to 1.0): The value used to module the next token probabilities. top_k (:obj:`int`, `optional`, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (:obj:`float`, `optional`, defaults to 1.0): If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation. repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. bos_token_id (:obj:`int`, `optional`): The id of the `beginning-of-sequence` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. length_penalty (:obj:`float`, `optional`, defaults to 1.0): Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences. no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the ``decoder_input_ids``. bad_words_ids(:obj:`List[List[int]]`, `optional`): List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use :obj:`tokenizer(bad_word, add_prefix_space=True).input_ids`. num_return_sequences(:obj:`int`, `optional`, defaults to 1): The number of independently computed returned sequences for each element in the batch. max_time(:obj:`float`, `optional`, defaults to None): The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed. attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. `What are attention masks? <../glossary.html#attention-mask>`__ decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. num_beam_groups (:obj:`int`, `optional`, defaults to 1): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID :obj:`batch_id` and :obj:`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned on the batch ID :obj:`batch_id` and the previously generated tokens :obj:`inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__. output_attentions (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more details. output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more details. output_scores (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. forced_bos_token_id (:obj:`int`, `optional`): The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`. Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to be the target language token. forced_eos_token_id (:obj:`int`, `optional`): The id of the token to force as the last generated token when :obj:`max_length` is reached. remove_invalid_values (:obj:`bool`, `optional`): Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to crash. Note that using ``remove_invalid_values`` can slow down generation. synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with `decoder_`. Return: :class:`~transformers.file_utils.ModelOutput` or :obj:`torch.LongTensor`: A :class:`~transformers.file_utils.ModelOutput` (if ``return_dict_in_generate=True`` or when ``config.return_dict_in_generate=True``) or a :obj:`torch.FloatTensor`. If the model is `not` an encoder-decoder model (``model.config.is_encoder_decoder=False``), the possible :class:`~transformers.file_utils.ModelOutput` types are: - :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, - :class:`~transformers.generation_utils.SampleDecoderOnlyOutput`, - :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, - :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput` If the model is an encoder-decoder model (``model.config.is_encoder_decoder=True``), the possible :class:`~transformers.file_utils.ModelOutput` types are: - :class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput`, - :class:`~transformers.generation_utils.SampleEncoderDecoderOutput`, - :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput`, - :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` Examples:: >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> # do greedy decoding without providing a prompt >>> outputs = model.generate(max_length=40) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> document = ( ... "at least two people were killed in a suspected bomb attack on a passenger bus " ... "in the strife-torn southern philippines on monday , the military said." ... ) >>> # encode input context >>> input_ids = tokenizer(document, return_tensors="pt").input_ids >>> # generate 3 independent sequences using beam search decoding (5 beams) >>> # with T5 encoder-decoder model conditioned on short news article. >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> input_context = "The dog" >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> # generate 3 candidates using sampling >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("ctrl") >>> model = AutoModelForCausalLM.from_pretrained("ctrl") >>> # "Legal" is one of the control codes for ctrl >>> input_context = "Legal My neighbor is" >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> input_context = "My cute dog" >>> # get tokens of words that should not be generated >>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]] >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> # generate sequences without allowing bad_words to be generated >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) """ # set init values if max_length is None and max_new_tokens is None: # Both are None, default max_length = self.config.max_length elif max_length is not None and max_new_tokens is not None: # Both are set, this is odd, raise a warning warnings.warn( "Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning ) max_length = max_length if max_length is not None else self.config.max_length num_beams = num_beams if num_beams is not None else self.config.num_beams num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) model_kwargs["output_attentions"] = output_attentions model_kwargs["output_hidden_states"] = output_hidden_states if input_ids is None and "inputs_embeds" not in model_kwargs: # init `input_ids` with bos_token_id input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) if model_kwargs.get("attention_mask", None) is None: # init `attention_mask` depending on `pad_token_id` model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id ) # special case if pad_token_id is not defined if pad_token_id is None and eos_token_id is not None: logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id # Storing encoder_input_ids for logits_processor that could use them encoder_input_ids = input_ids if self.config.is_encoder_decoder else None if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # set input_ids as decoder_input_ids if "decoder_input_ids" in model_kwargs: input_ids = model_kwargs.pop("decoder_input_ids") else: input_ids = self._prepare_decoder_input_ids_for_generation( input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id ) if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") if input_ids.shape[-1] >= max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}." "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." ) # determine generation mode is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) # set model_kwargs model_kwargs["use_cache"] = use_cache # get distribution pre_processing samplers logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, encoder_input_ids=encoder_input_ids, bad_words_ids=bad_words_ids, min_length=min_length, max_length=max_length, eos_token_id=eos_token_id, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, remove_invalid_values=remove_invalid_values, ) cur_len = input_ids.shape[-1] stopping_criteria = self._get_stopping_criteria( max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len ) if is_greedy_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) # greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_sample_gen_mode: # get probability distribution warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # sample return self.sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_gen_mode: batch_size = input_ids.shape[0] length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) return self.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) batch_size = input_ids.shape[0] * num_return_sequences length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, ) # interleave with `num_beams * num_return_sequences` input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams * num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) return self.beam_sample( input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_group_beam_gen_mode: batch_size = input_ids.shape[0] length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if num_beams % num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, max_length=stopping_criteria.max_length, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) return self.group_beam_search( input_ids, diverse_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, )
[docs] def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using multinomial sampling. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. logits_warper (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. output_attentions (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more details. output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more details. output_scores (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: :class:`~transformers.generation_utils.SampleDecoderOnlyOutput`, :class:`~transformers.generation_utils.SampleEncoderDecoderOutput` or obj:`torch.LongTensor`: A :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a :class:`~transformers.generation_utils.SampleDecoderOnlyOutput` if ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a :class:`~transformers.generation_utils.SampleEncoderDecoderOutput` if ``model.config.is_encoder_decoder=True``. Examples:: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... ) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "Today is a beautiful day, and" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), ... ]) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList([ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ]) >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) cur_len = input_ids.shape[-1] this_peer_finished = False # used by synced_gpus only # auto-regressive generation while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids
[docs] def beam_sample( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = None, **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using beam search with multinomial sampling. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (:obj:`BeamScorer`): A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of :class:`~transformers.BeamScorer` should be read. logits_processor (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. logits_warper (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. output_attentions (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more details. output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more details. output_scores (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput`, :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` or obj:`torch.LongTensor`: A :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput` if ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` if ``model.config.is_encoder_decoder=True``. Examples:: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id) ... ]) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList([ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ]) >>> outputs = model.beam_sample( ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... ) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = next_tokens // vocab_size next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSampleEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSampleDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"]
[docs]def top_k_top_p_filtering( logits: torch.FloatTensor, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, ) -> torch.FloatTensor: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) top_k (:obj:`int`, `optional`, defaults to 0): If > 0, only keep the top k tokens with highest probability (top-k filtering) top_p (:obj:`float`, `optional`, defaults to 1.0): If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): Minimumber of tokens we keep per batch example in the output. From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( None, logits ) if 0 <= top_p <= 1.0: logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) return logits