BertGeneration¶
Overview¶
The BertGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks using
EncoderDecoderModel
as proposed in Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
The abstract from the paper is the following:
Unsupervised pretraining of large neural models has recently revolutionized Natural Language Processing. By warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT, GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation, Text Summarization, Sentence Splitting, and Sentence Fusion.
Usage:
The model can be used in combination with the
EncoderDecoderModel
to leverage two pretrained BERT checkpoints for subsequent fine-tuning.
# leverage checkpoints for Bert2Bert model...
# use BERT's cls token as BOS token and sep token as EOS token
encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased", bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
# create tokenizer...
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
input_ids = tokenizer('This is a long article to summarize', add_special_tokens=False, return_tensors="pt").input_ids
labels = tokenizer('This is a short summary', return_tensors="pt").input_ids
# train...
loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels).loss
loss.backward()
Pretrained
EncoderDecoderModel
are also directly available in the model hub, e.g.,
# instantiate sentence fusion model
sentence_fuser = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse")
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")
input_ids = tokenizer('This is the first sentence. This is the second sentence.', add_special_tokens=False, return_tensors="pt").input_ids
outputs = sentence_fuser.generate(input_ids)
print(tokenizer.decode(outputs[0]))
Tips:
BertGenerationEncoder
andBertGenerationDecoder
should be used in combination withEncoderDecoder
.For summarization, sentence splitting, sentence fusion and translation, no special tokens are required for the input. Therefore, no EOS token should be added to the end of the input.
The original code can be found here.
BertGenerationConfig¶
-
class
transformers.
BertGenerationConfig
(vocab_size=50358, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=2, eos_token_id=1, gradient_checkpointing=False, **kwargs)[source]¶ This is the configuration class to store the configuration of a
BertGenerationPreTrainedModel
. It is used to instantiate a BertGeneration model according to the specified arguments, defining the model architecture.Configuration objects inherit from
PretrainedConfig
and can be used to control the model outputs. Read the documentation fromPretrainedConfig
for more information.- Parameters
vocab_size (
int
, optional, defaults to 50358) – Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by theinputs_ids
passed when callingBertGeneration
.hidden_size (
int
, optional, defaults to 1024) – Dimensionality of the encoder layers and the pooler layer.num_hidden_layers (
int
, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.num_attention_heads (
int
, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.intermediate_size (
int
, optional, defaults to 3072) – Dimensionality of the “intermediate” (often called feed-forward) layer in the Transformer encoder.hidden_act (
str
orfunction
, optional, defaults to"gelu"
) – The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu"
,"relu"
,"silu"
and"gelu_new"
are supported.hidden_dropout_prob (
float
, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.attention_probs_dropout_prob (
float
, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.max_position_embeddings (
int
, optional, defaults to 512) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).initializer_range (
float
, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.layer_norm_eps (
float
, optional, defaults to 1e-12) – The epsilon used by the layer normalization layers.gradient_checkpointing (
bool
, optional, defaults toFalse
) – IfTrue
, use gradient checkpointing to save memory at the expense of slower backward pass.
Examples:
>>> from transformers import BertGenerationConfig, BertGenerationEncoder >>> # Initializing a BertGeneration config >>> configuration = BertGenerationConfig() >>> # Initializing a model from the config >>> model = BertGenerationEncoder(configuration) >>> # Accessing the model configuration >>> configuration = model.config
BertGenerationTokenizer¶
-
class
transformers.
BertGenerationTokenizer
(vocab_file, bos_token='<s>', eos_token='</s>', unk_token='<unk>', pad_token='<pad>', sep_token='<::::>', **kwargs)[source]¶ Construct a BertGeneration tokenizer. Based on SentencePiece.
This tokenizer inherits from
PreTrainedTokenizer
which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.- Parameters
vocab_file (
str
) – SentencePiece file (generally has a .spm extension) that contains the vocabulary necessary to instantiate a tokenizer.eos_token (
str
, optional, defaults to"</s>"
) – The end of sequence token.bos_token (
str
, optional, defaults to"<s>"
) – The begin of sequence token.unk_token (
str
, optional, defaults to"<unk>"
) – The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.pad_token (
str
, optional, defaults to"<pad>"
) – The token used for padding, for example when batching sequences of different lengths.
-
save_vocabulary
(save_directory: str, filename_prefix: Optional[str] = None) → Tuple[str][source]¶ Save only the vocabulary of the tokenizer (vocabulary + added tokens).
This method won’t save the configuration and special token mappings of the tokenizer. Use
_save_pretrained()
to save the whole state of the tokenizer.- Parameters
save_directory (
str
) – The directory in which to save the vocabulary.filename_prefix (
str
, optional) – An optional prefix to add to the named of the saved files.
- Returns
Paths to the files saved.
- Return type
Tuple(str)
BertGenerationEncoder¶
-
class
transformers.
BertGenerationEncoder
(config)[source]¶ The bare BertGeneration model transformer outputting raw hidden-states without any specific head on top.
This model inherits from
PreTrainedModel
. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
BertGenerationConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in Attention is all you need by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
This model should be used when leveraging Bert or Roberta checkpoints for the
EncoderDecoderModel
class as described in Leveraging Pre-trained Checkpoints for Sequence Generation Tasks by Sascha Rothe, Shashi Narayan, and Aliaksei Severyn.To behave as an decoder the model needs to be initialized with the
is_decoder
argument of the configuration set toTrue
. To be used in a Seq2Seq model, the model needs to initialized with bothis_decoder
argument andadd_cross_attention
set toTrue
; anencoder_hidden_states
is then expected as an input to the forward pass.-
forward
(input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]¶ The
BertGenerationEncoder
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
) –Indices of input sequence tokens in the vocabulary.
Indices can be obtained using
BertGenerationTokenizer
. Seetransformers.PreTrainedTokenizer.__call__()
andtransformers.PreTrainedTokenizer.encode()
for details.attention_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) –Mask to avoid performing attention on padding token indices. Mask values selected in
[0, 1]
:1 for tokens that are not masked,
0 for tokens that are masked.
position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) –Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
[0, config.max_position_embeddings - 1]
.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) – Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix.output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.encoder_hidden_states (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) – Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.encoder_attention_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) – Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in[0, 1]
:1
for tokens that are NOT MASKED,0
for MASKED tokens.
- Returns
A
BaseModelOutputWithCrossAttentions
(ifreturn_dict=True
is passed or whenconfig.return_dict=True
) or a tuple oftorch.FloatTensor
comprising various elements depending on the configuration (BertGenerationConfig
) and inputs.last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) – Sequence of hidden-states at the output of the last layer of the model.hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
cross_attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
andconfig.add_cross_attention=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.
- Return type
BaseModelOutputWithCrossAttentions
ortuple(torch.FloatTensor)
Example:
>>> from transformers import BertGenerationTokenizer, BertGenerationEncoder >>> import torch >>> tokenizer = BertGenerationTokenizer.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder') >>> model = BertGenerationEncoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state
BertGenerationDecoder¶
-
class
transformers.
BertGenerationDecoder
(config)[source]¶ BertGeneration Model with a language modeling head on top for CLM fine-tuning.
This model inherits from
PreTrainedModel
. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
BertGenerationConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
forward
(input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]¶ The
BertGenerationDecoder
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
) –Indices of input sequence tokens in the vocabulary.
Indices can be obtained using
BertGenerationTokenizer
. Seetransformers.PreTrainedTokenizer.__call__()
andtransformers.PreTrainedTokenizer.encode()
for details.attention_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) –Mask to avoid performing attention on padding token indices. Mask values selected in
[0, 1]
:1 for tokens that are not masked,
0 for tokens that are masked.
position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) –Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
[0, config.max_position_embeddings - 1]
.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) – Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix.output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.encoder_hidden_states (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) – Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.encoder_attention_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) –Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in
[0, 1]
:1 for tokens that are not masked,
0 for tokens that are masked.
labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) – Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in[-100, 0, ..., config.vocab_size]
(seeinput_ids
docstring) Tokens with indices set to-100
are ignored (masked), the loss is only computed for the tokens with labels in[0, ..., config.vocab_size]
- Returns
A
CausalLMOutputWithCrossAttentions
(ifreturn_dict=True
is passed or whenconfig.return_dict=True
) or a tuple oftorch.FloatTensor
comprising various elements depending on the configuration (BertGenerationConfig
) and inputs.loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) – Language modeling loss (for next-token prediction).logits (
torch.FloatTensor
of shape(batch_size, sequence_length, config.vocab_size)
) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
cross_attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads.
Example:
>>> from transformers import BertGenerationTokenizer, BertGenerationDecoder, BertGenerationConfig >>> import torch >>> tokenizer = BertGenerationTokenizer.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder') >>> config = BertGenerationConfig.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") >>> config.is_decoder = True >>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config) >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits
- Return type
CausalLMOutputWithCrossAttentions
ortuple(torch.FloatTensor)