The transformer-based encoder-decoder model was introduced by Vaswani et al. in the famous Attention is all you need paper and is today the de-facto standard encoder-decoder architecture in natural language processing (NLP).
Recently, there has been a lot of research on different pre-training objectives for transformer-based encoder-decoder models, e.g. T5, Bart, Pegasus, ProphetNet, Marge, etc..., but the model architecture has stayed largely the same.
The goal of the blog post is to give an in-detail explanation of how the transformer-based encoder-decoder architecture models sequence-to-sequence problems. We will focus on the mathematical model defined by the architecture and how the model can be used in inference. Along the way, we will give some background on sequence-to-sequence models in NLP and break down the transformer-based encoder-decoder architecture into its encoder and decoder part. We provide many illustrations and establish the link between the theory of transformer-based encoder-decoder models and their practical usage in π€Transformers for inference. Note that this blog post does not explain how such models can be trained - this will be the topic of a future blog post.
Transformer-based encoder-decoder models are the result of years of research on representation learning and model architectures. This notebook provides a short summary of the history of neural encoder-decoder models. For more context, the reader is advised to read this awesome blog post by Sebastion Ruder. Additionally, a basic understanding of the self-attention architecture is recommended. The following blog post by Jay Alammar serves as a good refresher on the original Transformer model here.
At the time of writing this notebook, π€Transformers comprises the encoder-decoder models T5, Bart, MarianMT, and Pegasus, which are summarized in the docs under model summaries.
The notebook is divided into four parts:
Each part builds upon the previous part, but can also be read on its own.
Tasks in natural language generation (NLG), a subfield of NLP, are best expressed as sequence-to-sequence problems. Such tasks can be defined as finding a model that maps a sequence of input words to a sequence of target words. Some classic examples are summarization and translation. In the following, we assume that each word is encoded into a vector representation. n input words can then be represented as a sequence of n input vectors:
X1:nβ={x1β,β¦,xnβ}.
Consequently, sequence-to-sequence problems can be solved by finding a mapping f from an input sequence of n vectors X1:nβ to a sequence of m target vectors Y1:mβ, whereas the number of target vectors m is unknown apriori and depends on the input sequence:
f:X1:nββY1:mβ.
Sutskever et al. (2014) noted that deep neural networks (DNN)s, "*despite their flexibility and power can only define a mapping whose inputs and targets can be sensibly encoded with vectors of fixed dimensionality.*" 1
Using a DNN model 2 to solve sequence-to-sequence problems would therefore mean that the number of target vectors m has to be known apriori and would have to be independent of the input X1:nβ. This is suboptimal because, for tasks in NLG, the number of target words usually depends on the input X1:nβ and not just on the input length n. E.g. An article of 1000 words can be summarized to both 200 words and 100 words depending on its content.
In 2014, Cho et al. and Sutskever et al. proposed to use an encoder-decoder model purely based on recurrent neural networks (RNNs) for sequence-to-sequence tasks. In contrast to DNNS, RNNs are capable of modeling a mapping to a variable number of target vectors. Let's dive a bit deeper into the functioning of RNN-based encoder-decoder models.
During inference, the encoder RNN encodes an input sequence X1:nβ by successively updating its hidden state 3. After having processed the last input vector xnβ, the encoder's hidden state defines the input encoding c. Thus, the encoder defines the mapping:
fΞΈencββ:X1:nββc.
Then, the decoder's hidden state is initialized with the input encoding and during inference, the decoder RNN is used to auto-regressively generate the target sequence. Let's explain.
Mathematically, the decoder defines the probability distribution of a target sequence Y1:mβ given the hidden state c:
pΞΈdecββ(Y1:mββ£c).
By Bayes' rule the distribution can be decomposed into conditional distributions of single target vectors as follows:
pΞΈdecββ(Y1:mββ£c)=i=1βmβpΞΈdecββ(yiββ£Y0:iβ1β,c).
Thus, if the architecture can model the conditional distribution of the next target vector, given all previous target vectors:
pΞΈdecββ(yiββ£Y0:iβ1β,c),βiβ{1,β¦,m},
then it can model the distribution of any target vector sequence given the hidden state c by simply multiplying all conditional probabilities.
So how does the RNN-based decoder architecture model pΞΈdecββ(yiββ£Y0:iβ1β,c)?
In computational terms, the model sequentially maps the previous inner hidden state ciβ1β and the previous target vector yiβ to the current inner hidden state ciβ and a logit vector liβ (shown in dark red below):
fΞΈdecββ(yiβ1β,ciβ1β)βliβ,ciβ. c0β is thereby defined as c being the output hidden state of the RNN-based encoder. Subsequently, the softmax operation is used to transform the logit vector liβ to a conditional probablity distribution of the next target vector:
p(yiββ£liβ)=Softmax(liβ),Β withΒ liβ=fΞΈdecββ(yiβ1β,cprevβ).
For more detail on the logit vector and the resulting probability distribution, please see footnote 4. From the above equation, we can see that the distribution of the current target vector yiβ is directly conditioned on the previous target vector yiβ1β and the previous hidden state ciβ1β. Because the previous hidden state ciβ1β depends on all previous target vectors y0β,β¦,yiβ2β, it can be stated that the RNN-based decoder implicitly (e.g. indirectly) models the conditional distribution pΞΈdecββ(yiββ£Y0:iβ1β,c).
The space of possible target vector sequences Y1:mβ is prohibitively large so that at inference, one has to rely on decoding methods 5 that efficiently sample high probability target vector sequences from pΞΈdecββ(Y1:mββ£c).
Given such a decoding method, during inference, the next input vector yiβ can then be sampled from pΞΈdecββ(yiββ£Y0:iβ1β,c) and is consequently appended to the input sequence so that the decoder RNN then models pΞΈdecββ(yi+1ββ£Y0:iβ,c) to sample the next input vector yi+1β and so on in auto-regressive fashion.
An important feature of RNN-based encoder-decoder models is the definition of special vectors, such as the EOS and BOS vector. The EOS vector often represents the final input vector xnβ to "cue" the encoder that the input sequence has ended and also defines the end of the target sequence. As soon as the EOS is sampled from a logit vector, the generation is complete. The BOS vector represents the input vector y0β fed to the decoder RNN at the very first decoding step. To output the first logit l1β, an input is required and since no input has been generated at the first step a special BOS input vector is fed to the decoder RNN. Ok - quite complicated! Let's illustrate and walk through an example.
The unfolded RNN encoder is colored in green and the unfolded RNN decoder is colored in red.
The English sentence "I want to buy a car", represented by x1β=I, x2β=want, x3β=to, x4β=buy, x5β=a, x6β=car and x7β=EOS is translated into German: "Ich will ein Auto kaufen" defined as y0β=BOS, y1β=Ich, y2β=will, y3β=ein, y4β=Auto,y5β=kaufen and y6β=EOS. To begin with, the input vector x1β=I is processed by the encoder RNN and updates its hidden state. Note that because we are only interested in the final encoder's hidden state c, we can disregard the RNN encoder's target vector. The encoder RNN then processes the rest of the input sentence want, to, buy, a, car, EOS in the same fashion, updating its hidden state at each step until the vector x7β=EOS is reached 6. In the illustration above the horizontal arrow connecting the unfolded encoder RNN represents the sequential updates of the hidden state. The final hidden state of the encoder RNN, represented by c then completely defines the encoding of the input sequence and is used as the initial hidden state of the decoder RNN. This can be seen as conditioning the decoder RNN on the encoded input.
To generate the first target vector, the decoder is fed the BOS vector, illustrated as y0β in the design above. The target vector of the RNN is then further mapped to the logit vector l1β by means of the LM Head feed-forward layer to define the conditional distribution of the first target vector as explained above:
pΞΈdecββ(yβ£BOS,c).
The word Ich is sampled (shown by the grey arrow, connecting l1β and y1β) and consequently the second target vector can be sampled:
willβΌpΞΈdecββ(yβ£BOS,Ich,c).
And so on until at step i=6, the EOS vector is sampled from l6β and the decoding is finished. The resulting target sequence amounts to Y1:6β={y1β,β¦,y6β}, which is "Ich will ein Auto kaufen" in our example above.
To sum it up, an RNN-based encoder-decoder model, represented by fΞΈencββ and pΞΈdecββ defines the distribution p(Y1:mββ£X1:nβ) by factorization:
pΞΈencβ,ΞΈdecββ(Y1:mββ£X1:nβ)=i=1βmβpΞΈencβ,ΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ)=i=1βmβpΞΈdecββ(yiββ£Y0:iβ1β,c),Β withΒ c=fΞΈencββ(X).
During inference, efficient decoding methods can auto-regressively generate the target sequence Y1:mβ.
The RNN-based encoder-decoder model took the NLG community by storm. In 2016, Google announced to fully replace its heavily feature engineered translation service by a single RNN-based encoder-decoder model (see here).
Nevertheless, RNN-based encoder-decoder models have two pitfalls. First, RNNs suffer from the vanishing gradient problem, making it very difficult to capture long-range dependencies, cf. Hochreiter et al. (2001). Second, the inherent recurrent architecture of RNNs prevents efficient parallelization when encoding, cf. Vaswani et al. (2017).
1 The original quote from the paper is "Despite their flexibility and power, DNNs can only be applied to problems whose inputs and targets can be sensibly encoded with vectors of fixed dimensionality", which is slightly adapted here.
2 The same holds essentially true for convolutional neural networks (CNNs). While an input sequence of variable length can be fed into a CNN, the dimensionality of the target will always be dependent on the input dimensionality or fixed to a specific value.
3 At the first step, the hidden state is initialized as a zero vector and fed to the RNN together with the first input vector x1β.
4 A neural network can define a probability distribution over all words, i.e. p(yβ£c,Y0:iβ1β) as follows. First, the network defines a mapping from the inputs c,Y0:iβ1β to an embedded vector representation yβ², which corresponds to the RNN target vector. The embedded vector representation yβ² is then passed to the "language model head" layer, which means that it is multiplied by the word embedding matrix, i.e. Yvocab, so that a score between yβ² and each encoded vector yβYvocab is computed. The resulting vector is called the logit vector l=Yvocabyβ² and can be mapped to a probability distribution over all words by applying a softmax operation: p(yβ£c)=Softmax(Yvocabyβ²)=Softmax(l).
5 Beam-search decoding is an example of such a decoding method. Different decoding methods are out of scope for this notebook. The reader is advised to refer to this interactive notebook on decoding methods.
6 Sutskever et al. (2014) reverses the order of the input so that in the above example the input vectors would correspond to x1β=car, x2β=a, x3β=buy, x4β=to, x5β=want, x6β=I and x7β=EOS. The motivation is to allow for a shorter connection between corresponding word pairs such as x6β=I and y1β=Ich. The research group emphasizes that the reversal of the input sequence was a key reason for their model's improved performance on machine translation.
In 2017, Vaswani et al. introduced the Transformer and thereby gave birth to transformer-based encoder-decoder models.
Analogous to RNN-based encoder-decoder models, transformer-based encoder-decoder models consist of an encoder and a decoder which are both stacks of residual attention blocks. The key innovation of transformer-based encoder-decoder models is that such residual attention blocks can process an input sequence X1:nβ of variable length n without exhibiting a recurrent structure. Not relying on a recurrent structure allows transformer-based encoder-decoders to be highly parallelizable, which makes the model orders of magnitude more computationally efficient than RNN-based encoder-decoder models on modern hardware.
As a reminder, to solve a sequence-to-sequence problem, we need to find a mapping of an input sequence X1:nβ to an output sequence Y1:mβ of variable length m. Let's see how transformer-based encoder-decoder models are used to find such a mapping.
Similar to RNN-based encoder-decoder models, the transformer-based encoder-decoder models define a conditional distribution of target vectors Y1:nβ given an input sequence X1:nβ:
pΞΈencβ,ΞΈdecββ(Y1:mββ£X1:nβ).
The transformer-based encoder part encodes the input sequence X1:nβ to a sequence of hidden states X1:nβ, thus defining the mapping:
fΞΈencββ:X1:nββX1:nβ.
The transformer-based decoder part then models the conditional probability distribution of the target vector sequence Y1:nβ given the sequence of encoded hidden states X1:nβ:
pΞΈdecββ(Y1:nββ£X1:nβ).
By Bayes' rule, this distribution can be factorized to a product of conditional probability distribution of the target vector yiβ given the encoded hidden states X1:nβ and all previous target vectors Y0:iβ1β:
pΞΈdecββ(Y1:nββ£X1:nβ)=i=1βnβpΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ).
The transformer-based decoder hereby maps the sequence of encoded hidden states X1:nβ and all previous target vectors Y0:iβ1β to the logit vector liβ. The logit vector liβ is then processed by the softmax operation to define the conditional distribution pΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ), just as it is done for RNN-based decoders. However, in contrast to RNN-based decoders, the distribution of the target vector yiβ is explicitly (or directly) conditioned on all previous target vectors y0β,β¦,yiβ1β as we will see later in more detail. The 0th target vector y0β is hereby represented by a special "begin-of-sentence" BOS vector.
Having defined the conditional distribution pΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ), we can now auto-regressively generate the output and thus define a mapping of an input sequence X1:nβ to an output sequence Y1:mβ at inference.
Let's visualize the complete process of auto-regressive generation of transformer-based encoder-decoder models.
The transformer-based encoder is colored in green and the transformer-based decoder is colored in red. As in the previous section, we show how the English sentence "I want to buy a car", represented by x1β=I, x2β=want, x3β=to, x4β=buy, x5β=a, x6β=car, and x7β=EOS is translated into German: "Ich will ein Auto kaufen" defined as y0β=BOS, y1β=Ich, y2β=will, y3β=ein, y4β=Auto,y5β=kaufen, and y6β=EOS.
To begin with, the encoder processes the complete input sequence X1:7β = "I want to buy a car" (represented by the light green vectors) to a contextualized encoded sequence X1:7β. E.g. x4β defines an encoding that depends not only on the input x4β = "buy", but also on all other words "I", "want", "to", "a", "car" and "EOS", i.e. the context.
Next, the input encoding X1:7β together with the BOS vector, i.e. y0β, is fed to the decoder. The decoder processes the inputs X1:7β and y0β to the first logit l1β (shown in darker red) to define the conditional distribution of the first target vector y1β:
pΞΈenc,decββ(yβ£y0β,X1:7β)=pΞΈenc,decββ(yβ£BOS,IΒ wantΒ toΒ buyΒ aΒ carΒ EOS)=pΞΈdecββ(yβ£BOS,X1:7β).
Next, the first target vector y1β = Ich is sampled from the distribution (represented by the grey arrows) and can now be fed to the decoder again. The decoder now processes both y0β = "BOS" and y1β = "Ich" to define the conditional distribution of the second target vector y2β:
pΞΈdecββ(yβ£BOSΒ Ich,X1:7β).
We can sample again and produce the target vector y2β = "will". We continue in auto-regressive fashion until at step 6 the EOS vector is sampled from the conditional distribution:
EOSβΌpΞΈdecββ(yβ£BOSΒ IchΒ willΒ einΒ AutoΒ kaufen,X1:7β).
And so on in auto-regressive fashion.
It is important to understand that the encoder is only used in the first forward pass to map X1:nβ to X1:nβ. As of the second forward pass, the decoder can directly make use of the previously calculated encoding X1:nβ. For clarity, let's illustrate the first and the second forward pass for our example above.
As can be seen, only in step i=1 do we have to encode "I want to buy a car EOS" to X1:7β. At step i=2, the contextualized encodings of "I want to buy a car EOS" are simply reused by the decoder.
In π€Transformers, this auto-regressive generation is done under-the-hood
when calling the .generate()
method. Let's use one of our translation
models to see this in action.
from transformers import MarianMTModel, MarianTokenizer
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
# create ids of encoded input vectors
input_ids = tokenizer("I want to buy a car", return_tensors="pt").input_ids
# translate example
output_ids = model.generate(input_ids)[0]
# decode and print
print(tokenizer.decode(output_ids))
Output:
<pad> Ich will ein Auto kaufen
Calling .generate()
does many things under-the-hood. First, it passes
the input_ids
to the encoder. Second, it passes a pre-defined token, which is the <pad> symbol in the case of
MarianMTModel
along with the encoded input_ids
to the decoder.
Third, it applies the beam search decoding mechanism to
auto-regressively sample the next output word of the last decoder
output 1. For more detail on how beam search decoding works, one is
advised to read this blog
post.
In the Appendix, we have included a code snippet that shows how a simple generation method can be implemented "from scratch". To fully understand how auto-regressive generation works under-the-hood, it is highly recommended to read the Appendix.
To sum it up:
Great, now that we have gotten a general overview of how transformer-based encoder-decoder models work, we can dive deeper into both the encoder and decoder part of the model. More specifically, we will see exactly how the encoder makes use of the self-attention layer to yield a sequence of context-dependent vector encodings and how self-attention layers allow for efficient parallelization. Then, we will explain in detail how the self-attention layer works in the decoder model and how the decoder is conditioned on the encoder's output with cross-attention layers to define the conditional distribution pΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ). Along, the way it will become obvious how transformer-based encoder-decoder models solve the long-range dependencies problem of RNN-based encoder-decoder models.
1 In the case of "Helsinki-NLP/opus-mt-en-de"
, the decoding
parameters can be accessed
here,
where we can see that model applies beam search with num_beams=6
.
As mentioned in the previous section, the transformer-based encoder maps the input sequence to a contextualized encoding sequence:
fΞΈencββ:X1:nββX1:nβ.
Taking a closer look at the architecture, the transformer-based encoder is a stack of residual encoder blocks. Each encoder block consists of a bi-directional self-attention layer, followed by two feed-forward layers. For simplicity, we disregard the normalization layers in this notebook. Also, we will not further discuss the role of the two feed-forward layers, but simply see it as a final vector-to-vector mapping required in each encoder block 1. The bi-directional self-attention layer puts each input vector xβ²jβ,βjβ{1,β¦,n} into relation with all input vectors xβ²1β,β¦,xβ²nβ and by doing so transforms the input vector xβ²jβ to a more "refined" contextual representation of itself, defined as xβ²β²jβ. Thereby, the first encoder block transforms each input vector of the input sequence X1:nβ (shown in light green below) from a context-independent vector representation to a context-dependent vector representation, and the following encoder blocks further refine this contextual representation until the last encoder block outputs the final contextual encoding X1:nβ (shown in darker green below).
Let's visualize how the encoder processes the input sequence "I want to buy a car EOS" to a contextualized encoding sequence. Similar to RNN-based encoders, transformer-based encoders also add a special "end-of-sequence" input vector to the input sequence to hint to the model that the input vector sequence is finished 2.
Our exemplary transformer-based encoder is composed of three encoder blocks, whereas the second encoder block is shown in more detail in the red box on the right for the first three input vectors x1β,x2βandx3β. The bi-directional self-attention mechanism is illustrated by the fully-connected graph in the lower part of the red box and the two feed-forward layers are shown in the upper part of the red box. As stated before, we will focus only on the bi-directional self-attention mechanism.
As can be seen each output vector of the self-attention layer xβ²β²iβ,βiβ{1,β¦,7} depends directly on all input vectors xβ²1β,β¦,xβ²7β. This means, e.g. that the input vector representation of the word "want", i.e. xβ²2β, is put into direct relation with the word "buy", i.e. xβ²4β, but also with the word "I",i.e. xβ²1β. The output vector representation of "want", i.e. xβ²β²2β, thus represents a more refined contextual representation for the word "want".
Let's take a deeper look at how bi-directional self-attention works. Each input vector xβ²iβ of an input sequence Xβ²1:nβ of an encoder block is projected to a key vector kiβ, value vector viβ and query vector qiβ (shown in orange, blue, and purple respectively below) through three trainable weight matrices Wqβ,Wvβ,Wkβ:
qiβ=Wqβxβ²iβ, viβ=Wvβxβ²iβ, kiβ=Wkβxβ²iβ, βiβ{1,β¦n}.
Note, that the same weight matrices are applied to each input vector xiβ,βiβ{i,β¦,n}. After projecting each input vector xiβ to a query, key, and value vector, each query vector qjβ,βjβ{1,β¦,n} is compared to all key vectors k1β,β¦,knβ. The more similar one of the key vectors k1β,β¦knβ is to a query vector qjβ, the more important is the corresponding value vector vjβ for the output vector xβ²β²jβ. More specifically, an output vector xβ²β²jβ is defined as the weighted sum of all value vectors v1β,β¦,vnβ plus the input vector xβ²jβ. Thereby, the weights are proportional to the cosine similarity between qjβ and the respective key vectors k1β,β¦,knβ, which is mathematically expressed by Softmax(K1:nβΊβqjβ) as illustrated in the equation below. For a complete description of the self-attention layer, the reader is advised to take a look at this blog post or the original paper.
Alright, this sounds quite complicated. Let's illustrate the
bi-directional self-attention layer for one of the query vectors of our
example above. For simplicity, it is assumed that our exemplary
transformer-based decoder uses only a single attention head
config.num_heads = 1
and that no normalization is applied.
On the left, the previously illustrated second encoder block is shown again and on the right, an in detail visualization of the bi-directional self-attention mechanism is given for the second input vector xβ²2β that corresponds to the input word "want". At first all input vectors xβ²1β,β¦,xβ²7β are projected to their respective query vectors q1β,β¦,q7β (only the first three query vectors are shown in purple above), value vectors v1β,β¦,v7β (shown in blue), and key vectors k1β,β¦,k7β (shown in orange). The query vector q2β is then multiplied by the transpose of all key vectors, i.e. K1:7βΊβ followed by the softmax operation to yield the self-attention weights. The self-attention weights are finally multiplied by the respective value vectors and the input vector xβ²2β is added to output the "refined" representation of the word "want", i.e. xβ²β²2β (shown in dark green on the right). The whole equation is illustrated in the upper part of the box on the right. The multiplication of K1:7βΊβ and q2β thereby makes it possible to compare the vector representation of "want" to all other input vector representations "I", "to", "buy", "a", "car", "EOS" so that the self-attention weights mirror the importance each of the other input vector representations xβ²jβ,Β withΒ jξ =2 for the refined representation xβ²β²2β of the word "want".
To further understand the implications of the bi-directional self-attention layer, let's assume the following sentence is processed: "The house is beautiful and well located in the middle of the city where it is easily accessible by public transport". The word "it" refers to "house", which is 12 "positions away". In transformer-based encoders, the bi-directional self-attention layer performs a single mathematical operation to put the input vector of "house" into relation with the input vector of "it" (compare to the first illustration of this section). In contrast, in an RNN-based encoder, a word that is 12 "positions away", would require at least 12 mathematical operations meaning that in an RNN-based encoder a linear number of mathematical operations are required. This makes it much harder for an RNN-based encoder to model long-range contextual representations. Also, it becomes clear that a transformer-based encoder is much less prone to lose important information than an RNN-based encoder-decoder model because the sequence length of the encoding is kept the same, i.e. len(X1:nβ)=len(X1:nβ)=n, while an RNN compresses the length from βlen((X1:nβ)=n to just len(c)=1, which makes it very difficult for RNNs to effectively encode long-range dependencies between input words.
In addition to making long-range dependencies more easily learnable, we can see that the Transformer architecture is able to process text in parallel.Mathematically, this can easily be shown by writing the self-attention formula as a product of query, key, and value matrices:
Xβ²β²1:nβ=V1:nβSoftmax(Q1:nβΊβK1:nβ)+Xβ²1:nβ.
The output Xβ²β²1:nβ=xβ²β²1β,β¦,xβ²β²nβ is computed via a series of matrix multiplications and a softmax operation, which can be parallelized effectively. Note, that in an RNN-based encoder model, the computation of the hidden state c has to be done sequentially: Compute hidden state of the first input vector x1β, then compute the hidden state of the second input vector that depends on the hidden state of the first hidden vector, etc. The sequential nature of RNNs prevents effective parallelization and makes them much more inefficient compared to transformer-based encoder models on modern GPU hardware.
Great, now we should have a better understanding of a) how transformer-based encoder models effectively model long-range contextual representations and b) how they efficiently process long sequences of input vectors.
Now, let's code up a short example of the encoder part of our
MarianMT
encoder-decoder models to verify that the explained theory
holds in practice.
1 An in-detail explanation of the role the feed-forward layers play in transformer-based models is out-of-scope for this notebook. It is argued in Yun et. al, (2017) that feed-forward layers are crucial to map each contextual vector xβ²iβ individually to the desired output space, which the self-attention layer does not manage to do on its own. It should be noted here, that each output token xβ² is processed by the same feed-forward layer. For more detail, the reader is advised to read the paper.
2 However, the EOS input vector does not have to be appended to the input sequence, but has been shown to improve performance in many cases. In contrast to the 0th BOS target vector of the transformer-based decoder is required as a starting input vector to predict a first target vector.
from transformers import MarianMTModel, MarianTokenizer
import torch
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
embeddings = model.get_input_embeddings()
# create ids of encoded input vectors
input_ids = tokenizer("I want to buy a car", return_tensors="pt").input_ids
# pass input_ids to encoder
encoder_hidden_states = model.base_model.encoder(input_ids, return_dict=True).last_hidden_state
# change the input slightly and pass to encoder
input_ids_perturbed = tokenizer("I want to buy a house", return_tensors="pt").input_ids
encoder_hidden_states_perturbed = model.base_model.encoder(input_ids_perturbed, return_dict=True).last_hidden_state
# compare shape and encoding of first vector
print(f"Length of input embeddings {embeddings(input_ids).shape[1]}. Length of encoder_hidden_states {encoder_hidden_states.shape[1]}")
# compare values of word embedding of "I" for input_ids and perturbed input_ids
print("Is encoding for `I` equal to its perturbed version?: ", torch.allclose(encoder_hidden_states[0, 0], encoder_hidden_states_perturbed[0, 0], atol=1e-3))
Outputs:
Length of input embeddings 7. Length of encoder_hidden_states 7
Is encoding for `I` equal to its perturbed version?: False
We compare the length of the input word embeddings, i.e.
embeddings(input_ids)
corresponding to X1:nβ, with the
length of the encoder_hidden_states
, corresponding to X1:nβ. Also, we have forwarded the word sequence
"I want to buy a car" and a slightly perturbated version "I want to
buy a house" through the encoder to check if the first output encoding,
corresponding to "I", differs when only the last word is changed in
the input sequence.
As expected the output length of the input word embeddings and encoder output encodings, i.e. len(X1:nβ) and len(X1:nβ), is equal. Second, it can be noted that the values of the encoded output vector of x1β="I" are different when the last word is changed from "car" to "house". This however should not come as a surprise if one has understood bi-directional self-attention.
On a side-note, autoencoding models, such as BERT, have the exact same architecture as transformer-based encoder models. Autoencoding models leverage this architecture for massive self-supervised pre-training on open-domain text data so that they can map any word sequence to a deep bi-directional representation. In Devlin et al. (2018), the authors show that a pre-trained BERT model with a single task-specific classification layer on top can achieve SOTA results on eleven NLP tasks. All autoencoding models of π€Transformers can be found here.
As mentioned in the Encoder-Decoder section, the transformer-based decoder defines the conditional probability distribution of a target sequence given the contextualized encoding sequence:
pΞΈdecββ(Y1:mββ£X1:nβ),
which by Bayes' rule can be decomposed into a product of conditional distributions of the next target vector given the contextualized encoding sequence and all previous target vectors:
pΞΈdecββ(Y1:mββ£X1:nβ)=i=1βmβpΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ).
Let's first understand how the transformer-based decoder defines a probability distribution. The transformer-based decoder is a stack of decoder blocks followed by a dense layer, the "LM head". The stack of decoder blocks maps the contextualized encoding sequence X1:nβ and a target vector sequence prepended by the BOS vector and cut to the last target vector, i.e. Y0:iβ1β, to an encoded sequence of target vectors Y0:iβ1β. Then, the "LM head" maps the encoded sequence of target vectors Y0:iβ1β to a sequence of logit vectors L1:nβ=l1β,β¦,lnβ, whereas the dimensionality of each logit vector liβ corresponds to the size of the vocabulary. This way, for each iβ{1,β¦,n} a probability distribution over the whole vocabulary can be obtained by applying a softmax operation on liβ. These distributions define the conditional distribution:
pΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ),βiβ{1,β¦,n},
respectively. The "LM head" is often tied to the transpose of the word embedding matrix, i.e. WembβΊβ=[y1,β¦,yvocab]βΊ 1. Intuitively this means that for all iβ{0,β¦,nβ1} the "LM Head" layer compares the encoded output vector yβiβ to all word embeddings in the vocabulary y1,β¦,yvocab so that the logit vector li+1β represents the similarity scores between the encoded output vector and each word embedding. The softmax operation simply transformers the similarity scores to a probability distribution. For each iβ{1,β¦,n}, the following equations hold:
pΞΈdecββ(yβ£X1:nβ,Y0:iβ1β) =Softmax(fΞΈdecββ(X1:nβ,Y0:iβ1β)) =Softmax(WembβΊβyβiβ1β) =Softmax(liβ).
Putting it all together, in order to model the conditional distribution of a target vector sequence Y1:mβ, the target vectors Y1:mβ1β prepended by the special BOS vector, i.e. y0β, are first mapped together with the contextualized encoding sequence X1:nβ to the logit vector sequence L1:mβ. Consequently, each logit target vector liβ is transformed into a conditional probability distribution of the target vector yiβ using the softmax operation. Finally, the conditional probabilities of all target vectors y1β,β¦,ymβ multiplied together to yield the conditional probability of the complete target vector sequence:
pΞΈdecββ(Y1:mββ£X1:nβ)=i=1βmβpΞΈdecββ(yiββ£Y0:iβ1β,X1:nβ).
In contrast to transformer-based encoders, in transformer-based decoders, the encoded output vector yβiβ should be a good representation of the next target vector yi+1β and not of the input vector itself. Additionally, the encoded output vector yβiβ should be conditioned on all contextualized encoding sequence X1:nβ. To meet these requirements each decoder block consists of a uni-directional self-attention layer, followed by a cross-attention layer and two feed-forward layers 2. The uni-directional self-attention layer puts each of its input vectors yβ²jβ only into relation with all previous input vectors yβ²iβ,Β withΒ iβ€q for all jβ{1,β¦,n} to model the probability distribution of the next target vectors. The cross-attention layer puts each of its input vectors yβ²β²jβ into relation with all contextualized encoding vectors X1:nβ to condition the probability distribution of the next target vectors on the input of the encoder as well.
Alright, let's visualize the transformer-based decoder for our English to German translation example.
We can see that the decoder maps the input Y0:5β "BOS", "Ich", "will", "ein", "Auto", "kaufen" (shown in light red) together with the contextualized sequence of "I", "want", "to", "buy", "a", "car", "EOS", i.e. X1:7β (shown in dark green) to the logit vectors L1:6β (shown in dark red).
Applying a softmax operation on each l1β,l2β,β¦,l5β can thus define the conditional probability distributions:
pΞΈdecββ(yβ£BOS,X1:7β), pΞΈdecββ(yβ£BOSΒ Ich,X1:7β), β¦, pΞΈdecββ(yβ£BOSΒ IchΒ willΒ einΒ AutoΒ kaufen,X1:7β).
The overall conditional probability of:
pΞΈdecββ(IchΒ willΒ einΒ AutoΒ kaufenΒ EOSβ£X1:nβ)
can therefore be computed as the following product:
pΞΈdecββ(Ichβ£BOS,X1:7β)Γβ¦ΓpΞΈdecββ(EOSβ£BOSΒ IchΒ willΒ einΒ AutoΒ kaufen,X1:7β).
The red box on the right shows a decoder block for the first three target vectors y0β,y1β,y2β. In the lower part, the uni-directional self-attention mechanism is illustrated and in the middle, the cross-attention mechanism is illustrated. Let's first focus on uni-directional self-attention.
As in bi-directional self-attention, in uni-directional self-attention, the query vectors q0β,β¦,qmβ1β (shown in purple below), key vectors k0β,β¦,kmβ1β (shown in orange below), and value vectors v0β,β¦,vmβ1β (shown in blue below) are projected from their respective input vectors yβ²0β,β¦,ymβ1β (shown in light red below). However, in uni-directional self-attention, each query vector qiβ is compared only to its respective key vector and all previous ones, namely k0β,β¦,kiβ to yield the respective attention weights. This prevents an output vector yβ²β²jβ (shown in dark red below) to include any information about the following input vector yiβ,Β withΒ i>1 for all jβ{0,β¦,mβ1}. As is the case in bi-directional self-attention, the attention weights are then multiplied by their respective value vectors and summed together.
We can summarize uni-directional self-attention as follows:
yβ²β²iβ=V0:iβSoftmax(K0:iβΊβqiβ)+yβ²iβ.
Note that the index range of the key and value vectors is 0:i instead of 0:mβ1 which would be the range of the key vectors in bi-directional self-attention.
Let's illustrate uni-directional self-attention for the input vector yβ²1β for our example above.
As can be seen yβ²β²1β only depends on yβ²0β and yβ²1β. Therefore, we put the vector representation of the word "Ich", i.e. yβ²1β only into relation with itself and the "BOS" target vector, i.e. yβ²0β, but not with the vector representation of the word "will", i.e. yβ²2β.
So why is it important that we use uni-directional self-attention in the decoder instead of bi-directional self-attention? As stated above, a transformer-based decoder defines a mapping from a sequence of input vector Y0:mβ1β to the logits corresponding to the next decoder input vectors, namely L1:mβ. In our example, this means, e.g. that the input vector y1β = "Ich" is mapped to the logit vector l2β, which is then used to predict the input vector y2β. Thus, if yβ²1β would have access to the following input vectors Yβ²2:5β, the decoder would simply copy the vector representation of "will", i.e. yβ²2β, to be its output yβ²β²1β. This would be forwarded to the last layer so that the encoded output vector yβ1β would essentially just correspond to the vector representation y2β.
This is obviously disadvantageous as the transformer-based decoder would never learn to predict the next word given all previous words, but just copy the target vector yiβ through the network to yβiβ1β for all iβ{1,β¦,m}. In order to define a conditional distribution of the next target vector, the distribution cannot be conditioned on the next target vector itself. It does not make much sense to predict yiβ from p(yβ£Y0:iβ,X) because the distribution is conditioned on the target vector it is supposed to model. The uni-directional self-attention architecture, therefore, allows us to define a causal probability distribution, which is necessary to effectively model a conditional distribution of the next target vector.
Great! Now we can move to the layer that connects the encoder and decoder - the cross-attention mechanism!
The cross-attention layer takes two vector sequences as inputs: the outputs of the uni-directional self-attention layer, i.e. Yβ²β²0:mβ1β and the contextualized encoding vectors X1:nβ. As in the self-attention layer, the query vectors q0β,β¦,qmβ1β are projections of the output vectors of the previous layer, i.e. Yβ²β²0:mβ1β. However, the key and value vectors k0β,β¦,kmβ1β and v0β,β¦,vmβ1β are projections of the contextualized encoding vectors X1:nβ. Having defined key, value, and query vectors, a query vector qiβ is then compared to all key vectors and the corresponding score is used to weight the respective value vectors, just as is the case for bi-directional self-attention to give the output vector yβ²β²β²iβ for all iβ0,β¦,mβ1. Cross-attention can be summarized as follows:
yβ²β²β²iβ=V1:nβSoftmax(K1:nβΊβqiβ)+yβ²β²iβ.
Note that the index range of the key and value vectors is 1:n corresponding to the number of contextualized encoding vectors.
Let's visualize the cross-attention mechanism Let's for the input vector yβ²β²1β for our example above.
We can see that the query vector q1β (shown in purple) is derived from yβ²β²1β and therefore relies on a vector representation of the word "Ich". The query vector q1β (shown in red) is then compared to the key vectors k1β,β¦,k7β (shown in yellow) corresponding to the contextual encoding representation of all encoder input vectors X1:nβ = "I want to buy a car EOS". This puts the vector representation of "Ich" into direct relation with all encoder input vectors. Finally, the attention weights are multiplied by the value vectors v1β,β¦,v7β (shown in turquoise) to yield in addition to the input vector yβ²β²1β the output vector yβ²β²β²1β (shown in dark red).
So intuitively, what happens here exactly? Each output vector yβ²iβ is a weighted sum of all value projections of the encoder inputs v1β,β¦,v7β plus the input vector itself yiβ (c.f. illustrated formula above). The key mechanism to understand is the following: Depending on how similar a query projection of the input decoder vector qiβ is to a key projection of the encoder input vector kjβ, the more important is the value projection of the encoder input vector vjβ. In loose terms this means, the more "related" a decoder input representation is to an encoder input representation, the more does the input representation influence the decoder output representation.
Cool! Now we can see how this architecture nicely conditions each output vector yβ²β²β²iβ on the interaction between the encoder input vectors X1:nβ and the input vector yβ²β²iβ. Another important observation at this point is that the architecture is completely independent of the number n of contextualized encoding vectors X1:nβ on which the output vector yβ²β²β²iβ is conditioned on. All projection matrices Wkcrossβ and Wvcrossβ to derive the key vectors k1β,β¦,knβ and the value vectors v1β,β¦,vnβ respectively are shared across all positions 1,β¦,n and all value vectors v1β,β¦,vnβ are summed together to a single weighted averaged vector. Now it becomes obvious as well, why the transformer-based decoder does not suffer from the long-range dependency problem, the RNN-based decoder suffers from. Because each decoder logit vector is directly dependent on every single encoded output vector, the number of mathematical operations to compare the first encoded output vector and the last decoder logit vector amounts essentially to just one.
To conclude, the uni-directional self-attention layer is responsible for conditioning each output vector on all previous decoder input vectors and the current input vector and the cross-attention layer is responsible to further condition each output vector on all encoded input vectors.
To verify our theoretical understanding, let's continue our code example from the encoder section above.
1 The word embedding matrix Wembβ gives each input word a unique context-independent vector representation. This matrix is often fixed as the "LM Head" layer. However, the "LM Head" layer can very well consist of a completely independent "encoded vector-to-logit" weight mapping.
2 Again, an in-detail explanation of the role the feed-forward layers play in transformer-based models is out-of-scope for this notebook. It is argued in Yun et. al, (2017) that feed-forward layers are crucial to map each contextual vector xβ²iβ individually to the desired output space, which the self-attention layer does not manage to do on its own. It should be noted here, that each output token xβ² is processed by the same feed-forward layer. For more detail, the reader is advised to read the paper.
from transformers import MarianMTModel, MarianTokenizer
import torch
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
embeddings = model.get_input_embeddings()
# get encoded input vectors
input_ids = tokenizer("I want to buy a car", return_tensors="pt").input_ids
encoded_output_vectors = model.base_model.encoder(input_ids, return_dict=True).last_hidden_state
# create ids of encoded input vectors
decoder_input_ids = tokenizer("<pad> Ich will ein", return_tensors="pt", add_special_tokens=False).input_ids
# pass decoder input_ids and encoded input vectors to decoder
decoder_output_vectors = model.base_model.decoder(decoder_input_ids, encoded_output_vectors, None, None, None, return_dict=True).last_hidden_state
# derive embeddings by multiplying decoder outputs with embedding weights
lm_logits = torch.nn.functional.linear(decoder_output_vectors, embeddings.weight, bias=model.final_logits_bias)
# change the decoder input slightly
decoder_input_ids_perturbed = tokenizer("</s> Ich will das", return_tensors="pt").input_ids
decoder_output_vectors_perturbed = model.base_model.decoder(decoder_input_ids, encoded_output_vectors, None, None, None, return_dict=True).last_hidden_state
lm_logits_perturbed = torch.nn.functional.linear(decoder_output_vectors_perturbed, embeddings.weight, bias=model.final_logits_bias)
# compare shape and encoding of first vector
print(f"Shape of decoder input vectors {embeddings(decoder_input_ids).shape}. Shape of decoder logits {lm_logits.shape}")
# compare values of word embedding of "I" for input_ids and perturbed input_ids
print("Is encoding for `Ich` equal to its perturbed version?: ", torch.allclose(lm_logits[0, 0], lm_logits_perturbed[0, 0], atol=1e-3))
Output:
Shape of decoder input vectors torch.Size([1, 5, 512]). Shape of decoder logits torch.Size([1, 5, 58101])
Is encoding for `Ich` equal to its perturbed version?: True
We compare the output shape of the decoder input word embeddings, i.e.
embeddings(decoder_input_ids)
(corresponds to Y0:4β,
here <pad>
corresponds to BOS and "Ich will das" is tokenized to 4
tokens) with the dimensionality of the lm_logits
(corresponds to L1:5β). Also, we have passed the word sequence
"<pad>
{=html} Ich will das" and a slightly perturbated version
"<pad>
{=html} Ich will das" together with the
encoder_output_vectors
through the encoder to check if the second
lm_logit
, corresponding to "Ich", differs when only the last word is
changed in the input sequence ("ein" -> "das").
As expected the output shapes of the decoder input word embeddings and
lm_logits, i.e. the dimensionality of Y0:4β and L1:5β are different in the last dimension. While the
sequence length is the same (=5), the dimensionality of a decoder input
word embedding corresponds to model.config.hidden_size
, whereas the
dimensionality of a lm_logit
corresponds to the vocabulary size
model.config.vocab_size
, as explained above. Second, it can be noted
that the values of the encoded output vector of l1β="Ich" are the same when the last word is changed
from "ein" to "das". This however should not come as a surprise if
one has understood uni-directional self-attention.
On a final side-note, auto-regressive models, such as GPT2, have the same architecture as transformer-based decoder models if one removes the cross-attention layer because stand-alone auto-regressive models are not conditioned on any encoder outputs. So auto-regressive models are essentially the same as auto-encoding models but replace bi-directional attention with uni-directional attention. These models can also be pre-trained on massive open-domain text data to show impressive performances on natural language generation (NLG) tasks. In Radford et al. (2019), the authors show that a pre-trained GPT2 model can achieve SOTA or close to SOTA results on a variety of NLG tasks without much fine-tuning. All auto-regressive models of π€Transformers can be found here.
Alright, that's it! Now, you should have gotten a good understanding of transformer-based encoder-decoder models and how to use them with the π€Transformers library.
Thanks a lot to Victor Sanh, Sasha Rush, Sam Shleifer, Oliver Γ strand, βͺTed Moskovitz and Kristian Kyvik for giving valuable feedback.
As mentioned above, the following code snippet shows how one can program
a simple generation method for transformer-based encoder-decoder
models. Here, we implement a simple greedy decoding method using
torch.argmax
to sample the target vector.
from transformers import MarianMTModel, MarianTokenizer
import torch
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
# create ids of encoded input vectors
input_ids = tokenizer("I want to buy a car", return_tensors="pt").input_ids
# create BOS token
decoder_input_ids = tokenizer("<pad>", add_special_tokens=False, return_tensors="pt").input_ids
assert decoder_input_ids[0, 0].item() == model.config.decoder_start_token_id, "`decoder_input_ids` should correspond to `model.config.decoder_start_token_id`"
# STEP 1
# pass input_ids to encoder and to decoder and pass BOS token to decoder to retrieve first logit
outputs = model(input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
# get encoded sequence
encoded_sequence = (outputs.encoder_last_hidden_state,)
# get logits
lm_logits = outputs.logits
# sample last token with highest prob
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
# concat
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
# STEP 2
# reuse encoded_inputs and pass BOS + "Ich" to decoder to second logit
lm_logits = model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits
# sample last token with highest prob again
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
# concat again
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
# STEP 3
lm_logits = model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
# let's see what we have generated so far!
print(f"Generated so far: {tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)}")
# This can be written in a loop as well.
Outputs:
Generated so far: Ich Ich
In this code example, we show exactly what was described earlier. We
pass an input "I want to buy a car" together with the BOS
token to the encoder-decoder model and sample from the first logit l1β (i.e. the first lm_logits
line). Hereby, our sampling
strategy is simple to greedily choose the next decoder input vector that
has the highest probability. In an auto-regressive fashion, we then pass
the sampled decoder input vector together with the previous inputs to
the encoder-decoder model and sample again. We repeat this a third time.
As a result, the model has generated the words "Ich Ich". The first
word is spot-on! The second word is not that great. We can see here,
that a good decoding method is key for a successful sequence generation
from a given model distribution.
In practice, more complicated decoding methods are used to sample the
lm_logits
. Most of which are covered in
this blog post.