# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import torch
from torch.nn import functional as F
from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from .utils import logging
logger = logging.get_logger(__name__)
[docs]class GenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
:class:`~transformers.PreTrainedModel`.
"""
[docs] def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
"""
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method.
"""
return logits
def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor:
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int
) -> torch.LongTensor:
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
return input_ids.ne(pad_token_id).long()
return input_ids.new_ones(input_ids.shape)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
) -> Dict[str, Any]:
# retrieve encoder hidden states
encoder = self.get_encoder()
encoder_kwargs = {
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
}
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs
) -> torch.LongTensor:
if "decoder_input_ids" in model_kwargs:
return model_kwargs["decoder_input_ids"]
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = (
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
* decoder_start_token_id
)
return decoder_input_ids
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@staticmethod
def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: torch.LongTensor = None,
encoder_outputs: ModelOutput = None,
**model_kwargs
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = (
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
)
input_ids = input_ids.index_select(0, expanded_return_idx)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
if is_encoder_decoder:
assert encoder_outputs is not None
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx
)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
@staticmethod
def _init_sequence_length_for_generation(
input_ids: torch.LongTensor, max_length: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)
cur_len = input_ids.shape[-1]
return sequence_lengths, unfinished_sequences, cur_len
@staticmethod
def _update_seq_length_for_generation(
sequence_lengths: torch.LongTensor,
unfinished_sequences: torch.LongTensor,
cur_len: int,
is_eos_in_next_token: torch.BoolTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
# check if sentence is not finished yet
is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()
# update sentence length
sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
return sequence_lengths, unfinished_sequences
@staticmethod
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs
@staticmethod
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
"""
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
"""
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsWarper` instances used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
return warpers
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
"""
# init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# instantiate processors list
processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if bad_words_ids is not None:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
return processors
[docs] @torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
indicated are the default values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
min_length (:obj:`int`, `optional`, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (:obj:`float`, `optional`, defaults tp 1.0):
The value used to module the next token probabilities.
top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
higher are kept for generation.
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
bos_token_id (:obj:`int`, `optional`):
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(:obj:`List[List[int]]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same
shape as :obj:`input_ids` that masks the pad token. `What are attention masks?
<../glossary.html#attention-mask>`__
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
kwargs should be prefixed with `decoder_`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> # do greedy decoding without providing a prompt
>>> outputs = model.generate(max_length=40)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> document = (
... "at least two people were killed in a suspected bomb attack on a passenger bus "
... "in the strife-torn southern philippines on monday , the military said."
... )
>>> # encode input contex
>>> input_ids = tokenizer(document, return_tensors="pt").input_ids
>>> # generate 3 independent sequences using beam search decoding (5 beams)
>>> # with T5 encoder-decoder model conditioned on short news article.
>>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate 3 candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = AutoModelForCausalLM.from_pretrained("ctrl")
>>> # "Legal" is one of the control codes for ctrl
>>> input_context = "Legal My neighbor is"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "My cute dog"
>>> # get tokens of words that should not be generated
>>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]]
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate sequences without allowing bad_words to be generated
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
"""
# set init values
num_beams = num_beams if num_beams is not None else self.config.num_beams
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if input_ids is None:
# init `input_ids` with bos_token_id
input_ids = self._prepare_input_ids_for_generation(bos_token_id)
if model_kwargs.get("attention_mask", None) is None:
# init `attention_mask` depending on `pad_token_id`
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
)
# special case if pad_token_id is not defined
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# set input_ids as decoder_input_ids
input_ids = self._prepare_decoder_input_ids_for_generation(
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs
)
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
# determine generation mode
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True
# set model_kwargs
model_kwargs["use_cache"] = use_cache
# get distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# greedy search
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
elif is_sample_gen_mode:
# get probability distribution warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
elif is_beam_gen_mode:
batch_size = input_ids.shape[0]
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# interleave with `num_beams`
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
)
batch_size = input_ids.shape[0] * num_return_sequences
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
)
# interleave with `num_beams * num_return_sequences`
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_beams * num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
return self.beam_sample(
input_ids,
beam_scorer,
logits_processor=logits_processor,
logits_warper=logits_warper,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
[docs] def greedy_search(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**model_kwargs
):
r"""
Generates sequences for models with a language modeling head using greedy decoding.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ])
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = torch.argmax(scores, dim=-1)
# add code that transfomers next_tokens to tokens_to_add
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
# add token and increase length by one
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sequences.max() == 0:
break
# increase cur_len
cur_len = cur_len + 1
return input_ids
[docs] def sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**model_kwargs
):
r"""
Generates sequences for models with a language modeling head using multinomial sampling.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
logits_warper (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ])
>>> # instantiate logits processors
>>> logits_warper = LogitsProcessorList([
... TopKLogitsWarper(50),
... TemperatureLogitsWarper(0.7),
... ])
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# auto-regressive generation
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
scores = logits_processor(input_ids, next_token_logits)
scores = logits_warper(input_ids, scores)
# sample
probs = F.softmax(scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# add code that transfomers next_tokens to tokens_to_add
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
# add token and increase length by one
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
cur_len = cur_len + 1
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sequences.max() == 0:
break
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
return input_ids
[docs] def beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**model_kwargs
):
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
beam_scorer (:obj:`BeamScorer`):
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
constructed, stored and sorted during generation. For more information, the documentation of
:class:`~transformers.BeamScorer` should be read.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ])
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
assert (
num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
# adjust tokens for Bart, *e.g.*
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = next_tokens // vocab_size
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if beam_scorer.is_done:
break
decoded = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
return decoded
[docs] def beam_sample(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**model_kwargs
):
r"""
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
beam_scorer (:obj:`BeamScorer`):
A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
constructed, stored and sorted during generation. For more information, the documentation of
:class:`~transformers.BeamScorer` should be read.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
logits_warper (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)
... ])
>>> # instantiate logits processors
>>> logits_warper = LogitsProcessorList([
... TopKLogitsWarper(50),
... TemperatureLogitsWarper(0.7),
... ])
>>> outputs = model.beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
# adjust token scores (a no-op by default)
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = next_tokens // vocab_size
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if beam_scorer.is_done:
break
decoded = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
return decoded
def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
return logits