makiisthebes commited on
Commit
336cbca
1 Parent(s): 7d6a371

Transformers from Scratch

Browse files
Files changed (1) hide show
  1. scratch_transformer.py +187 -0
scratch_transformer.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Transformers from Scratch using "Attention is All You Need" paper
2
+ # Modelling Scaled Dot-Product Attention, Multi-Head Attention, Position-wise Feed-Forward Networks.
3
+
4
+ # Import Modules
5
+ import matplotlib.pyplot as plt
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ import torch
9
+ import numpy as np
10
+ import math
11
+
12
+ # Making Single and Multi-Head Attention modules from scratch using Pure PyTorch
13
+
14
+ # Initialise the seed for reproducibility
15
+ seed = 42
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed(seed)
19
+
20
+ # Self-Attention Mechanism: Single Head
21
+ embdim = 256 # D
22
+ headdim = 64 # Internal D
23
+ tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
24
+
25
+ # Defining weights associates with query, key, value
26
+ Wq = torch.randn(embdim, headdim) / math.sqrt(embdim)
27
+ Wk = torch.randn(embdim, headdim) / math.sqrt(embdim)
28
+ Wv = torch.randn(embdim, embdim) / math.sqrt(embdim)
29
+
30
+ # Query, Key, Value
31
+ qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # batch x seqlen x headdim; queries, (1, 5, 64)
32
+ kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # batch x seqlen x headdim; keys
33
+ vis = torch.einsum("BTE,EF->BTF", tokens, Wv) # batch x seqlen x embeddim; values
34
+
35
+ # Start: Testing Code
36
+ random_mat1 = torch.randn(2, 5, 4) # BATCH, TOKENS, DIMENSIONS
37
+ random_mat2 = torch.randn(2, 5, 4)
38
+
39
+ # 2, 5, 4 * , 2, 4, 5
40
+ torch.matmul(random_mat1, random_mat2.transpose(1, 2)) # 2, 5, 5
41
+ print(qis.shape)
42
+ print(kis.shape)
43
+ # (Q) N, D * (K^T) D, N -> N, N
44
+ # End: Testing Code
45
+
46
+
47
+ scoremat = torch.matmul(qis, kis.transpose(1, 2)) # output: batch x seqlen (Query) x seqlen (Key)
48
+ attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2) # attention matrix given.
49
+
50
+ # Output of the attention mechanism
51
+ zis = torch.einsum("BST,BTF->BSF", attmat, vis)
52
+
53
+ # We can verify the output, with scaled dot-product attention
54
+ attn_torch = F.scaled_dot_product_attention(qis, kis, vis)
55
+ assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6)) # True
56
+
57
+ # Multi-Head Attention
58
+ embdim = 768
59
+ headcnt = 12
60
+ headdim = embdim // headcnt
61
+ # print(headdim)
62
+ assert headdim * headcnt == embdim
63
+ tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
64
+
65
+ # We use all the 256, ( 768) ~ which is (256), (64 * 12 (heads))
66
+ Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
67
+ Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
68
+ Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
69
+
70
+ print(Wq.shape)
71
+ print(Wk.shape)
72
+ print(Wv.shape)
73
+
74
+ batch, token_num, _ = tokens.shape # batch, tokens (n), embedding shape.
75
+ # tokens, B, N, E
76
+
77
+ # Wq, B, E, HWeights (H * HC)
78
+ qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # Batch, N, H ~ 1, 5, 768
79
+ kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # Batch N, H
80
+ vis = torch.einsum("BTE,EH->BTH", tokens, Wv) # Batch, N, H
81
+ # split the single hidden dim into the heads
82
+
83
+ # Converting dimensions from (B, N, H) to (B, N, HC, HW)
84
+ # So now for each batch, for each token, for each head there are a set of weights.
85
+ qis_mh = qis.view(batch, token_num, headcnt, headdim) # B, N, HC, HW
86
+ kis_mh = kis.view(batch, token_num, headcnt, headdim)
87
+ vis_mh = vis.view(batch, token_num, headcnt, headdim)
88
+
89
+ scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh) # Input: (B, N, HC, HH) & Output: (B, HC, Q, K)
90
+ print(scoremat_mh.shape) # 1, 12, 5, 5 # Now I have 12 heads, which have given me attention matrices of shape 5x5.
91
+
92
+ # batch x headcnt x seqlen (query) x seqlen (key)
93
+
94
+ attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1)
95
+ zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim
96
+ zis = zis_mh.reshape(batch, token_num, headcnt * headdim)
97
+
98
+ # The block does not do the operation of concat and linear layer operations on this.
99
+
100
+ # We can verify the output, with Multi-Head Attention
101
+ mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, )
102
+ print(mha.in_proj_weight.shape) # 3 * embdim x embdim
103
+ mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T
104
+ attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, )
105
+
106
+ # Which is the same as attmat_mh
107
+ assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) # True
108
+
109
+ print(attn_weights.shape) # batch, heads, tokens, tokens.
110
+ print(attn_out.shape)
111
+
112
+ # Casual Mask from Scratch
113
+ # Calculate Casual Mask, this is described in the paper when we do not want to attend to the future tokens, in decoder.
114
+
115
+ attn_mask = torch.ones(token_num, token_num, )
116
+ attn_mask = -1E4 * torch.triu(attn_mask, 1)
117
+ print(attn_mask)
118
+ scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key)
119
+ scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization
120
+ attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1)
121
+ zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim
122
+ zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim)
123
+
124
+ attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask)
125
+
126
+ # Plotting all heads of the attention mechanism.
127
+ plt.figure()
128
+ for head in range(headcnt):
129
+ plt.subplot(3, 4, head + 1)
130
+ plt.imshow(attn_weights_causal[0, head].detach().numpy())
131
+ plt.title(f"head {head}")
132
+ plt.axis("off")
133
+ plt.show()
134
+
135
+ # Transformer Block from Scratch
136
+
137
+ # Modeling the Transformer Block from Scratch using PyTorch
138
+ # Transformer Block contains:
139
+ # - Layer norm
140
+ # - Skip connections
141
+ # - Multi-head attention
142
+ # - MLP, Feedforward net
143
+
144
+
145
+ class TransformerBlock(nn.Module):
146
+
147
+ def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None:
148
+ super().__init__(*args, **kwargs)
149
+ self.ln1 = nn.LayerNorm(embdim)
150
+ self.ln2 = nn.LayerNorm(embdim)
151
+ self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)
152
+ self.ffn = nn.Sequential(
153
+ nn.Linear(embdim, 4 * embdim),
154
+ nn.GELU(),
155
+ nn.Linear(4 * embdim, embdim),
156
+ nn.Dropout(dropout),
157
+ )
158
+
159
+ def forward(self, x, is_causal=True):
160
+ """
161
+ Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added.
162
+ """
163
+ batch, token_num, hidden_dim = x.shape
164
+ if is_causal:
165
+ attn_mask = torch.ones(token_num, token_num,)
166
+ attn_mask = -1E4 * torch.triu(attn_mask,1)
167
+ else:
168
+ attn_mask = None
169
+
170
+ residue = x
171
+ attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, )
172
+ x = residue + attn_output
173
+ x = self.ln1(x)
174
+ residue = x
175
+ ffn_output = self.ffn(x)
176
+ output = residue + ffn_output
177
+ return output
178
+
179
+
180
+
181
+ if __name__ == "__main__":
182
+ # Testing the Transformer Block
183
+ print("Testing the Transformer Block")
184
+ transformer_block = TransformerBlock(embdim, headcnt)
185
+ tokens = torch.randn(1, 5, embdim)
186
+ output = transformer_block(tokens)
187
+ print(output.shape)