# From attention to transformers


In this tutorial, our focus is on delving into the intricacies of the attention mechanism. If you're keen on it, you'll be able to create a self-attention layer and construct your own transformer model from skatch.

In many well-established libraries like **torch**, the code tends to be somewhat challenging to decipher due to efficiency optimizations and the inclusion of various conditional paths using **if** and **else**. Here, we will craft **a more intelligible yet functionally equivalent model** and verify its performance against the official implementation.



### General note for GPU training (in colab)

* First, please use the GPU runtime. If so the `!nvidia-smi` will return no error.
 1. Click on "Runtime" in the top menu bar.
 2. Select "Change runtime type" from the drop-down menu.
 3. In the "Runtime type" section, select "GPU" as the hardware accelerator.
 4. Click "Save" to apply the changes.


* What should I do with **Cuda out of memory error.**? (this is THE most common error in DL)
![](https://miro.medium.com/v2/resize:fit:828/format:webp/1*enMsxkgJ1eb9XvtWju5V8Q.png)
 1. In colab notebook, **unfortunately, you need to restart the kernel after OOM happened**. Or it will keep happening no matter what.
 2. Change the model to save memory, usually includes, decrease batch size, decrease the number of layers, decrease the max sequence length, decrease the hidden / embedding dimension
 3. If you know mixed precision training, you can switch to low precision `fp16` numbers for weights and inputs.

* What should I do for the **Device siee assert triggered** error
 > RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
 
 * Usually it's because the embedding layer receive an index (token id or position id) not stored in it.
 * Could be sth. else, which will be harder to debug...

In [1]:
# import locale
# locale.getpreferredencoding = lambda: "UTF-8" # to fix a potential locale bug
!nvidia-smi

Tue Jan 23 11:37:03 2024 
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 63C P8 11W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
 
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
| No running processes found |
+---------------------------------------------------------------------------------------+


### Imports

In [None]:
!pip install torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import matplotlib.pyplot as plt

In [None]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## Self-Attention Mechanism: Single Head

![](https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/main/media/AttentionSchematics_white-01.png)

In [None]:
embdim = 256
headdim = 64
tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
Wq = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wk = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wv = torch.randn(embdim, embdim) / math.sqrt(embdim)

Fill in the score matrix computation

In [None]:
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # batch x seqlen x headdim
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # batch x seqlen x headdim
vis = torch.einsum("BTE,EF->BTF", tokens, Wv) # batch x seqlen x embeddim
#### ------ Add your code here:compute query-key similarities. ------ ####
scoremat = # output: batch x seqlen (Query) x seqlen (Key)
#### ------ End ------ ####
attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2)

Some checks to make sure the score correspond to the product of the right pair.

In [None]:
assert(torch.isclose(scoremat[0,1,2], qis[0,1,:]@kis[0,2,:]))
assert(torch.isclose(scoremat[0,3,4], qis[0,3,:]@kis[0,4,:]))
assert(torch.isclose(scoremat[0,2,2], qis[0,2,:]@kis[0,2,:]))

In [None]:
zis = torch.einsum("BST,BTF->BSF", attmat, vis)

In pytorch, these operations are packed int the function `F.scaled_dot_product_attention`. So let's test our implementation of the single head attention against it.

In [None]:
attn_torch = F.scaled_dot_product_attention(qis,kis,vis)
assert(torch.allclose(attn_torch, zis, atol=1E-6,rtol=1E-6))

## Multi-head attention

In [None]:
embdim = 768
headcnt = 12
headdim = embdim // headcnt
assert headdim * headcnt == embdim
tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim

In [None]:
batch, token_num, _ = tokens.shape
qis = torch.einsum("BSE,EH->BSH", tokens, Wq)
kis = torch.einsum("BTE,EH->BTH", tokens, Wk)
vis = torch.einsum("BTE,EH->BTH", tokens, Wv)
# split the single hidden dim into the heads
qis_mh = qis.view(batch, token_num, headcnt, headdim)
kis_mh = kis.view(batch, token_num, headcnt, headdim)
vis_mh = vis.view(batch, token_num, headcnt, headdim)

Now your challenge is to compute multihead attention using `einsum`

In [None]:
#### ------ Add your code here: compute query-key similarities. ------ ####
scoremat_mh = # Output: batch x headcnt x seqlen (query) x seqlen (key)
#### ------ End ------ ####
attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1)
zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim
zis = zis_mh.reshape(batch, token_num, headcnt * headdim)

Let's validate the tensor multiplication is correct

In [None]:
# raw attention score of the 1st attention head
assert (torch.allclose(scoremat_mh[0, 1], qis_mh[0,:,1] @ kis_mh[0,:,1,:].T))

In [None]:
print(tokens.shape)
print(qis_mh.shape)
print(kis_mh.shape)
print(vis_mh.shape)
print(attmat_mh.shape)
print(zis_mh.shape)
print(zis.shape)

In `torch` this operation is packed in `nn.MultiheadAttention`, including the input projection, attention and out projection. So, note the input the the `mha.forward` function are the *token_embeddings* not the Q,K,Vs as we put it in `F.scaled_dot_product_attention`

In [None]:
mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)
print(mha.in_proj_weight.shape) # 3 * embdim x embdim
mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T

In [None]:
attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False,)
assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6)

In `nn.MultiheadAttention` , there is a output projection `out_proj`, projecting the values. It is a linear layer with bias. We can validate that going through this projection our outputs `zis` is the same as the output of `mha`

In [None]:
print(mha.out_proj)
assert torch.allclose(attn_out, mha.out_proj(zis), atol=1e-6, rtol=1e-6)

### Causal attention mask

For models such as GPT, each token can only attend to tokens before it, thus the attention score needs to be modified before entering softmax.

The common way of masking is to add a large negative number to the locations that you'd not want the model to attend to.

In [None]:
attn_mask = torch.ones(token_num,token_num,)
attn_mask = -1E4 * torch.triu(attn_mask,1)
attn_mask

In [None]:
scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key)
scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization
attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1)
zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim
zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim)

**Note** `is_causal` parameter should work and create a causal mask automatically. But in a recent pytorch bug, it doesn't work. So beware~
https://github.com/pytorch/pytorch/issues/99282

In [None]:
attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask)

In [None]:
assert torch.allclose(attn_weights_causal, attmat_mh_msk, atol=1e-6, rtol=1e-6)
assert torch.allclose(attn_out_causal, mha.out_proj(zis_msk), atol=1e-6, rtol=1e-6)

In [None]:
plt.figure()
for head in range(headcnt):
 plt.subplot(3, 4, head + 1)
 plt.imshow(attn_weights_causal[0, head].detach().numpy())
 plt.title(f"head {head}")
 plt.axis("off")
plt.show()

## Transformer Block

Having gaining some intuition about attention layer, let's build it into a transformer. An vanilla transformer block usually looks like this. Note there are slight difference between the transformer blocks in GPT2, BERT and other models, but they generally has the following components

* Transformer Block
 * Layernorm
 * Skip connections
 * Multi-head attention
 * MLP, Feedforward net


In [None]:
class TransformerBlock_simple(nn.Module):

 def __init__(self, embdim, headcnt, *args, dropout=0.0, **kwargs) -> None:
 super().__init__(*args, **kwargs)
 self.ln1 = nn.LayerNorm(embdim)
 self.ln2 = nn.LayerNorm(embdim)
 self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)
 self.ffn = nn.Sequential(
 nn.Linear(embdim, 4 * embdim),
 nn.GELU(),
 nn.Linear(4 * embdim, embdim),
 nn.Dropout(dropout),
 )

 def forward(self, x, is_causal=True):
 batch, token_num, hidden_dim = x.shape
 if is_causal:
 attn_mask = torch.ones(token_num, token_num,)
 attn_mask = -1E4 * torch.triu(attn_mask,1)
 else:
 attn_mask = None

 residue = x
 x = self.ln1(x)
 #### ------ Add your code here: multihead attention ------ ####
 attn_output, attn_weights = # first output is the output latent states
 #### ------ End ------ ####
 x = residue + attn_output

 residue = x
 x = self.ln2(x)
 ffn_output = self.ffn(x)
 output = residue + ffn_output
 return output

Compare the implmentation with the schematics and see if it makes more sense!


*Attention Block*


![BERT (Transformer encoder)](https://iq.opengenus.org/content/images/2020/06/encoder-1.png)


# Image Classification

Now we employ Transformer structure to conduct image classification.

### Imports

In [None]:
!pip install transformers
!pip install torchvision

## Import transformers
from transformers import get_linear_schedule_with_warmup
from transformers import BertForSequenceClassification
from transformers import BertModel, BertTokenizer, BertConfig

import os
from os.path import join
from tqdm.notebook import tqdm, trange
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST, CIFAR10
from torchvision import datasets, transforms


### Preparing Image Dataset
Load the dataset, note, the augmentations are necessary. If no augmentation, Transformer will overfit very soon.

In [None]:
!mkdir data
dataset = CIFAR10(root='./data/', train=True, download=True, transform=
transforms.Compose([
 transforms.RandomHorizontalFlip(),
 transforms.RandomCrop(32, padding=4),
 transforms.ToTensor(),
 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
# augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy
val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose(
 [transforms.ToTensor(),
 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]))
#%%

Citing https://openreview.net/pdf?id=SCN8UaetXx,

> "Visual Transformers. Despite some previous work in which attention is used inside the convolutional layers of a CNN [57, 26], the first fully-transformer architectures for vision are iGPT [8] and ViT [17]. The former is trained using a "masked-pixel" self-supervised approach, similar in spirit to the common masked-word task used, for instance, in BERT [15] and in GPT [45] (see below). On the other hand, ViT is trained in a supervised way, using a special "class token" and a classification head attached to the final embedding of this token. Both methods are computationally expensive and, despite their very good results when trained on huge datasets, they underperform ResNet architectures when trained from scratch using only ImageNet-1K [17, 8]. VideoBERT [51] is conceptually similar to iGPT, but, rather than using pixels as tokens, each frame of a video is holistically represented by a feature vector, which is quantized using an off-the-shelf pretrained video classification model. DeiT [53] trains ViT using distillation information provided by a pretrained CNN."

### Transformer model for images

In [None]:
config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12,
 num_attention_heads=8, max_position_embeddings=256,
 vocab_size=100, bos_token_id=101, eos_token_id=102,
 cls_token_id=103, )
model = BertModel(config).cuda()
patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda()
CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size))
readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
 nn.GELU(),
 nn.Linear(config.hidden_size, 10)
 ).cuda()
for module in [patch_embed, readout, model, CLS_token]:
 module.cuda()

optimizer = AdamW([*model.parameters(),
 *patch_embed.parameters(),
 *readout.parameters(),
 CLS_token], lr=5e-4)

In [None]:
batch_size = 192 # 96
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
model.train()
loss_list = []
acc_list = []
correct_cnt = 0
total_loss = 0
for epoch in trange(10, leave=False):
 pbar = tqdm(train_loader, leave=False)
 for i, (imgs, labels) in enumerate(pbar):
 patch_embs = patch_embed(imgs.cuda())
 #### ------ Add your code here: replace the None with the correct order of the embedding dimension. ------ ####
 patch_embs = patch_embs.flatten(2).permute(None, None, None) # hint: (batch_size, HW, hidden)
 #### ------ End ------ ####
 # print(patch_embs.shape)
 input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
 # print(input_embs.shape)
 output = model(inputs_embeds=input_embs)
 logit = readout(output.last_hidden_state[:, 0, :])
 loss = F.cross_entropy(logit, labels.cuda())
 # print(loss)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()
 pbar.set_description(f"loss: {loss.item():.4f}")
 total_loss += loss.item() * imgs.shape[0]
 correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()

 loss_list.append(round(total_loss / len(dataset), 4))
 acc_list.append(round(correct_cnt / len(dataset), 4))
 # test on validation set
 model.eval()
 correct_cnt = 0
 total_loss = 0

 for i, (imgs, labels) in enumerate(val_loader):
 patch_embs = patch_embed(imgs.cuda())
 #### ------ Add your code here: replace the None with the correct order of the embedding dimension. ------ ####
 patch_embs = patch_embs.flatten(2).permute(None, None, None) # hint: (batch_size, HW, hidden)
 #### ------ End ------ ####
 input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
 output = model(inputs_embeds=input_embs)
 logit = readout(output.last_hidden_state[:, 0, :])
 loss = F.cross_entropy(logit, labels.cuda())
 total_loss += loss.item() * imgs.shape[0]
 correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()

 print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}")


In [None]:
#### ------ Add your code here: plot the training loss curve to show its variation with the epoch. ------ ####
# hints: use the data in list 'loss_list' and 'acc_list' to plot the curve via plt.plot()
#### ------ End ------ ####

In [None]:
#### ------ Add your code here: plot the accuracy score curve to show its variation with the epoch. ------ ####
# hints: use the data in list 'loss_list' and 'acc_list' to plot the curve via plt.plot()
#### ------ End ------ ####

In [None]:
torch.save(model.state_dict(),"bert.pth")
!du -sh bert.pth

**Reference:**
Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch (https://github.com/Animadversio/TransformerFromScratch?tab=readme-ov-file).