|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import torch |
|
import numpy as np |
|
import math |
|
|
|
|
|
|
|
|
|
seed = 42 |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
embdim = 256 |
|
headdim = 64 |
|
tokens = torch.randn(1, 5, embdim) |
|
|
|
|
|
Wq = torch.randn(embdim, headdim) / math.sqrt(embdim) |
|
Wk = torch.randn(embdim, headdim) / math.sqrt(embdim) |
|
Wv = torch.randn(embdim, embdim) / math.sqrt(embdim) |
|
|
|
|
|
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) |
|
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) |
|
vis = torch.einsum("BTE,EF->BTF", tokens, Wv) |
|
|
|
|
|
random_mat1 = torch.randn(2, 5, 4) |
|
random_mat2 = torch.randn(2, 5, 4) |
|
|
|
|
|
torch.matmul(random_mat1, random_mat2.transpose(1, 2)) |
|
print(qis.shape) |
|
print(kis.shape) |
|
|
|
|
|
|
|
|
|
scoremat = torch.matmul(qis, kis.transpose(1, 2)) |
|
attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2) |
|
|
|
|
|
zis = torch.einsum("BST,BTF->BSF", attmat, vis) |
|
|
|
|
|
attn_torch = F.scaled_dot_product_attention(qis, kis, vis) |
|
assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6)) |
|
|
|
|
|
embdim = 768 |
|
headcnt = 12 |
|
headdim = embdim // headcnt |
|
|
|
assert headdim * headcnt == embdim |
|
tokens = torch.randn(1, 5, embdim) |
|
|
|
|
|
Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) |
|
Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) |
|
Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) |
|
|
|
print(Wq.shape) |
|
print(Wk.shape) |
|
print(Wv.shape) |
|
|
|
batch, token_num, _ = tokens.shape |
|
|
|
|
|
|
|
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) |
|
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) |
|
vis = torch.einsum("BTE,EH->BTH", tokens, Wv) |
|
|
|
|
|
|
|
|
|
qis_mh = qis.view(batch, token_num, headcnt, headdim) |
|
kis_mh = kis.view(batch, token_num, headcnt, headdim) |
|
vis_mh = vis.view(batch, token_num, headcnt, headdim) |
|
|
|
scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh) |
|
print(scoremat_mh.shape) |
|
|
|
|
|
|
|
attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1) |
|
zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) |
|
zis = zis_mh.reshape(batch, token_num, headcnt * headdim) |
|
|
|
|
|
|
|
|
|
mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, ) |
|
print(mha.in_proj_weight.shape) |
|
mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T |
|
attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, ) |
|
|
|
|
|
assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) |
|
|
|
print(attn_weights.shape) |
|
print(attn_out.shape) |
|
|
|
|
|
|
|
|
|
attn_mask = torch.ones(token_num, token_num, ) |
|
attn_mask = -1E4 * torch.triu(attn_mask, 1) |
|
print(attn_mask) |
|
scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) |
|
scoremat_mh_msk += attn_mask |
|
attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1) |
|
zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) |
|
zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim) |
|
|
|
attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask) |
|
|
|
|
|
plt.figure() |
|
for head in range(headcnt): |
|
plt.subplot(3, 4, head + 1) |
|
plt.imshow(attn_weights_causal[0, head].detach().numpy()) |
|
plt.title(f"head {head}") |
|
plt.axis("off") |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
|
def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.ln1 = nn.LayerNorm(embdim) |
|
self.ln2 = nn.LayerNorm(embdim) |
|
self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,) |
|
self.ffn = nn.Sequential( |
|
nn.Linear(embdim, 4 * embdim), |
|
nn.GELU(), |
|
nn.Linear(4 * embdim, embdim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x, is_causal=True): |
|
""" |
|
Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added. |
|
""" |
|
batch, token_num, hidden_dim = x.shape |
|
if is_causal: |
|
attn_mask = torch.ones(token_num, token_num,) |
|
attn_mask = -1E4 * torch.triu(attn_mask,1) |
|
else: |
|
attn_mask = None |
|
|
|
residue = x |
|
attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, ) |
|
x = residue + attn_output |
|
x = self.ln1(x) |
|
residue = x |
|
ffn_output = self.ffn(x) |
|
output = residue + ffn_output |
|
return output |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
print("Testing the Transformer Block") |
|
transformer_block = TransformerBlock(embdim, headcnt) |
|
tokens = torch.randn(1, 5, embdim) |
|
output = transformer_block(tokens) |
|
print(output.shape) |
|
|