Models¶

The base classes PreTrainedModel and TFPreTrainedModel implement the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace’s AWS S3 repository).

PreTrainedModel and TFPreTrainedModel also implement a few methods which are common among all the models to:

  • resize the input token embeddings when new tokens are added to the vocabulary

  • prune the attention heads of the model.

The other methods that are common to each model are defined in ModuleUtilsMixin (for the PyTorch models) and TFModuleUtilsMixin (for the TensorFlow models) or for text generation, GenerationMixin (for the PyTorch models) and TFGenerationMixin (for the TensorFlow models)

PreTrainedModel¶

class transformers.PreTrainedModel(config: transformers.configuration_utils.PretrainedConfig, *inputs, **kwargs)[source]¶

Base class for all models.

PreTrainedModel takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:

  • resize the input embeddings,

  • prune heads in the self-attention heads.

Class attributes (overridden by derived classes):
  • config_class (PretrainedConfig) – A subclass of PretrainedConfig to use as configuration class for this model architecture.

  • load_tf_weights (Callable) – A python method for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

    • model (PreTrainedModel) – An instance of the model on which to load the TensorFlow checkpoint.

    • config (PreTrainedConfig) – An instance of the configuration associated to the model.

    • path (str) – A path to the TensorFlow checkpoint.

  • base_model_prefix (str) – A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.

  • authorized_missing_keys (Optional[List[str]]) – A list of re pattern of tensor names to ignore when loading the model (and avoid unnecessary warnings).

property base_model¶

The main body of the model.

Type

torch.nn.Module

property dummy_inputs¶

Dummy inputs to do a forward pass in the network.

Type

Dict[str, torch.Tensor]

classmethod from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)[source]¶

Instantiate a pretrained pytorch model from a pre-trained model configuration.

The model is set in evaluation mode by default using model.eval() (Dropout modules are deactivated). To train the model, you should first set it back in training mode with model.train().

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Parameters
  • pretrained_model_name_or_path (str, optional) –

    Can be either:

    • A string with the shortcut name of a pretrained model to load from cache or download, e.g., bert-base-uncased.

    • A string with the identifier name of a pretrained model that was user-uploaded to our S3, e.g., dbmdz/bert-base-german-cased.

    • A path to a directory containing model weights saved using save_pretrained(), e.g., ./my_model_directory/.

    • A path or url to a tensorflow index checkpoint file (e.g, ./tf_model/model.ckpt.index). In this case, from_tf should be set to True and a configuration object should be provided as config argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

    • None if you are both providing the configuration and state dictionary (resp. with keyword arguments config and state_dict).

  • model_args (sequence of positional arguments, optional) – All remaning positional arguments will be passed to the underlying model’s __init__ method.

  • config (Union[PretrainedConfig, str], optional) –

    Can be either:

    Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

    • The model is a model provided by the library (loaded with the shortcut name string of a pretrained model).

    • The model was saved using save_pretrained() and is reloaded by suppling the save directory.

    • The model is loaded by suppling a local directory as pretrained_model_name_or_path and a configuration JSON file named config.json is found in the directory.

  • state_dict (Dict[str, torch.Tensor], optional) –

    A state dictionary to use instead of a state dictionary loaded from saved weights file.

    This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.

  • cache_dir (str, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • from_tf (bool, optional, defaults to False) – Load the model weights from a TensorFlow checkpoint save file (see docstring of pretrained_model_name_or_path argument).

  • force_download (bool, optional, defaults to False) – Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.

  • resume_download (bool, optional, defaults to False) – Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists.

  • proxies (Dict[str, str], `optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.

  • output_loading_info (bool, optional, defaults to False) – Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error messages.

  • local_files_only (bool, optional, defaults to False) – Whether or not to only look at local files (e.g., not try doanloading the model).

  • use_cdn (bool, optional, defaults to True) – Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on our S3 (faster). Should be set to False for checkpoints larger than 20GB.

  • kwargs (remaining dictionary of keyword arguments, optional) –

    Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., output_attention=True). Behaves differently depending on whether a config is provided or automatically loaded:

    • If a configuration is provided with config, **kwargs will be directly passed to the underlying model’s __init__ method (we assume all relevant updates to the configuration have already been done)

    • If a configuration is not provided, kwargs will be first passed to the configuration class initialization function (from_pretrained()). Each key of kwargs that corresponds to a configuration attribute will be used to override said attribute with the supplied kwargs value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s __init__ function.

Examples:

from transformers import BertConfig, BertModel
# Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('bert-base-uncased')
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = BertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
get_input_embeddings() → torch.nn.modules.module.Module[source]¶

Returns the model’s input embeddings.

Returns

A torch module mapping vocabulary to hidden states.

Return type

nn.Module

get_output_embeddings() → torch.nn.modules.module.Module[source]¶

Returns the model’s output embeddings.

Returns

A torch module mapping hidden states to vocabulary.

Return type

nn.Module

init_weights()[source]¶

Initializes and prunes weights if needed.

prune_heads(heads_to_prune: Dict[int, List[int]])[source]¶

Prunes heads of the base model.

Parameters

heads_to_prune (Dict[int, List[int]]) – Dictionary with keys being selected layer indices (int) and associated values being the list of heads to prune in said layer (list of int). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.

resize_token_embeddings(new_num_tokens: Optional[int] = None) → torch.nn.modules.sparse.Embedding[source]¶

Resizes input token embeddings matrix of the model if new_num_tokens != config.vocab_size.

Takes care of tying weights embeddings afterwards if the model class has a tie_weights() method.

Parameters

new_num_tokens (int, optional) – The number of new tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or None, just returns a pointer to the input tokens torch.nn.Embedding module of the model wihtout doing anything.

Returns

Pointer to the input tokens Embeddings Module of the model.

Return type

torch.nn.Embedding

save_pretrained(save_directory)[source]¶

Save a model and its configuration file to a directory, so that it can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method.

Parameters

save_directory (str) – Directory to which to save. Will be created if it doesn’t exist.

set_input_embeddings(value: torch.nn.modules.module.Module)[source]¶

Set model’s input embeddings

Parameters

value (nn.Module) – A module mapping vocabulary to hidden states.

tie_weights()[source]¶

Tie the weights between the input embeddings and the output embeddings.

If the torchscript flag is set in the configuration, can’t handle parameter sharing so we are cloning the weights instead.

ModuleUtilsMixin¶

class transformers.modeling_utils.ModuleUtilsMixin[source]¶

A few utilities for torch.nn.Modules, to be used as a mixin.

add_memory_hooks()[source]¶

Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.

Increase in memory consumption is stored in a mem_rss_diff attribute for each module and can be reset to zero with model.reset_memory_hooks_state().

property device¶

The device on which the module is (assuming that all the module parameters are on the same device).

Type

torch.device

property dtype¶

The dtype of the module (assuming that all the module parameters have the same dtype).

Type

torch.dtype

get_extended_attention_mask(attention_mask: torch.Tensor, input_shape: Tuple[int], device: <property object at 0x7f41be15bcc8>) → torch.Tensor[source]¶

Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

Parameters
  • attention_mask (torch.Tensor) – Mask with ones indicating tokens to attend to, zeros for tokens to ignore.

  • input_shape (Tuple[int]) – The shape of the input to the model.

  • device – (torch.device): The device of the input to the model.

Returns

torch.Tensor The extended attention mask, with a the same dtype as attention_mask.dtype.

get_head_mask(head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False) → torch.Tensor[source]¶

Prepare the head mask if needed.

Parameters
  • head_mask (torch.Tensor with shape [num_heads] or [num_hidden_layers x num_heads], optional) – The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).

  • num_hidden_layers (int) – The number of hidden layers in the model.

  • is_attention_chunked – (bool, optional, defaults to :obj:`False): Whether or not the attentions scores are computed by chunks or not.

Returns

torch.Tensor with shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] or list with [None] for each layer.

invert_attention_mask(encoder_attention_mask: torch.Tensor) → torch.Tensor[source]¶

Invert an attention mask (e.g., switches 0. and 1.).

Parameters

encoder_attention_mask (torch.Tensor) – An attention mask.

Returns

The inverted attention mask.

Return type

torch.Tensor

num_parameters(only_trainable: bool = False) → int[source]¶

Get the number of (optionally, trainable) parameters in the model.

Parameters

only_trainable (bool, optional, defaults to False) – Whether or not to return only the number of trainable parameters

Returns

The number of parameters.

Return type

int

reset_memory_hooks_state()[source]¶

Reset the mem_rss_diff attribute of each module (see add_memory_hooks()).

TFPreTrainedModel¶

class transformers.TFPreTrainedModel(*args, **kwargs)[source]¶

Base class for all TF models.

TFPreTrainedModel takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:

  • resize the input embeddings,

  • prune heads in the self-attention heads.

Class attributes (overridden by derived classes):
  • config_class (PretrainedConfig) – A subclass of PretrainedConfig to use as configuration class for this model architecture.

  • base_model_prefix (str) – A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.

property dummy_inputs¶

Dummy inputs to build the network.

Returns

The dummy inputs.

Return type

Dict[str, tf.Tensor]

classmethod from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)[source]¶

Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Parameters
  • pretrained_model_name_or_path (str, optional) –

    Can be either:

    • A string with the shortcut name of a pretrained model to load from cache or download, e.g., bert-base-uncased.

    • A string with the identifier name of a pretrained model that was user-uploaded to our S3, e.g., dbmdz/bert-base-german-cased.

    • A path to a directory containing model weights saved using save_pretrained(), e.g., ./my_model_directory/.

    • A path or url to a PyTorch state_dict save file (e.g, ./pt_model/pytorch_model.bin). In this case, from_pt should be set to True and a configuration object should be provided as config argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

    • None if you are both providing the configuration and state dictionary (resp. with keyword arguments config and state_dict).

  • model_args (sequence of positional arguments, optional) – All remaning positional arguments will be passed to the underlying model’s __init__ method.

  • config (Union[PretrainedConfig, str], optional) –

    Can be either:

    Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

    • The model is a model provided by the library (loaded with the shortcut name string of a pretrained model).

    • The model was saved using save_pretrained() and is reloaded by suppling the save directory.

    • The model is loaded by suppling a local directory as pretrained_model_name_or_path and a configuration JSON file named config.json is found in the directory.

  • from_pt – (bool, optional, defaults to False): Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

  • cache_dir (str, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.

  • resume_download (bool, optional, defaults to False) – Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists.

  • proxies – (Dict[str, str], `optional): A dictionary of proxy servers to use by protocol or endpoint, e.g., {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.

  • output_loading_info (bool, optional, defaults to False) – Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error messages.

  • local_files_only (bool, optional, defaults to False) – Whether or not to only look at local files (e.g., not try doanloading the model).

  • use_cdn (bool, optional, defaults to True) – Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on our S3 (faster). Should be set to False for checkpoints larger than 20GB.

  • kwargs (remaining dictionary of keyword arguments, optional) –

    Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., output_attention=True). Behaves differently depending on whether a config is provided or automatically loaded:

    • If a configuration is provided with config, **kwargs will be directly passed to the underlying model’s __init__ method (we assume all relevant updates to the configuration have already been done)

    • If a configuration is not provided, kwargs will be first passed to the configuration class initialization function (from_pretrained()). Each key of kwargs that corresponds to a configuration attribute will be used to override said attribute with the supplied kwargs value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s __init__ function.

Examples:

from transformers import BertConfig, TFBertModel
# Download model and configuration from S3 and cache.
model = TFBertModel.from_pretrained('bert-base-uncased')
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = TFBertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = TFBertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True
# Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
get_input_embeddings() → tensorflow.python.keras.engine.base_layer.Layer[source]¶

Returns the model’s input embeddings.

Returns

A torch module mapping vocabulary to hidden states.

Return type

tf.keras.layers.Layer

get_output_embeddings() → tensorflow.python.keras.engine.base_layer.Layer[source]¶

Returns the model’s output embeddings.

Returns

A torch module mapping hidden states to vocabulary.

Return type

tf.keras.layers.Layer

prune_heads(heads_to_prune)[source]¶

Prunes heads of the base model.

Parameters

heads_to_prune (Dict[int, List[int]]) – Dictionary with keys being selected layer indices (int) and associated values being the list of heads to prune in said layer (list of int). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.

resize_token_embeddings(new_num_tokens=None) → tensorflow.python.ops.variables.Variable[source]¶

Resizes input token embeddings matrix of the model if new_num_tokens != config.vocab_size.

Takes care of tying weights embeddings afterwards if the model class has a tie_weights() method.

Parameters

new_num_tokens (int, optional) – The number of new tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or None, just returns a pointer to the input tokens tf.Variable module of the model wihtout doing anything.

Returns

Pointer to the input tokens Embeddings Module of the model.

Return type

tf.Variable

save_pretrained(save_directory)[source]¶

Save a model and its configuration file to a directory, so that it can be re-loaded using the from_pretrained() class method.

Parameters

save_directory (str) – Directory to which to save. Will be created if it doesn’t exist.

set_input_embeddings(value)[source]¶

Set model’s input embeddings.

Parameters

value (tf.keras.layers.Layer) – A module mapping vocabulary to hidden states.

TFModelUtilsMixin¶

class transformers.modeling_tf_utils.TFModelUtilsMixin[source]¶

A few utilities for tf.keras.Model, to be used as a mixin.

num_parameters(only_trainable: bool = False) → int[source]¶

Get the number of (optionally, trainable) parameters in the model.

Parameters

only_trainable (bool, optional, defaults to False) – Whether or not to return only the number of trainable parameters

Returns

The number of parameters.

Return type

int

Generative models¶

class transformers.generation_utils.GenerationMixin[source]¶

A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.

adjust_logits_during_generation(logits, **kwargs)[source]¶

Implement in subclasses of PreTrainedModel for custom behavior to adjust the logits in the generate method.

enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)[source]¶

Enforce the repetition penalty (from the CTRL paper).

generate(input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, **model_kwargs) → torch.LongTensor[source]¶

Generates sequences for models with a language modeling head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.

Adapted in part from Facebook’s XLM beam search code.

Apart from input_ids and attention_mask, all the arguments below will default to the value of the attribute of the same name inside the PretrainedConfig of the model. The default values indicated are the default values of those config.

Most of these parameters are explained in more detail in this blog post.

Parameters
  • input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) – The sequence used as a prompt for the generation. If None the method initializes it as an empty torch.LongTensor of shape (1,).

  • max_length (int, optional, defaults to 20) – The maximum length of the sequence to be generated.

  • min_length (int, optional, defaults to 10) – The minimum length of the sequence to be generated.

  • do_sample (bool, optional, defaults to False) – Whether or not to use sampling ; use greedy decoding otherwise.

  • early_stopping (bool, optional, defaults to False) – Whether to stop the beam search when at least num_beams sentences are finished per batch or not.

  • num_beams (int, optional, defaults to 1) – Number of beams for beam search. 1 means no beam search.

  • temperature (float, optional, defaults tp 1.0) – The value used to module the next token probabilities.

  • top_k (int, optional, defaults to 50) – The number of highest probability vocabulary tokens to keep for top-k-filtering.

  • top_p (float, optional, defaults to 1.0) – If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.

  • repetition_penalty (float, optional, defaults to 1.0) – The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.

  • pad_token_id (int, optional) – The id of the padding token.

  • bos_token_id (int, optional) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional) – The id of the end-of-sequence token.

  • length_penalty (float, optional, defaults to 1.0) –

    Exponential penalty to the length. 1.0 means no penalty.

    Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences.

  • no_repeat_ngram_size (int, optional, defaults to 0) – If set to int > 0, all ngrams of that size can only occur once.

  • bad_words_ids (List[int], optional) – List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use tokenizer.encode(bad_word, add_prefix_space=True).

  • num_return_sequences (int, optional, defaults to 1) – The number of independently computed returned sequences for each element in the batch.

  • attention_mask (torch.LongTensor of shape (batch_size, sequence_length), optional) –

    Mask to avoid performing attention on padding token indices. Mask values are in [0, 1], 1 for tokens that are not masked, and 0 for masked tokens.

    If not provided, will default to a tensor the same shape as input_ids that masks the pad token.

    What are attention masks?

  • decoder_start_token_id (int, optional) – If an encoder-decoder model starts decoding with a different token than bos, the id of that token.

  • use_cache – (bool, optional, defaults to True): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.

  • model_kwargs – Additional model specific kwargs will be forwarded to the forward function of the model.

Returns

The generated sequences. The second dimension (sequence_length) is either equal to max_length or shorter if all batches finished early due to the eos_token_id.

Return type

torch.LongTensor of shape (batch_size * num_return_sequences, sequence_length)

Examples:

tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40)  # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): #  3 output sequences were generated
    print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True)  # generate 3 candidates using sampling
for i in range(3): #  3 output sequences were generated
    print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
prepare_inputs_for_generation(input_ids, **kwargs)[source]¶

Implement in subclasses of PreTrainedModel for custom behavior to prepare inputs in the generate method.

class transformers.generation_tf_utils.TFGenerationMixin[source]¶

A class contraining all of the functions supporting generation, to be used as a mixin in TFPreTrainedModel.

generate(input_ids=None, max_length=None, min_length=None, do_sample=None, early_stopping=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bad_words_ids=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, num_return_sequences=None, attention_mask=None, decoder_start_token_id=None, use_cache=None)[source]¶

Generates sequences for models with a language modeling head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.

Adapted in part from Facebook’s XLM beam search code.

Apart from input_ids and attention_mask, all the arguments below will default to the value of the attribute of the same name inside the PretrainedConfig of the model. The default values indicated are the default values of those config.

Most of these parameters are explained in more detail in this blog post.

Parameters
  • input_ids (tf.Tensor of dtype=tf.int32 and shape (batch_size, sequence_length), optional) – The sequence used as a prompt for the generation. If None the method initializes it as an empty tf.Tensor of shape (1,).

  • max_length (int, optional, defaults to 20) – The maximum length of the sequence to be generated.

  • min_length (int, optional, defaults to 10) – The minimum length of the sequence to be generated.

  • do_sample (bool, optional, defaults to False) – Whether or not to use sampling ; use greedy decoding otherwise.

  • early_stopping (bool, optional, defaults to False) – Whether to stop the beam search when at least num_beams sentences are finished per batch or not.

  • num_beams (int, optional, defaults to 1) – Number of beams for beam search. 1 means no beam search.

  • temperature (float, optional, defaults tp 1.0) – The value used to module the next token probabilities.

  • top_k (int, optional, defaults to 50) – The number of highest probability vocabulary tokens to keep for top-k-filtering.

  • top_p (float, optional, defaults to 1.0) – If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.

  • repetition_penalty (float, optional, defaults to 1.0) – The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.

  • pad_token_id (int, optional) – The id of the padding token.

  • bos_token_id (int, optional) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional) – The id of the end-of-sequence token.

  • length_penalty (float, optional, defaults to 1.0) –

    Exponential penalty to the length. 1.0 means no penalty.

    Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences.

  • no_repeat_ngram_size (int, optional, defaults to 0) – If set to int > 0, all ngrams of that size can only occur once.

  • bad_words_ids (List[int], optional) – List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use tokenizer.encode(bad_word, add_prefix_space=True).

  • num_return_sequences (int, optional, defaults to 1) – The number of independently computed returned sequences for each element in the batch.

  • attention_mask (tf.Tensor of dtype=tf.int32 and shape (batch_size, sequence_length), optional) –

    Mask to avoid performing attention on padding token indices. Mask values are in [0, 1], 1 for tokens that are not masked, and 0 for masked tokens.

    If not provided, will default to a tensor the same shape as input_ids that masks the pad token.

    What are attention masks?

  • decoder_start_token_id (int, optional) – If an encoder-decoder model starts decoding with a different token than bos, the id of that token.

  • use_cache – (bool, optional, defaults to True): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.

  • model_specific_kwargs – Additional model specific kwargs will be forwarded to the forward function of the model.

Returns

The generated sequences. The second dimension (sequence_length) is either equal to max_length or shorter if all batches finished early due to the eos_token_id.

Return type

tf.Tensor of dtype=tf.int32 and shape (batch_size * num_return_sequences, sequence_length)

Examples:

tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40)  # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): #  3 output sequences were generated
    print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True)  # generate 3 candidates using sampling
for i in range(3): #  3 output sequences were generated
    print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
input_context = 'My cute dog'
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
prepare_inputs_for_generation(inputs, **kwargs)[source]¶

Implement in subclasses of TFPreTrainedModel for custom behavior to prepare inputs in the generate method.