# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """
from typing import Optional
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_outputs import Seq2SeqLMOutput
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function.
Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization.
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#module>`__ sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the encoder.
Indices can be obtained using :class:`~transformers.PretrainedTokenizer`.
See :meth:`~transformers.PreTrainedTokenizer.encode` and
:meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices for the encoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`):
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
plain tuple.
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
"""
[docs]@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
r"""
:class:`~transformers.EncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model
classes of the library as encoder and another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
assert config is not None or (
encoder is not None and decoder is not None
), "Either a configuration or an Encoder and a decoder has to be provided"
if config is None:
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
assert isinstance(config, self.config_class), "config: {} has to be of type {}".format(
config, self.config_class
)
# initialize with config
super().__init__(config)
if encoder is None:
from .modeling_auto import AutoModel
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
from .modeling_auto import AutoModelForCausalLM
decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
assert (
self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
# tie encoder, decoder weights if config set accordingly
self.tie_weights()
[docs] def tie_weights(self):
# tie encoder & decoder if needed
if self.config.tie_encoder_decoder:
# tie encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
[docs] def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
[docs] @classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs
) -> PreTrainedModel:
r"""Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
To train the model, you need to first set it back in training mode with `model.train()`.
Params:
encoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
information necessary to initiate the encoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
information necessary to initiate the decoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``).
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter
- To update the parent model configuration, do not use a prefix for each configuration parameter
Behave differently depending on whether a :obj:`config` is provided or automatically loaded.
Examples::
>>> from transformers import EncoderDecoderModel
>>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert2bert")
>>> # load fine-tuned model
>>> model = EncoderDecoderModel.from_pretrained("./bert2bert")
"""
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
assert (
encoder_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModel
if "config" not in kwargs_encoder:
from .configuration_auto import AutoConfig
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
assert (
decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModelForCausalLM
if "config" not in kwargs_decoder:
from .configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or decoder_config.add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
[docs] @add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
labels=None,
return_dict=None,
**kwargs,
):
r"""
Returns:
Examples::
>>> from transformers import EncoderDecoderModel, BertTokenizer
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
>>> # forward
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
>>> # training
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True)
>>> loss, logits = outputs.loss, outputs.logits
>>> # save and load from pretrained
>>> model.save_pretrained("bert2bert")
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")
>>> # generation
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=return_dict,
**kwargs_encoder,
)
hidden_states = encoder_outputs[0]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
inputs_embeds=decoder_inputs_embeds,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
labels=labels,
return_dict=return_dict,
**kwargs_decoder,
)
# TODO(PVP): currently it is not possible to use `past`
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
return decoder_outputs + encoder_outputs
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)