Models¶
The base classes PreTrainedModel
and TFPreTrainedModel
implement the
common methods for loading/saving a model either from a local file or directory, or from a pretrained model
configuration provided by the library (downloaded from HuggingFace’s AWS S3 repository).
PreTrainedModel
and TFPreTrainedModel
also implement a few methods which
are common among all the models to:
resize the input token embeddings when new tokens are added to the vocabulary
prune the attention heads of the model.
The other methods that are common to each model are defined in ModuleUtilsMixin
(for the PyTorch models) and TFModuleUtilsMixin
(for the TensorFlow models) or
for text generation, GenerationMixin
(for the PyTorch models) and
TFGenerationMixin
(for the TensorFlow models)
PreTrainedModel
¶
-
class
transformers.
PreTrainedModel
(config: transformers.configuration_utils.PretrainedConfig, *inputs, **kwargs)[source]¶ Base class for all models.
PreTrainedModel
takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:resize the input embeddings,
prune heads in the self-attention heads.
- Class attributes (overridden by derived classes):
config_class (
PretrainedConfig
) – A subclass ofPretrainedConfig
to use as configuration class for this model architecture.load_tf_weights (
Callable
) – A python method for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:model (
PreTrainedModel
) – An instance of the model on which to load the TensorFlow checkpoint.config (
PreTrainedConfig
) – An instance of the configuration associated to the model.path (
str
) – A path to the TensorFlow checkpoint.
base_model_prefix (
str
) – A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.authorized_missing_keys (
Optional[List[str]]
) – A list of re pattern of tensor names to ignore when loading the model (and avoid unnecessary warnings).keys_to_never_save (
Optional[List[str]]
) – A list of of tensor names to ignore when saving the model (useful for keys that aren’t trained, but which are deterministic)
-
property
base_model
¶ The main body of the model.
- Type
torch.nn.Module
-
property
dummy_inputs
¶ Dummy inputs to do a forward pass in the network.
- Type
Dict[str, torch.Tensor]
-
classmethod
from_pretrained
(pretrained_model_name_or_path, *model_args, **kwargs)[source]¶ Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using
model.eval()
(Dropout modules are deactivated). To train the model, you should first set it back in training mode withmodel.train()
.The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.
The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.
- Parameters
pretrained_model_name_or_path (
str
, optional) –Can be either:
A string with the shortcut name of a pretrained model to load from cache or download, e.g.,
bert-base-uncased
.A string with the identifier name of a pretrained model that was user-uploaded to our S3, e.g.,
dbmdz/bert-base-german-cased
.A path to a directory containing model weights saved using
save_pretrained()
, e.g.,./my_model_directory/
.A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.None
if you are both providing the configuration and state dictionary (resp. with keyword argumentsconfig
andstate_dict
).
model_args (sequence of positional arguments, optional) – All remaning positional arguments will be passed to the underlying model’s
__init__
method.config (
Union[PretrainedConfig, str]
, optional) –Can be either:
an instance of a class derived from
PretrainedConfig
,a string valid as input to
from_pretrained()
.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
The model is a model provided by the library (loaded with the shortcut name string of a pretrained model).
The model was saved using
save_pretrained()
and is reloaded by suppling the save directory.The model is loaded by suppling a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
state_dict (
Dict[str, torch.Tensor]
, optional) –A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using
save_pretrained()
andfrom_pretrained()
is not a simpler option.cache_dir (
str
, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.from_tf (
bool
, optional, defaults toFalse
) – Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument).force_download (
bool
, optional, defaults toFalse
) – Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.resume_download (
bool
, optional, defaults toFalse
) – Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists.proxies (
Dict[str, str], `optional
) – A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request.output_loading_info (
bool
, optional, defaults toFalse
) – Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error messages.local_files_only (
bool
, optional, defaults toFalse
) – Whether or not to only look at local files (e.g., not try doanloading the model).use_cdn (
bool
, optional, defaults toTrue
) – Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on our S3 (faster). Should be set toFalse
for checkpoints larger than 20GB.mirror (
str
, optional, defaults toNone
) – Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information.kwargs (remaining dictionary of keyword arguments, optional) –
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done)If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()
). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
Examples:
>>> from transformers import BertConfig, BertModel >>> # Download model and configuration from S3 and cache. >>> model = BertModel.from_pretrained('bert-base-uncased') >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). >>> model = BertModel.from_pretrained('./test/saved_model/') >>> # Update configuration during loading. >>> model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) >>> assert model.config.output_attentions == True >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') >>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
-
get_input_embeddings
() → torch.nn.modules.module.Module[source]¶ Returns the model’s input embeddings.
- Returns
A torch module mapping vocabulary to hidden states.
- Return type
nn.Module
-
get_output_embeddings
() → torch.nn.modules.module.Module[source]¶ Returns the model’s output embeddings.
- Returns
A torch module mapping hidden states to vocabulary.
- Return type
nn.Module
-
prune_heads
(heads_to_prune: Dict[int, List[int]])[source]¶ Prunes heads of the base model.
- Parameters
heads_to_prune (
Dict[int, List[int]]
) – Dictionary with keys being selected layer indices (int
) and associated values being the list of heads to prune in said layer (list ofint
). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
-
resize_token_embeddings
(new_num_tokens: Optional[int] = None) → torch.nn.modules.sparse.Embedding[source]¶ Resizes input token embeddings matrix of the model if
new_num_tokens != config.vocab_size
.Takes care of tying weights embeddings afterwards if the model class has a
tie_weights()
method.- Parameters
new_num_tokens (
int
, optional) – The number of new tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided orNone
, just returns a pointer to the input tokenstorch.nn.Embedding
module of the model wihtout doing anything.- Returns
Pointer to the input tokens Embeddings Module of the model.
- Return type
torch.nn.Embedding
-
save_pretrained
(save_directory)[source]¶ Save a model and its configuration file to a directory, so that it can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method.
- Parameters
save_directory (
str
) – Directory to which to save. Will be created if it doesn’t exist.
ModuleUtilsMixin
¶
-
class
transformers.modeling_utils.
ModuleUtilsMixin
[source]¶ A few utilities for
torch.nn.Modules
, to be used as a mixin.-
add_memory_hooks
()[source]¶ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a
mem_rss_diff
attribute for each module and can be reset to zero withmodel.reset_memory_hooks_state()
.
-
property
device
¶ The device on which the module is (assuming that all the module parameters are on the same device).
- Type
torch.device
-
property
dtype
¶ The dtype of the module (assuming that all the module parameters have the same dtype).
- Type
torch.dtype
-
estimate_tokens
(input_dict: Dict[str, Union[torch.Tensor, Any]]) → int[source]¶ Helper function to estimate the total number of tokens from the model inputs.
- Parameters
inputs (
dict
) – The model inputs.- Returns
The total number of tokens.
- Return type
int
-
floating_point_ops
(input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True) → int[source]¶ Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a batch with this transformer model. Default approximation neglects the quadratic dependency on the number of tokens (valid if
12 * d_model << sequence_length
) as laid out in this paper section 2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.- Parameters
batch_size (
int
) – The batch size for the forward pass.sequence_length (
int
) – The number of tokens in each line of the batch.exclude_embeddings (
bool
, optional, defaults toTrue
) – Whether or not to count embedding and softmax operations.
- Returns
The number of floating-point operations.
- Return type
int
-
get_extended_attention_mask
(attention_mask: torch.Tensor, input_shape: Tuple[int], device: <property object at 0x7fbf624a6728>) → torch.Tensor[source]¶ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
- Parameters
attention_mask (
torch.Tensor
) – Mask with ones indicating tokens to attend to, zeros for tokens to ignore.input_shape (
Tuple[int]
) – The shape of the input to the model.device – (
torch.device
): The device of the input to the model.
- Returns
torch.Tensor
The extended attention mask, with a the same dtype asattention_mask.dtype
.
-
get_head_mask
(head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False) → torch.Tensor[source]¶ Prepare the head mask if needed.
- Parameters
head_mask (
torch.Tensor
with shape[num_heads]
or[num_hidden_layers x num_heads]
, optional) – The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).num_hidden_layers (
int
) – The number of hidden layers in the model.is_attention_chunked – (
bool
, optional, defaults to :obj:`False): Whether or not the attentions scores are computed by chunks or not.
- Returns
torch.Tensor
with shape[num_hidden_layers x batch x num_heads x seq_length x seq_length]
or list with[None]
for each layer.
-
invert_attention_mask
(encoder_attention_mask: torch.Tensor) → torch.Tensor[source]¶ Invert an attention mask (e.g., switches 0. and 1.).
- Parameters
encoder_attention_mask (
torch.Tensor
) – An attention mask.- Returns
The inverted attention mask.
- Return type
torch.Tensor
-
num_parameters
(only_trainable: bool = False, exclude_embeddings: bool = False) → int[source]¶ Get number of (optionally, trainable or non-embeddings) parameters in the module.
- Parameters
only_trainable (
bool
, optional, defaults toFalse
) – Whether or not to return only the number of trainable parametersexclude_embeddings (
bool
, optional, defaults toFalse
) – Whether or not to return only the number of non-embeddings parameters
- Returns
The number of parameters.
- Return type
int
-
reset_memory_hooks_state
()[source]¶ Reset the
mem_rss_diff
attribute of each module (seeadd_memory_hooks()
).
-
TFPreTrainedModel
¶
-
class
transformers.
TFPreTrainedModel
(*args, **kwargs)[source]¶ Base class for all TF models.
TFPreTrainedModel
takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:resize the input embeddings,
prune heads in the self-attention heads.
- Class attributes (overridden by derived classes):
config_class (
PretrainedConfig
) – A subclass ofPretrainedConfig
to use as configuration class for this model architecture.base_model_prefix (
str
) – A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
-
property
dummy_inputs
¶ Dummy inputs to build the network.
- Returns
The dummy inputs.
- Return type
Dict[str, tf.Tensor]
-
classmethod
from_pretrained
(pretrained_model_name_or_path, *model_args, **kwargs)[source]¶ Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.
The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.
- Parameters
pretrained_model_name_or_path (
str
, optional) –Can be either:
A string with the shortcut name of a pretrained model to load from cache or download, e.g.,
bert-base-uncased
.A string with the identifier name of a pretrained model that was user-uploaded to our S3, e.g.,
dbmdz/bert-base-german-cased
.A path to a directory containing model weights saved using
save_pretrained()
, e.g.,./my_model_directory/
.A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.None
if you are both providing the configuration and state dictionary (resp. with keyword argumentsconfig
andstate_dict
).
model_args (sequence of positional arguments, optional) – All remaning positional arguments will be passed to the underlying model’s
__init__
method.config (
Union[PretrainedConfig, str]
, optional) –Can be either:
an instance of a class derived from
PretrainedConfig
,a string valid as input to
from_pretrained()
.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
The model is a model provided by the library (loaded with the shortcut name string of a pretrained model).
The model was saved using
save_pretrained()
and is reloaded by suppling the save directory.The model is loaded by suppling a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
from_pt – (
bool
, optional, defaults toFalse
): Load the model weights from a PyTorch state_dict save file (see docstring ofpretrained_model_name_or_path
argument).cache_dir (
str
, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.force_download (
bool
, optional, defaults toFalse
) – Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.resume_download (
bool
, optional, defaults toFalse
) – Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists.proxies – (
Dict[str, str], `optional
): A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request.output_loading_info (
bool
, optional, defaults toFalse
) – Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error messages.local_files_only (
bool
, optional, defaults toFalse
) – Whether or not to only look at local files (e.g., not try doanloading the model).use_cdn (
bool
, optional, defaults toTrue
) – Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on our S3 (faster). Should be set toFalse
for checkpoints larger than 20GB.mirror (
str
, optional, defaults toNone
) – Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information.kwargs (remaining dictionary of keyword arguments, optional) –
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done)If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()
). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
Examples:
>>> from transformers import BertConfig, TFBertModel >>> # Download model and configuration from S3 and cache. >>> model = TFBertModel.from_pretrained('bert-base-uncased') >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). >>> model = TFBertModel.from_pretrained('./test/saved_model/') >>> # Update configuration during loading. >>> model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True) >>> assert model.config.output_attentions == True >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json') >>> model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
-
get_input_embeddings
() → tensorflow.python.keras.engine.base_layer.Layer[source]¶ Returns the model’s input embeddings.
- Returns
A torch module mapping vocabulary to hidden states.
- Return type
tf.keras.layers.Layer
-
get_output_embeddings
() → tensorflow.python.keras.engine.base_layer.Layer[source]¶ Returns the model’s output embeddings.
- Returns
A torch module mapping hidden states to vocabulary.
- Return type
tf.keras.layers.Layer
-
prune_heads
(heads_to_prune)[source]¶ Prunes heads of the base model.
- Parameters
heads_to_prune (
Dict[int, List[int]]
) – Dictionary with keys being selected layer indices (int
) and associated values being the list of heads to prune in said layer (list ofint
). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
-
resize_token_embeddings
(new_num_tokens=None) → tensorflow.python.ops.variables.Variable[source]¶ Resizes input token embeddings matrix of the model if
new_num_tokens != config.vocab_size
.Takes care of tying weights embeddings afterwards if the model class has a
tie_weights()
method.- Parameters
new_num_tokens (
int
, optional) – The number of new tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided orNone
, just returns a pointer to the input tokenstf.Variable
module of the model wihtout doing anything.- Returns
Pointer to the input tokens Embeddings Module of the model.
- Return type
tf.Variable
-
save_pretrained
(save_directory)[source]¶ Save a model and its configuration file to a directory, so that it can be re-loaded using the
from_pretrained()
class method.- Parameters
save_directory (
str
) – Directory to which to save. Will be created if it doesn’t exist.
TFModelUtilsMixin
¶
-
class
transformers.modeling_tf_utils.
TFModelUtilsMixin
[source]¶ A few utilities for
tf.keras.Model
, to be used as a mixin.-
num_parameters
(only_trainable: bool = False) → int[source]¶ Get the number of (optionally, trainable) parameters in the model.
- Parameters
only_trainable (
bool
, optional, defaults toFalse
) – Whether or not to return only the number of trainable parameters- Returns
The number of parameters.
- Return type
int
-
Generative models¶
-
class
transformers.generation_utils.
GenerationMixin
[source]¶ A class contraining all of the functions supporting generation, to be used as a mixin in
PreTrainedModel
.-
adjust_logits_during_generation
(logits, **kwargs)[source]¶ Implement in subclasses of
PreTrainedModel
for custom behavior to adjust the logits in the generate method.
-
enforce_repetition_penalty_
(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)[source]¶ Enforce the repetition penalty (from the CTRL paper).
-
generate
(input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, **model_kwargs) → torch.LongTensor[source]¶ Generates sequences for models with a language modeling head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from Facebook’s XLM beam search code.
Apart from
input_ids
andattention_mask
, all the arguments below will default to the value of the attribute of the same name inside thePretrainedConfig
of the model. The default values indicated are the default values of those config.Most of these parameters are explained in more detail in this blog post.
- Parameters
input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) – The sequence used as a prompt for the generation. IfNone
the method initializes it as an emptytorch.LongTensor
of shape(1,)
.max_length (
int
, optional, defaults to 20) – The maximum length of the sequence to be generated.min_length (
int
, optional, defaults to 10) – The minimum length of the sequence to be generated.do_sample (
bool
, optional, defaults toFalse
) – Whether or not to use sampling ; use greedy decoding otherwise.early_stopping (
bool
, optional, defaults toFalse
) – Whether to stop the beam search when at leastnum_beams
sentences are finished per batch or not.num_beams (
int
, optional, defaults to 1) – Number of beams for beam search. 1 means no beam search.temperature (
float
, optional, defaults tp 1.0) – The value used to module the next token probabilities.top_k (
int
, optional, defaults to 50) – The number of highest probability vocabulary tokens to keep for top-k-filtering.top_p (
float
, optional, defaults to 1.0) – If set to float < 1, only the most probable tokens with probabilities that add up totop_p
or higher are kept for generation.repetition_penalty (
float
, optional, defaults to 1.0) – The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.pad_token_id (
int
, optional) – The id of the padding token.bos_token_id (
int
, optional) – The id of the beginning-of-sequence token.eos_token_id (
int
, optional) – The id of the end-of-sequence token.length_penalty (
float
, optional, defaults to 1.0) –Exponential penalty to the length. 1.0 means no penalty.
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences.
no_repeat_ngram_size (
int
, optional, defaults to 0) – If set to int > 0, all ngrams of that size can only occur once.bad_words_ids (
List[int]
, optional) – List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, usetokenizer.encode(bad_word, add_prefix_space=True)
.num_return_sequences (
int
, optional, defaults to 1) – The number of independently computed returned sequences for each element in the batch.attention_mask (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) –Mask to avoid performing attention on padding token indices. Mask values are in
[0, 1]
, 1 for tokens that are not masked, and 0 for masked tokens.If not provided, will default to a tensor the same shape as
input_ids
that masks the pad token.decoder_start_token_id (
int
, optional) – If an encoder-decoder model starts decoding with a different token than bos, the id of that token.use_cache – (
bool
, optional, defaults toTrue
): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.model_kwargs – Additional model specific kwargs will be forwarded to the
forward
function of the model.
- Returns
The generated sequences. The second dimension (sequence_length) is either equal to
max_length
or shorter if all batches finished early due to theeos_token_id
.- Return type
torch.LongTensor
of shape(batch_size * num_return_sequences, sequence_length)
Examples:
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. outputs = model.generate(max_length=40) # do greedy decoding print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
-
-
class
transformers.generation_tf_utils.
TFGenerationMixin
[source]¶ A class contraining all of the functions supporting generation, to be used as a mixin in
TFPreTrainedModel
.-
generate
(input_ids=None, max_length=None, min_length=None, do_sample=None, early_stopping=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bad_words_ids=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, num_return_sequences=None, attention_mask=None, decoder_start_token_id=None, use_cache=None)[source]¶ Generates sequences for models with a language modeling head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from Facebook’s XLM beam search code.
Apart from
input_ids
andattention_mask
, all the arguments below will default to the value of the attribute of the same name inside thePretrainedConfig
of the model. The default values indicated are the default values of those config.Most of these parameters are explained in more detail in this blog post.
- Parameters
input_ids (
tf.Tensor
ofdtype=tf.int32
and shape(batch_size, sequence_length)
, optional) – The sequence used as a prompt for the generation. IfNone
the method initializes it as an emptytf.Tensor
of shape(1,)
.max_length (
int
, optional, defaults to 20) – The maximum length of the sequence to be generated.min_length (
int
, optional, defaults to 10) – The minimum length of the sequence to be generated.do_sample (
bool
, optional, defaults toFalse
) – Whether or not to use sampling ; use greedy decoding otherwise.early_stopping (
bool
, optional, defaults toFalse
) – Whether to stop the beam search when at leastnum_beams
sentences are finished per batch or not.num_beams (
int
, optional, defaults to 1) – Number of beams for beam search. 1 means no beam search.temperature (
float
, optional, defaults to 1.0) – The value used to module the next token probabilities.top_k (
int
, optional, defaults to 50) – The number of highest probability vocabulary tokens to keep for top-k-filtering.top_p (
float
, optional, defaults to 1.0) – If set to float < 1, only the most probable tokens with probabilities that add up totop_p
or higher are kept for generation.repetition_penalty (
float
, optional, defaults to 1.0) – The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.pad_token_id (
int
, optional) – The id of the padding token.bos_token_id (
int
, optional) – The id of the beginning-of-sequence token.eos_token_id (
int
, optional) – The id of the end-of-sequence token.length_penalty (
float
, optional, defaults to 1.0) –Exponential penalty to the length. 1.0 means no penalty.
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences.
no_repeat_ngram_size (
int
, optional, defaults to 0) – If set to int > 0, all ngrams of that size can only occur once.bad_words_ids (
List[int]
, optional) – List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, usetokenizer.encode(bad_word, add_prefix_space=True)
.num_return_sequences (
int
, optional, defaults to 1) – The number of independently computed returned sequences for each element in the batch.attention_mask (
tf.Tensor
ofdtype=tf.int32
and shape(batch_size, sequence_length)
, optional) –Mask to avoid performing attention on padding token indices. Mask values are in
[0, 1]
, 1 for tokens that are not masked, and 0 for masked tokens.If not provided, will default to a tensor the same shape as
input_ids
that masks the pad token.decoder_start_token_id (
int
, optional) – If an encoder-decoder model starts decoding with a different token than bos, the id of that token.use_cache – (
bool
, optional, defaults toTrue
): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.model_specific_kwargs – Additional model specific kwargs will be forwarded to the
forward
function of the model.
- Returns
The generated sequences. The second dimension (sequence_length) is either equal to
max_length
or shorter if all batches finished early due to theeos_token_id
.- Return type
tf.Tensor
ofdtype=tf.int32
and shape(batch_size * num_return_sequences, sequence_length)
Examples:
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. outputs = model.generate(max_length=40) # do greedy decoding print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer model = TFAutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer model = TFAutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer model = TFAutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. input_context = 'My cute dog' bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
-