VarunAIML commited on
Commit
08571aa
1 Parent(s): 8b81b3c

minor changes

Browse files
Files changed (1) hide show
  1. model.py +270 -0
model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class LayerNormalization(nn.Module):
6
+
7
+ def __init__(self, eps: float=10**-6) -> None:
8
+ super().__init__()
9
+ self.eps = eps
10
+ self.alpha = nn. Parameter(torch.ones (1)) #alpha is a learnable parameter
11
+ self.bias = nn. Parameter(torch.zeros(1)) #·bias is a learnable parameter
12
+
13
+ def forward(self,x):
14
+ #x: (batch, seq_len, hidden_size)
15
+ #Keep the dimension for broadcasting
16
+ mean = x.mean (dim = -1, keepdim = True) # (batch, seq_len, 1)
17
+ #Keep the dimension for broadcasting
18
+ std = x.std (dim = -1, keepdim = True) # (batch, seq_len, ∙1)
19
+ #eps is to prevent dividing by zero or when std is very small
20
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
21
+
22
+
23
+ class FeedForwardBlock(nn.Module):
24
+
25
+ def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
26
+ super().__init__()
27
+ self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
28
+ self.dropout = nn. Dropout (dropout)
29
+ self.linear_2= nn.Linear(d_ff, d_model) # w2 and b2
30
+
31
+ def forward(self, x):
32
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
33
+ return self.linear_2(self.dropout (torch.relu(self.linear_1(x))))
34
+
35
+
36
+ class InputEmbeddings(nn.Module):
37
+
38
+ def __init__(self, d_model: int, vocab_size: int) -> None:
39
+ super().__init__()
40
+ self.d_model=d_model
41
+ self.vocab_size = vocab_size
42
+ self.embedding = nn. Embedding (vocab_size, d_model)
43
+
44
+ def forward(self,x):
45
+ #· (batch, seq_len) --> (batch, seq_len, d_model)
46
+ # Multiply by sqrt(d_model) to scale the embeddings according to the paper
47
+ return self.embedding(x)* math.sqrt(self.d_model)
48
+
49
+ class PositionalEncoding(nn.Module):
50
+
51
+ def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.seq_len = seq_len
55
+ self.dropout = nn.Dropout(dropout) # Create a matrix of shape (seq_len, d_model)
56
+ pe = torch.zeros(seq_len, d_model) # Create a vector of shape (seq_len)
57
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
58
+ # Create a vector of shape (d_model)
59
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
60
+ # Apply sine to even indices
61
+ pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
62
+ # Apply cosine to odd indices
63
+ pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
64
+ # Add a batch dimension to the positional encoding
65
+ pe = pe.unsqueeze(0) # (1, seq_len, d_model)
66
+ # Register the positional encoding as a buffer
67
+ self.register_buffer('pe', pe)
68
+
69
+ def forward(self, x):
70
+ x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
71
+ return self.dropout(x)
72
+
73
+ class ResidualConnection(nn.Module):
74
+
75
+ def __init__(self, dropout: float) -> None:
76
+ super().__init__()
77
+ self.dropout = nn.Dropout(dropout)
78
+ self.norm = LayerNormalization()
79
+
80
+ def forward(self, x, sublayer):
81
+ return x + self.dropout(sublayer(self.norm(x)))
82
+
83
+
84
+ class MultiHeadAttentionBlock(nn.Module):
85
+
86
+ def __init__(self, d_model: int, h: int, dropout: float) -> None:
87
+ super().__init__()
88
+ self.d_model = d_model # Embedding vector size
89
+ self.h = h # Number of heads
90
+ # Make sure d_model is divisible by h
91
+ assert d_model % h == 0, "d_model is not divisible by h"
92
+
93
+ self.d_k = d_model // h # Dimension of vector seen by each head
94
+ self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
95
+ self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
96
+ self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
97
+ self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
98
+ self.dropout = nn.Dropout(dropout)
99
+
100
+ @staticmethod
101
+ def attention(query, key, value, mask, dropout: nn.Dropout):
102
+ d_k = query.shape[-1]
103
+ # Just apply the formula from the paper
104
+ # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
105
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
106
+ if mask is not None:
107
+ # Write a very low value (indicating -inf) to the positions where mask == 0
108
+ _MASKING_VALUE = -1e9 if attention_scores.dtype == torch.float32 else -1e+4
109
+ attention_scores.masked_fill_(mask == 0, _MASKING_VALUE)
110
+ attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply soft
111
+ if dropout is not None:
112
+ attention_scores = dropout(attention_scores)
113
+ # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
114
+ # return attention scores which can be used for visualization
115
+ return (attention_scores @ value), attention_scores
116
+
117
+ def forward(self, q, k, v, mask):
118
+ query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
119
+ key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
120
+ value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
121
+
122
+ # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
123
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
124
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
125
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
126
+
127
+ # Calculate attention
128
+ x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
129
+
130
+ # Combine all the heads together
131
+ # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
132
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
133
+
134
+ # Multiply by Wo
135
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
136
+ return self.w_o(x)
137
+
138
+ class EncoderBlock(nn.Module):
139
+ def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None :
140
+ super().__init__()
141
+ self.self_attention_block = self_attention_block
142
+ self.feed_forward_block = feed_forward_block
143
+ self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
144
+
145
+ def forward(self, x, src_mask):
146
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
147
+ x = self.residual_connections[1](x, self.feed_forward_block)
148
+ return x
149
+
150
+ class Encoder(nn.Module):
151
+ def __init__(self, layers: nn.ModuleList) -> None:
152
+ super().__init__()
153
+ self.layers = layers
154
+ self.norm = LayerNormalization()
155
+
156
+ def forward(self, x, mask):
157
+ for layer in self.layers:
158
+ x = layer(x, mask)
159
+ return self.norm(x)
160
+
161
+ class DecoderBlock(nn.Module):
162
+ def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float ) -> None:
163
+ super().__init__()
164
+ self.self_attention_block = self_attention_block
165
+ self.cross_attention_block = cross_attention_block
166
+ self.feed_forward_block = feed_forward_block
167
+ self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
168
+
169
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
170
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
171
+ x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
172
+ x = self.residual_connections[2](x, self.feed_forward_block)
173
+ return x
174
+
175
+ class Decoder(nn.Module):
176
+ def __init__(self, layers: nn.ModuleList) -> None:
177
+ super().__init__()
178
+ self.layers = layers
179
+ self.norm = LayerNormalization()
180
+
181
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
182
+ for layer in self.layers:
183
+ x = layer(x, encoder_output, src_mask, tgt_mask)
184
+ return self.norm(x)
185
+
186
+ class ProjectionLayer(nn.Module):
187
+ def __init__(self, d_model, vocab_size) -> None:
188
+ super().__init__()
189
+ self.proj = nn.Linear(d_model, vocab_size)
190
+
191
+ def forward(self, x) -> None:
192
+ #- (batch, seq_len, d_model) ---> (batch, seq_len, vocab_size)
193
+ return torch.log_softmax(self.proj(x), dim = -1)
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
197
+ super().__init__()
198
+ self.encoder = encoder
199
+ self.decoder = decoder
200
+ self.src_embed = src_embed
201
+ self.tgt_embed = tgt_embed
202
+ self.src_pos = src_pos
203
+ self.tgt_pos = tgt_pos
204
+ self.projection_layer = projection_layer
205
+
206
+ def encode(self, src, src_mask):
207
+ #- (batch, seq_len, d_model)
208
+ src = self.src_embed(src)
209
+ src = self.src_pos(src)
210
+ return self.encoder(src, src_mask)
211
+
212
+ def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
213
+ #- (batch, -seq_len, -d_model)
214
+ tgt = self.tgt_embed(tgt)
215
+ tgt = self.tgt_pos(tgt)
216
+ return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
217
+
218
+ def project(self, x):
219
+ # (batch, -seq_len, -vocab_size)
220
+ return self.projection_layer(x)
221
+
222
+ def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
223
+
224
+ # Create the embedding: layers
225
+ src_embed = InputEmbeddings(d_model, src_vocab_size)
226
+ tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
227
+
228
+ # Create the positional encoding layers
229
+ src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
230
+ tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
231
+
232
+ # Create the encoder blocks
233
+ encoder_blocks = []
234
+ for _ in range(N // 2):
235
+ encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
236
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
237
+ encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
238
+ encoder_blocks.append(encoder_block)
239
+
240
+ #Create the decoder blocks
241
+ decoder_blocks = []
242
+ for _ in range(N // 2):
243
+ decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
244
+ decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
245
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
246
+ decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
247
+ decoder_blocks.append(decoder_block)
248
+
249
+ e1, e2, e3 = encoder_blocks
250
+ d1, d2, d3 = decoder_blocks
251
+
252
+ encoder_blocks1 = [e1, e2, e3, e3, e2, e1]
253
+ decoder_blocks1 = [d1, d2, d3, d3, d2, d1]
254
+
255
+ # Create the encoder and decoder
256
+ encoder = Encoder(nn.ModuleList (encoder_blocks1))
257
+ decoder = Decoder(nn.ModuleList(decoder_blocks1))
258
+
259
+ # Create the projection layer
260
+ projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
261
+
262
+ # Create the transformer
263
+ transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
264
+
265
+ # Initialize the parameters
266
+ for p in transformer.parameters():
267
+ if p.dim() > 1:
268
+ nn.init.xavier_uniform_(p)
269
+
270
+ return transformer