transformers_scratch / scratch_transformer.py
makiisthebes's picture
Transformers from Scratch
336cbca verified
# Transformers from Scratch using "Attention is All You Need" paper
# Modelling Scaled Dot-Product Attention, Multi-Head Attention, Position-wise Feed-Forward Networks.
# Import Modules
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import math
# Making Single and Multi-Head Attention modules from scratch using Pure PyTorch
# Initialise the seed for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Self-Attention Mechanism: Single Head
embdim = 256 # D
headdim = 64 # Internal D
tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
# Defining weights associates with query, key, value
Wq = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wk = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wv = torch.randn(embdim, embdim) / math.sqrt(embdim)
# Query, Key, Value
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # batch x seqlen x headdim; queries, (1, 5, 64)
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # batch x seqlen x headdim; keys
vis = torch.einsum("BTE,EF->BTF", tokens, Wv) # batch x seqlen x embeddim; values
# Start: Testing Code
random_mat1 = torch.randn(2, 5, 4) # BATCH, TOKENS, DIMENSIONS
random_mat2 = torch.randn(2, 5, 4)
# 2, 5, 4 * , 2, 4, 5
torch.matmul(random_mat1, random_mat2.transpose(1, 2)) # 2, 5, 5
print(qis.shape)
print(kis.shape)
# (Q) N, D * (K^T) D, N -> N, N
# End: Testing Code
scoremat = torch.matmul(qis, kis.transpose(1, 2)) # output: batch x seqlen (Query) x seqlen (Key)
attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2) # attention matrix given.
# Output of the attention mechanism
zis = torch.einsum("BST,BTF->BSF", attmat, vis)
# We can verify the output, with scaled dot-product attention
attn_torch = F.scaled_dot_product_attention(qis, kis, vis)
assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6)) # True
# Multi-Head Attention
embdim = 768
headcnt = 12
headdim = embdim // headcnt
# print(headdim)
assert headdim * headcnt == embdim
tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
# We use all the 256, ( 768) ~ which is (256), (64 * 12 (heads))
Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
print(Wq.shape)
print(Wk.shape)
print(Wv.shape)
batch, token_num, _ = tokens.shape # batch, tokens (n), embedding shape.
# tokens, B, N, E
# Wq, B, E, HWeights (H * HC)
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # Batch, N, H ~ 1, 5, 768
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # Batch N, H
vis = torch.einsum("BTE,EH->BTH", tokens, Wv) # Batch, N, H
# split the single hidden dim into the heads
# Converting dimensions from (B, N, H) to (B, N, HC, HW)
# So now for each batch, for each token, for each head there are a set of weights.
qis_mh = qis.view(batch, token_num, headcnt, headdim) # B, N, HC, HW
kis_mh = kis.view(batch, token_num, headcnt, headdim)
vis_mh = vis.view(batch, token_num, headcnt, headdim)
scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh) # Input: (B, N, HC, HH) & Output: (B, HC, Q, K)
print(scoremat_mh.shape) # 1, 12, 5, 5 # Now I have 12 heads, which have given me attention matrices of shape 5x5.
# batch x headcnt x seqlen (query) x seqlen (key)
attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1)
zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim
zis = zis_mh.reshape(batch, token_num, headcnt * headdim)
# The block does not do the operation of concat and linear layer operations on this.
# We can verify the output, with Multi-Head Attention
mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, )
print(mha.in_proj_weight.shape) # 3 * embdim x embdim
mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T
attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, )
# Which is the same as attmat_mh
assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) # True
print(attn_weights.shape) # batch, heads, tokens, tokens.
print(attn_out.shape)
# Casual Mask from Scratch
# Calculate Casual Mask, this is described in the paper when we do not want to attend to the future tokens, in decoder.
attn_mask = torch.ones(token_num, token_num, )
attn_mask = -1E4 * torch.triu(attn_mask, 1)
print(attn_mask)
scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key)
scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization
attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1)
zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim
zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim)
attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask)
# Plotting all heads of the attention mechanism.
plt.figure()
for head in range(headcnt):
plt.subplot(3, 4, head + 1)
plt.imshow(attn_weights_causal[0, head].detach().numpy())
plt.title(f"head {head}")
plt.axis("off")
plt.show()
# Transformer Block from Scratch
# Modeling the Transformer Block from Scratch using PyTorch
# Transformer Block contains:
# - Layer norm
# - Skip connections
# - Multi-head attention
# - MLP, Feedforward net
class TransformerBlock(nn.Module):
def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.ln1 = nn.LayerNorm(embdim)
self.ln2 = nn.LayerNorm(embdim)
self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)
self.ffn = nn.Sequential(
nn.Linear(embdim, 4 * embdim),
nn.GELU(),
nn.Linear(4 * embdim, embdim),
nn.Dropout(dropout),
)
def forward(self, x, is_causal=True):
"""
Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added.
"""
batch, token_num, hidden_dim = x.shape
if is_causal:
attn_mask = torch.ones(token_num, token_num,)
attn_mask = -1E4 * torch.triu(attn_mask,1)
else:
attn_mask = None
residue = x
attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, )
x = residue + attn_output
x = self.ln1(x)
residue = x
ffn_output = self.ffn(x)
output = residue + ffn_output
return output
if __name__ == "__main__":
# Testing the Transformer Block
print("Testing the Transformer Block")
transformer_block = TransformerBlock(embdim, headcnt)
tokens = torch.randn(1, 5, embdim)
output = transformer_block(tokens)
print(output.shape)