Source code for transformers.generation_utils

# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import functional as F

from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)
from .utils import logging


logger = logging.get_logger(__name__)


[docs]class GenerationMixin: """ A class containing all of the functions supporting generation, to be used as a mixin in :class:`~transformers.PreTrainedModel`. """
[docs] def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the generate method. """ return {"input_ids": input_ids}
[docs] def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in the generate method. """ return logits
def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor: if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int ) -> torch.LongTensor: is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( (eos_token_id is not None) and (pad_token_id != eos_token_id) ) if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: return input_ids.ne(pad_token_id).long() return input_ids.new_ones(input_ids.shape) def _prepare_encoder_decoder_kwargs_for_generation( self, input_ids: torch.LongTensor, model_kwargs ) -> Dict[str, Any]: # retrieve encoder hidden states encoder = self.get_encoder() encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_") } model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) return model_kwargs def _prepare_decoder_input_ids_for_generation( self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs ) -> torch.LongTensor: if "decoder_input_ids" in model_kwargs: return model_kwargs["decoder_input_ids"] decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) decoder_input_ids = ( torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device) * decoder_start_token_id ) return decoder_input_ids def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: if pad_token_id is None and eos_token_id is not None: logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id return pad_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id elif ( hasattr(self.config, "decoder") and hasattr(self.config.decoder, "decoder_start_token_id") and self.config.decoder.decoder_start_token_id is not None ): return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id elif ( hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id") and self.config.decoder.bos_token_id is not None ): return self.config.decoder.bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: torch.LongTensor = None, encoder_outputs: ModelOutput = None, **model_kwargs ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) if is_encoder_decoder: assert encoder_outputs is not None encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs @staticmethod def _init_sequence_length_for_generation( input_ids: torch.LongTensor, max_length: int ) -> Tuple[torch.Tensor, torch.Tensor, int]: unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length) cur_len = input_ids.shape[-1] return sequence_lengths, unfinished_sequences, cur_len @staticmethod def _update_seq_length_for_generation( sequence_lengths: torch.LongTensor, unfinished_sequences: torch.LongTensor, cur_len: int, is_eos_in_next_token: torch.BoolTensor, ) -> Tuple[torch.LongTensor, torch.LongTensor]: # check if sentence is not finished yet is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool() # update sentence length sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len) unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long()) return sequence_lengths, unfinished_sequences @staticmethod def _update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False ) -> Dict[str, Any]: # update past if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values elif "mems" in outputs: model_kwargs["past"] = outputs.mems elif "past_buckets_states" in outputs: model_kwargs["past"] = outputs.past_buckets_states else: model_kwargs["past"] = None # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention mask if not is_encoder_decoder: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) return model_kwargs @staticmethod def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]: """ This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every generation step. For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in subclasses of :class:`~transformers.PreTrainedModel`. """ return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) def _get_logits_warper( self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant :obj:`~transformers.LogitsWarper` instances used for multinomial sampling. """ # init warp parameters top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if top_k is not None and top_k != 0: warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if temperature is not None and temperature != 1.0: warpers.append(TemperatureLogitsWarper(temperature)) return warpers def _get_logits_processor( self, repetition_penalty: float, no_repeat_ngram_size: int, bad_words_ids: List[List[int]], min_length: int, eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant :obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. """ # init warp parameters repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id # instantiate processors list processors = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if repetition_penalty is not None and repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if bad_words_ids is not None: processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > -1: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams)) return processors
[docs] @torch.no_grad() def generate( self, input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **model_kwargs ) -> torch.LongTensor: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, multinomial sampling, beam-search decoding, and beam-search multinomial sampling. Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default values of those config. Most of these parameters are explained in more detail in `this blog post <https://huggingface.co/blog/how-to-generate>`__. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. min_length (:obj:`int`, `optional`, defaults to 10): The minimum length of the sequence to be generated. do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use sampling ; use greedy decoding otherwise. early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. temperature (:obj:`float`, `optional`, defaults tp 1.0): The value used to module the next token probabilities. top_k (:obj:`int`, `optional`, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (:obj:`float`, `optional`, defaults to 1.0): If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation. repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. bos_token_id (:obj:`int`, `optional`): The id of the `beginning-of-sequence` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. length_penalty (:obj:`float`, `optional`, defaults to 1.0): Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences. no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. bad_words_ids(:obj:`List[List[int]]`, `optional`): List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use :obj:`tokenizer(bad_word, add_prefix_space=True).input_ids`. num_return_sequences(:obj:`int`, `optional`, defaults to 1): The number of independently computed returned sequences for each element in the batch. attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. `What are attention masks? <../glossary.html#attention-mask>`__ decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This argument is useful for constrained generation conditioned on the prefix, as described in `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with `decoder_`. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> # do greedy decoding without providing a prompt >>> outputs = model.generate(max_length=40) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> document = ( ... "at least two people were killed in a suspected bomb attack on a passenger bus " ... "in the strife-torn southern philippines on monday , the military said." ... ) >>> # encode input contex >>> input_ids = tokenizer(document, return_tensors="pt").input_ids >>> # generate 3 independent sequences using beam search decoding (5 beams) >>> # with T5 encoder-decoder model conditioned on short news article. >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> input_context = "The dog" >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> # generate 3 candidates using sampling >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("ctrl") >>> model = AutoModelForCausalLM.from_pretrained("ctrl") >>> # "Legal" is one of the control codes for ctrl >>> input_context = "Legal My neighbor is" >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> input_context = "My cute dog" >>> # get tokens of words that should not be generated >>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]] >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> # generate sequences without allowing bad_words to be generated >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) """ # set init values num_beams = num_beams if num_beams is not None else self.config.num_beams max_length = max_length if max_length is not None else self.config.max_length do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id if input_ids is None: # init `input_ids` with bos_token_id input_ids = self._prepare_input_ids_for_generation(bos_token_id) if model_kwargs.get("attention_mask", None) is None: # init `attention_mask` depending on `pad_token_id` model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id ) # special case if pad_token_id is not defined if pad_token_id is None and eos_token_id is not None: logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # set input_ids as decoder_input_ids input_ids = self._prepare_decoder_input_ids_for_generation( input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs ) if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") # determine generation mode is_greedy_gen_mode = (num_beams == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and do_sample is False is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True # set model_kwargs model_kwargs["use_cache"] = use_cache # get distribution pre_processing samplers logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, min_length=min_length, eos_token_id=eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, ) if is_greedy_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) # greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, **model_kwargs, ) elif is_sample_gen_mode: # get probability distribution warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # sample return self.sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, **model_kwargs, ) elif is_beam_gen_mode: batch_size = input_ids.shape[0] length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) return self.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, **model_kwargs, ) elif is_beam_sample_gen_mode: logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) batch_size = input_ids.shape[0] * num_return_sequences length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, ) # interleave with `num_beams * num_return_sequences` input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams * num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) return self.beam_sample( input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, **model_kwargs, )
[docs] def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, **model_kwargs ): r""" Generates sequences for models with a language modeling head using multinomial sampling. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. logits_processor (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling head applied at each generation step. logits_warper (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... ) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "Today is a beautiful day, and" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), ... ]) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList([ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ]) >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id # init sequence length tensors sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( input_ids, max_length ) # auto-regressive generation while cur_len < max_length: # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self(**model_inputs, return_dict=True) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution scores = logits_processor(input_ids, next_token_logits) scores = logits_warper(input_ids, scores) # sample probs = F.softmax(scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # add code that transfomers next_tokens to tokens_to_add if eos_token_id is not None: assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences) # add token and increase length by one input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) cur_len = cur_len + 1 # update sequence length if eos_token_id is not None: sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation( sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id ) # stop when there is a </s> in each sentence, or if we exceed the maximul length if unfinished_sequences.max() == 0: break # update model kwargs model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) return input_ids
[docs] def beam_sample( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, **model_kwargs ): r""" Generates sequences for models with a language modeling head using beam search with multinomial sampling. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. beam_scorer (:obj:`BeamScorer`): A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of :class:`~transformers.BeamScorer` should be read. logits_processor (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling head applied at each generation step. logits_warper (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id) ... ]) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList([ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ]) >>> outputs = model.beam_sample( ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... ) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs, return_dict=True) next_token_logits = outputs.logits[:, -1, :] # adjust token scores (a no-op by default) next_token_logits = self.adjust_logits_during_generation( next_token_logits, cur_len=cur_len, max_length=max_length ) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = F.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = next_tokens // vocab_size next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) if beam_scorer.is_done: break decoded = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id ) return decoded
def top_k_top_p_filtering( logits: torch.FloatTensor, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, ) -> torch.FloatTensor: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( None, logits ) if 0 <= top_p <= 1.0: logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) return logits