Encoder Decoder Models¶

The EncoderDecoderModel can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder.

The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation tasks was shown in Leveraging Pre-trained Checkpoints for Sequence Generation Tasks by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.

After such an EncoderDecoderModel has been trained/fine-tuned, it can be saved/loaded just like any other models (see the examples for more information).

An application of this architecture could be to leverage two pretrained BertModel as the encoder and decoder for a summarization model as was shown in: Text Summarization with Pretrained Encoders by Yang Liu and Mirella Lapata.

EncoderDecoderConfig¶

class transformers.EncoderDecoderConfig(**kwargs)[source]¶

EncoderDecoderConfig is the configuration class to store the configuration of a EncoderDecoderModel. It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Parameters

kwargs (optional) –

Dictionary of keyword arguments. Notably:

  • encoder (PretrainedConfig, optional) – An instance of a configuration object that defines the encoder config.

  • decoder (PretrainedConfig, optional) – An instance of a configuration object that defines the decoder config.

Examples:

>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

>>> # Initializing a BERT bert-base-uncased style configuration
>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()

>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

>>> # Initializing a Bert2Bert model from the bert-base-uncased style configurations
>>> model = EncoderDecoderModel(config=config)

>>> # Accessing the model configuration
>>> config_encoder = model.config.encoder
>>> config_decoder  = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True

>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')

>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model')
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
classmethod from_encoder_decoder_configs(encoder_config: transformers.configuration_utils.PretrainedConfig, decoder_config: transformers.configuration_utils.PretrainedConfig, **kwargs) → transformers.configuration_utils.PretrainedConfig[source]¶

Instantiate a EncoderDecoderConfig (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.

Returns

An instance of a configuration object

Return type

EncoderDecoderConfig

to_dict()[source]¶

Serializes this instance to a Python dictionary. Override the default to_dict() from PretrainedConfig.

Returns

Dictionary of all the attributes that make up this configuration instance,

Return type

Dict[str, any]

EncoderDecoderModel¶

class transformers.EncoderDecoderModel(config: Optional[transformers.configuration_utils.PretrainedConfig] = None, encoder: Optional[transformers.modeling_utils.PreTrainedModel] = None, decoder: Optional[transformers.modeling_utils.PreTrainedModel] = None)[source]¶

This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via from_pretrained() function and the decoder is loaded via from_pretrained() function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, like summarization.

The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation tasks was shown in Leveraging Pre-trained Checkpoints for Sequence Generation Tasks by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.

After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models (see the examples for more information).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (T5Config) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

EncoderDecoder is a generic model class that will be instantiated as a transformer architecture with one of the base model classes of the library as encoder and another one as decoder when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and :meth`~transformers.AutoModelForCausalLM.from_pretrained` class method for the decoder.

forward(input_ids=None, inputs_embeds=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_inputs_embeds=None, labels=None, return_dict=None, **kwargs)[source]¶

The EncoderDecoderModel forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) –

    Indices of input sequence tokens in the vocabulary.

    Indices can be obtained using PreTrainedTokenizer. See transformers.PreTrainedTokenizer.encode() and transformers.PreTrainedTokenizer.__call__() for details.

    What are input IDs?

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) – Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

  • attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional) –

    Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

    • 1 for tokens that are not masked,

    • 0 for tokens that are maked.

    What are attention masks?

  • encoder_outputs (tuple(torch.FloatTensor), optional) – This tuple must consist of (last_hidden_state, optional: hidden_states, optional: attentions) last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) is a tensor of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.

  • decoder_input_ids (torch.LongTensor of shape (batch_size, target_sequence_length), optional) – Provide for sequence to sequence training to the decoder. Indices can be obtained using PretrainedTokenizer. See transformers.PreTrainedTokenizer.encode() and transformers.PreTrainedTokenizer.__call__() for details.

  • decoder_attention_mask (torch.BoolTensor of shape (batch_size, tgt_seq_len), optional) – Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.

  • decoder_inputs_embeds (torch.FloatTensor of shape (batch_size, target_sequence_length, hidden_size), optional) – Optionally, instead of passing decoder_input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert decoder_input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) – Labels for computing the masked language modeling loss for the decoder. Indices should be in [-100, 0, ..., config.vocab_size] (see input_ids docstring) Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size]

  • return_dict (bool, optional) – If set to True, the model will return a Seq2SeqLMOutput instead of a plain tuple.

  • kwargs – (optional) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - Without a prefix which will be input as **encoder_kwargs for the encoder forward function. - With a decoder_ prefix which will be input as **decoder_kwargs for the decoder forward function.

Returns

A Seq2SeqLMOutput (if return_dict=True is passed or when config.return_dict=True) or a tuple of torch.FloatTensor comprising various elements depending on the configuration (EncoderDecoderConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) – Languaged modeling loss.

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) – List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_heads, sequence_length, embed_size_per_head)).

    Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see past_key_values input) to speed up sequential decoding.

  • decoder_hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.

  • decoder_attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

  • encoder_last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) – Sequence of hidden-states at the output of the last layer of the encoder of the model.

  • encoder_hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.

  • encoder_attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import EncoderDecoderModel, BertTokenizer
>>> import torch

>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints

>>> # forward
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)

>>> # training
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True)
>>> loss, logits = outputs.loss, outputs.logits

>>> # save and load from pretrained
>>> model.save_pretrained("bert2bert")
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")

>>> # generation
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)

Return type

Seq2SeqLMOutput or tuple(torch.FloatTensor)