On Coding Your First Attention

Community Article Published April 21, 2024

While it’s not necessarily the case that you must code the attention block of a transformer from scratch to understand how it works, yet it sure is the closest you can get to having a first-principles understanding of why/how transformers behave the way they do.

image/png

@karpathy covered attention in detail in his nanoGPT video (strongly recommend watching). Now I would like to share some thoughts and experience in writing my first attention.

First let’s zoom out quickly and explain what attention is in transformers: Attention in transformers is a communication mechanism that allows the model to focus on different parts of the input sequence when making predictions.

It assigns weights to each input token based on its relevance to the current context, enabling the model to weigh information selectively. This mechanism helps transformers capture long-range dependencies and contextual information effectively.

The official AIAN paper introduced two commonly used forms of attentions: Scaled Dot-Product Attention (also known as Self-Attention) and a stack of self-attention blocks known as Multi-Head Attention.

The Code

Now, attention as for most deep learning algorithms boils down to a math equation. So writing the code can get really trivial especially with a deep learning framework like PyTorch. Below is what's called a Single Head Attention:

image/png

The code defines single-head attention in PyTorch - it transforms input vectors, computes attention scores and weights, and then calculates the weighted sum of values based on these weights (as per the attention equation)

When you have multiple of those stacked in parallel, you get what's called Multi-Head Attention. This gives a much simpler code if you are inheriting from the SingleHeadAttention class:

image/png

This one creates multiple attentions (inheriting from SingleHeadAttention class) and stacks them in parallel. During the forward pass, it applies each head to the input tensors Q, K, and V, concatenates the outputs, and linearly transforms them to produce the final output.

End Note

So, in essence, having to code this made me revisit some key underlying concepts (pytorch's matmul, softmax, dropout, even backprop) that helped clear some stuff up. This along with Karpathy's nanoGPT video shaped my understanding - which is still a work in progress as new forms of the transformer architecture emerge.