Neu256 commited on
Commit
b1ac81f
1 Parent(s): af97934

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -359
model.py DELETED
@@ -1,359 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import functional as F
4
- from utils import DEVICE
5
-
6
- class PromeLayerNorm(nn.Module):
7
- def __init__(self, epsilon=1e-5):
8
- super().__init__()
9
- self.epsilon = epsilon
10
-
11
- def forward(self, x):
12
- g = torch.nn.Parameter(torch.ones(x.shape[-1])).to(x.device)
13
- b = torch.nn.Parameter(torch.zeros(x.shape[-1])).to(x.device)
14
-
15
- u = x.mean(-1, keepdim=True)
16
- s = (x - u).pow(2).mean(-1, keepdim=True)
17
- x = (x - u) * torch.rsqrt(s + self.epsilon)
18
- x = x * g + b
19
-
20
- return x
21
-
22
- class PromeStand(nn.Module):
23
- def __init__(self, epsilon=1e-5):
24
- super().__init__()
25
- self.epsilon = epsilon
26
-
27
- def forward(self, x):
28
- """
29
- x: Input tensor
30
- """
31
- mean = x.mean() + self.epsilon
32
- std = x.std() + self.epsilon
33
- x = x - mean
34
- x = x / std
35
- return x
36
-
37
- class PromeEmbedding(nn.Module):
38
- """
39
- This class implements a Prome embedding layer.
40
-
41
- Args:
42
- vocab_size (int): The size of the vocabulary.
43
- embedding_dim (int): The dimension of the embedding.
44
- padding_idx (int, optional): The padding index. If this is not None, then the padding index will be masked out when calculating the embedding.
45
-
46
- Returns:
47
- torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
48
- """
49
- def __init__(self, vocab_size, embedding_dim, padding_idx = None):
50
- super().__init__()
51
- self.embedding_dim = embedding_dim
52
- self.weight = torch.nn.Parameter(torch.randn(vocab_size, embedding_dim))
53
- self.padding_idx = padding_idx
54
- self.context_matrix = torch.nn.Parameter(torch.randn(vocab_size, embedding_dim))
55
-
56
- def forward(self, input_ids):
57
- """
58
- Calculates the embedding for the given input IDs.
59
-
60
- Args:
61
- input_ids (torch.Tensor): A tensor of shape (batch_size, sequence_length).
62
-
63
- Returns:
64
- torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
65
- """
66
- input_ids = input_ids.long()
67
- if self.padding_idx is not None:
68
- input_ids = input_ids.masked_fill(input_ids == self.padding_idx, 0)
69
-
70
- # get symbol vector
71
- embeddings = self.weight[input_ids]
72
-
73
- # Dynamically update context vector based on input embeddings
74
- context_vectors = self.context_matrix[input_ids]
75
-
76
- # Modify embeddings using context vector
77
- output = embeddings + context_vectors
78
-
79
- return output
80
-
81
- class AttentionHead(nn.Module):
82
- """
83
- One head of the self-attention layer
84
- """
85
-
86
- def __init__(self, head_size, num_embed, block_size, dropout):
87
- super().__init__()
88
- self.key = nn.Linear(num_embed, head_size, bias=False)
89
- self.query = nn.Linear(num_embed, head_size, bias=False)
90
- self.value = nn.Linear(num_embed, head_size, bias=False)
91
- # tril is a lower triangular matrix. it is not a parameter
92
- # of the model, so we assign it to the module using register_buffer
93
- self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
94
-
95
- # layer norm
96
- self.norm = PromeStand()
97
-
98
- # Dropout
99
- self.dropout = nn.Dropout(dropout)
100
-
101
-
102
- def forward(self, x):
103
- B, T, C = x.shape
104
- key = self.key(x)
105
- query = self.query(x)
106
- # compute attention scores
107
- # (B, T, C) @ (B, C, T) -> (B, T, T)
108
- wei = (query @ key.transpose(-2, -1)) * C ** -0.5
109
- # Tril matrix (lower triagular matrix) is used to mask
110
- # future positions (setting them to -inf) so that the
111
- # decoder "learns" to predict next words
112
- wei = wei.masked_fill(self.tril[:T, :T] == 0, -float("inf")) # (B,T,T)
113
- wei = F.silu(F.softmax(wei, dim=-1))
114
- # scale
115
- # multiplicative attention
116
- score = -1 / (C ** -0.5)
117
- wei.mul_(score)
118
- # weighted aggregation of the values
119
- value = self.value(x)
120
- out = wei @ value # (B,T,T) @ (B,T,C) ---> (B,T,C)
121
-
122
- return out
123
-
124
- class MultiHeadAttention(nn.Module):
125
- """
126
- Multiple Heads of self-attention in parallel
127
- """
128
-
129
- def __init__(self, num_heads, head_size, num_embed, block_size, dropout):
130
- super().__init__()
131
- self.heads = nn.ModuleList(
132
- [
133
- AttentionHead(
134
- head_size=head_size,
135
- num_embed=num_embed,
136
- block_size=block_size,
137
- dropout=dropout
138
- )
139
- for _ in range(num_heads)
140
- ]
141
- )
142
- self.proj = nn.Linear(num_embed, num_embed)
143
- self.dropout = nn.Dropout(dropout)
144
- self.norm = PromeStand()
145
-
146
- def forward(self, x):
147
- # output of the self-attention
148
- out = torch.concat([h(x) for h in self.heads], dim=-1)
149
- # standartization
150
- out = self.norm(out + x)
151
- # apply the linear projection layer
152
- out = self.dropout(self.proj(out))
153
-
154
- return out
155
-
156
-
157
- class MLP(nn.Module):
158
- def __init__(self, num_embed, hidden_dim, dropout=0.1):
159
- super().__init__()
160
- self.dropout = nn.Dropout(dropout)
161
- self.fc1 = nn.Linear(num_embed, hidden_dim)
162
- self.fc2 = nn.Linear(hidden_dim, hidden_dim)
163
- self.fc3 = nn.Linear(hidden_dim, num_embed)
164
-
165
- def forward(self, x):
166
- x = self.fc1(x)
167
- x = F.silu(x)
168
- x = self.fc2(x)
169
- x = self.dropout(x)
170
- x = F.silu(x)
171
- x = self.fc3(x)
172
- return x
173
-
174
-
175
- class TransformerBlock(nn.Module):
176
- """
177
- This calss will group together MultiHead Attention and
178
- FeedForward NN, so that we can copy it in Transformer
179
- """
180
-
181
- def __init__(self, num_heads, block_size, num_embed, hidden_dim, dropout):
182
- super().__init__()
183
- head_size = num_embed // num_heads
184
- self.mha = MultiHeadAttention(
185
- num_heads=num_heads,
186
- head_size=head_size,
187
- num_embed=num_embed,
188
- block_size=block_size,
189
- dropout=dropout
190
- )
191
- self.mlp = MLP(num_embed=num_embed, hidden_dim = hidden_dim, dropout=dropout)
192
- # add the layer normalization
193
- self.ln = PromeStand(num_embed)
194
-
195
- self.dropout = nn.Dropout(dropout)
196
-
197
- def forward(self, x):
198
- """
199
- Decodes the input sequence.
200
-
201
- Args:
202
- x (torch.Tensor): A tensor of shape (batch_size, sequence_length, embedding_dim).
203
- memory (torch.Tensor): A tensor of shape (batch_size, memory_length, embedding_dim).
204
-
205
- Returns:
206
- torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
207
- """
208
- y = x
209
-
210
- x = self.ln(x)
211
- x = self.mha(x)
212
- x = self.dropout(x)
213
- x += y
214
- y = x
215
- x = self.ln(x)
216
- x = self.mlp(x)
217
- x = self.mha(x)
218
- x += y
219
- x = self.dropout(x)
220
-
221
- return x
222
-
223
-
224
- class TransformerDecoder(nn.Module):
225
- """
226
- This class implements a Transformer decoder.
227
-
228
- Args:
229
- num_heads (int): The number of attention heads.
230
- block_size (int): The size of the input sequence.
231
- num_embed (int): The dimension of the embedding.
232
- num_layers (int): The number of decoder blocks.
233
- dropout (float): The dropout rate.
234
-
235
- Returns:
236
- torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
237
- """
238
- def __init__(self, num_heads, block_size, num_embed, hidden_dim, num_layers, dropout):
239
- super().__init__()
240
-
241
- # Create the embedding layer.
242
- self.pemb = PromeEmbedding(block_size, num_embed)
243
-
244
- # Create a sequential block of Transformer blocks.
245
- self.blocks = nn.Sequential(
246
- *[
247
- TransformerBlock(
248
- num_heads=num_heads,
249
- block_size=block_size,
250
- num_embed=num_embed,
251
- hidden_dim = hidden_dim,
252
- dropout=dropout
253
- )
254
- for _ in range(num_layers)
255
- ]
256
- )
257
-
258
- # Create a softmax layer.
259
- self.softmax = nn.Softmax(dim=-1)
260
-
261
- def forward(self, x):
262
- """
263
- Decodes the input sequence.
264
-
265
- Args:
266
- x (torch.Tensor): A tensor of shape (batch_size, sequence_length).
267
-
268
- Returns:
269
- torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
270
- """
271
-
272
- # Add positional encodings to the input sequence.
273
- x = x + self.pemb(torch.arange(x.size(1)))
274
-
275
- x = self.blocks(x)
276
-
277
- # Apply a softmax layer to the output of the last Transformer block.
278
- x = self.softmax(x)
279
-
280
- return x
281
-
282
- class Transformer(nn.Module):
283
- def __init__(self, **kwargs):
284
- super().__init__()
285
- # a simple lookup table that stores embeddings of a fixed dictionary and size
286
- # each token directly reads off the logits for the next token from a lookup table
287
- # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
288
- self.vocab_size = kwargs.get("vocab_size", 100)
289
- self.num_embed = kwargs.get("num_embed", 32)
290
- self.block_size = kwargs.get("block_size", 8)
291
- self.num_heads = kwargs.get("num_heads", 4)
292
- self.num_layers = kwargs.get("num_layers", 4)
293
- self.hidden_dim = kwargs.get("hidden_dim", 768)
294
- self.dropout = kwargs.get("dropout", 0.2)
295
- # each token reads the logits for the next token from a lookup table
296
- self.token_embedding_table = PromeEmbedding(self.vocab_size, self.num_embed)
297
- # each position from 0 to block_size-1 will get its embedding
298
- self.position_embedding_table = PromeEmbedding(self.block_size, self.num_embed)
299
-
300
- self.decoder = TransformerDecoder(self.num_heads, self.block_size, self.num_embed, self.hidden_dim, self.num_layers, self.dropout)
301
-
302
- # we add the layer norm before the Linear layer
303
- self.dropout = nn.Dropout(self.dropout)
304
- self.ln_f = PromeLayerNorm(self.num_embed)
305
- self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
306
-
307
- def forward(self, idx, targets=None):
308
- B, T = idx.shape
309
- # idx and targets are (B,T) tensor of integers
310
- # the token_emb is (B, T, C), C = NUM_EMBED
311
- token_emb = self.token_embedding_table(idx)
312
- # (T, C)
313
- posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
314
-
315
- x = token_emb + posit_emb
316
-
317
- # apply dropout
318
- x = self.dropout(x)
319
-
320
- # apply one head of self-attention
321
- x = self.decoder(x)
322
-
323
- # apply normalization
324
- x = self.ln_f(x)
325
-
326
- # (B, T, vocab_size)
327
- logits = self.lm_head(x)
328
-
329
- # Compute the loss
330
- if targets != None:
331
- # cross_entropy accepts inputs in a (batch_size, num_classes)
332
- # so we need to reformat our logits dimensions to
333
- # (batch_size * time, dim_vocabulary), time = block_size
334
- B, T, C = logits.shape
335
- logits = torch.reshape(logits, (B * T, C))
336
- targets = torch.reshape(targets, (B * T, ))
337
- loss = F.cross_entropy(logits, targets)
338
- else:
339
- loss = None
340
-
341
- return logits, loss
342
-
343
- def generate(self, idx: torch.Tensor, max_new_tokens: int, block_size: int):
344
- # idx is (B, T) array of indices in the current context
345
- for _ in range(max_new_tokens):
346
- # crop the context too the last block_size tokens
347
- # because tokens don't communicate between blocks
348
- idx_crop = idx[:, -block_size:]
349
- # get the predictions
350
- logits, loss = self.forward(idx_crop)
351
- # focus only on the last time step
352
- logits = logits[:, -1, :] # becomes (B, C)
353
- # apply softmax to get probabilities
354
- probs = F.softmax(logits, dim=-1) # (B, C)
355
- # sample from the distribution with probabilities probs
356
- idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
357
- # append sampled index to the running sequence
358
- idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
359
- return idx