matiusX commited on
Commit
589d5e7
1 Parent(s): 530cc54

Upload llama.py

Browse files
Files changed (1) hide show
  1. llama.py +365 -0
llama.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ import requests
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ import sentencepiece as spm
9
+ import random
10
+ from collections import OrderedDict
11
+ from matplotlib import pyplot as plt
12
+ import time
13
+
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ elif torch.backends.mps.is_available():
17
+ device = "mps"
18
+ else:
19
+ device = "cpu"
20
+
21
+ VOCAB_SIZE = 130
22
+ BATCH_SIZE = 32
23
+ CONTEXT_WINDOW = 16
24
+ EPOCHS = 1000
25
+ DIM = 128
26
+ LOG_INTERVAL = 10
27
+ HEADS = 8
28
+ LAYERS = 4
29
+
30
+ url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
31
+ response = requests.get(url)
32
+
33
+ if response.status_code == 200:
34
+ tinyshakespeare = response.text
35
+ else:
36
+ print(response.status_code)
37
+
38
+ tinyshakespeare_list = tinyshakespeare.split("\n")
39
+ tinyshakespeare_list = [i for i in tinyshakespeare_list if i != ""]
40
+
41
+ spm.SentencePieceTrainer.Train(
42
+ sentence_iterator = iter(tinyshakespeare_list),
43
+ model_prefix = "tinyshakespeare_model",
44
+ vocab_size = VOCAB_SIZE,
45
+ character_coverage = 1.0,
46
+ model_type = "bpe",
47
+ pad_id = 0,
48
+ unk_id = 1,
49
+ bos_id = 2,
50
+ eos_id = 3,
51
+ )
52
+
53
+ sp = spm.SentencePieceProcessor(model_file = "tinyshakespeare_model.model")
54
+ dataset_tensor = torch.tensor(sp.Encode(tinyshakespeare))
55
+
56
+ def get_batch_train(dataset, batch_size, context_window):
57
+ train_data = dataset[:int(.7 * len(dataset))]
58
+ ix = torch.randint(0, train_data.size(0) - context_window - 1, (batch_size,))
59
+ x = torch.stack([train_data[i:i+context_window] for i in ix]).long()
60
+ y = torch.stack([train_data[i+1:i+context_window+1] for i in ix]).long()
61
+ return x, y
62
+
63
+
64
+ def get_batch_val(dataset, batch_size, context_window):
65
+ val_data = dataset[int(.7 * len(dataset)): int(.85 * len(dataset))]
66
+ ix = torch.randint(0, val_data.size(0) - context_window - 1, (batch_size,))
67
+ x = torch.stack([val_data[i:i+context_window] for i in ix]).long()
68
+ y = torch.stack([val_data[i+1:i+context_window+1] for i in ix]).long()
69
+ return x, y
70
+
71
+ def get_batch_test(dataset, batch_size, context_window):
72
+ test_data = dataset[int(.85 * len(dataset)): len(dataset)]
73
+ ix = torch.randint(0, test_data.size(0) - context_window - 1, (batch_size,))
74
+ x = torch.stack([test_data[i:i+context_window] for i in ix]).long()
75
+ y = torch.stack([test_data[i+1:i+context_window+1] for i in ix]).long()
76
+ return x, y
77
+
78
+ @torch.no_grad()
79
+ def calculate_loss(model):
80
+ model.eval()
81
+ train_losses = []
82
+ val_losses = []
83
+ for i in range(EPOCHS):
84
+ #train evaluation
85
+ x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
86
+ _, train_loss = model(x_train, y_train)
87
+ train_losses.append(train_loss.item())
88
+
89
+ #val evaluation
90
+ x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
91
+ _, val_loss = model(x_val, y_val)
92
+ val_losses.append(val_loss.item())
93
+
94
+ losses_dict = {"train": np.mean(train_losses), "val": np.mean(val_losses)}
95
+ return losses_dict
96
+
97
+
98
+ @torch.no_grad()
99
+ def calculate_accuracy(model):
100
+ model.eval()
101
+ correct_predictions = 0
102
+ total_predictions = 0
103
+
104
+ for i in range(EPOCHS):
105
+ # Get a batch of validation data
106
+ x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
107
+
108
+ # Get model predictions
109
+ logits = model(x_val)
110
+
111
+ # Convert predictions to class labels
112
+ predicted_labels = torch.argmax(logits, dim=-1)
113
+
114
+ # Compare with true labels
115
+ correct_predictions += (predicted_labels == y_val).sum().item()
116
+ total_predictions += y_val.numel()
117
+
118
+ accuracy = correct_predictions / total_predictions
119
+ return accuracy
120
+
121
+ @torch.no_grad()
122
+ def calculate_perplexity(model):
123
+ model.eval()
124
+ val_losses = []
125
+
126
+ for i in range(EPOCHS):
127
+ # Get a batch of validation data
128
+ x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
129
+
130
+ # Get model predictions and loss
131
+ _, val_loss = model(x_val, y_val)
132
+ val_losses.append(val_loss.item())
133
+
134
+ # Calculate the mean validation loss
135
+ mean_val_loss = np.mean(val_losses)
136
+
137
+ # Perplexity is the exponential of the cross-entropy loss
138
+ perplexity = np.exp(mean_val_loss)
139
+ return perplexity
140
+
141
+ def train(model, optimizer, checkpoint_path="/checkpoints"):
142
+ losses = []
143
+ accs = []
144
+ perps = []
145
+ for epoch in range(EPOCHS):
146
+ optimizer.zero_grad()
147
+ x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
148
+ logits, loss = model(x_train, y_train)
149
+ loss.backward()
150
+ optimizer.step()
151
+
152
+ if epoch % LOG_INTERVAL == 0:
153
+ current_loss = calculate_loss(model)
154
+ current_accuracy = calculate_accuracy(model)
155
+ current_perplexity = calculate_perplexity(model)
156
+
157
+ losses.append(current_loss)
158
+ accs.append(current_accuracy)
159
+ perps.append(current_perplexity)
160
+
161
+ torch.save({
162
+ 'epoch': epoch,
163
+ 'model_state_dict': model.state_dict(),
164
+ 'optimizer_state_dict': optimizer.state_dict(),
165
+ 'loss': current_loss,
166
+ 'accuracy': current_accuracy,
167
+ 'perplexity': current_perplexity
168
+ }, f"{checkpoint_path}/checkpoint_epoch_{epoch}.pth")
169
+
170
+ print(f"Epoch {epoch}: Loss - {current_loss['val']}, Accuracy - {current_accuracy}, Perplexity - {current_perplexity}")
171
+
172
+
173
+ print("validation Loss: ", losses[-1]['val'])
174
+ print("validation Accuracy: ", accs[-1])
175
+ print("validation Perplexity: ", perps[-1])
176
+ return pd.DataFrame(losses).plot()
177
+
178
+ class RMSNorm(torch.nn.Module):
179
+ def __init__(self, layer_shape, eps=1e-8, bias=False):
180
+ super(RMSNorm, self).__init__()
181
+ self.register_parameter("scale", torch.nn.Parameter(torch.ones(layer_shape)))
182
+
183
+ def forward(self, x):
184
+ return self.scale[:x.shape[1], :].unsqueeze(0) * ((torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5).unsqueeze(-1).unsqueeze(-1))
185
+
186
+ def get_rotary_matrix(context_window, embedding_dim):
187
+ R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
188
+ for position in range(context_window):
189
+ for i in range(embedding_dim//2):
190
+ theta = 10000. ** (-2.*(i - 1) / embedding_dim)
191
+ m_theta = position * theta
192
+ R[position, 2*i,2*i] = np.cos(m_theta)
193
+ R[position, 2*i,2*i+1] = - np.sin(m_theta)
194
+ R[position, 2*i+1,2*i] = np.sin(m_theta)
195
+ R[position, 2*i+1,2*i+1] = np.cos(m_theta)
196
+ return R
197
+
198
+
199
+ class RoPEAttentionHead(nn.Module):
200
+ def __init__(self):
201
+ super().__init__()
202
+ self.w_q = nn.Linear(DIM, DIM, bias=False)
203
+ self.w_k = nn.Linear(DIM, DIM, bias=False)
204
+ self.w_v = nn.Linear(DIM, DIM, bias=False)
205
+
206
+ self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM)
207
+
208
+ def get_rotary_matrix(context_window, embedding_dim):
209
+ R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
210
+ for position in range(context_window):
211
+ for i in range(embedding_dim//2):
212
+ theta = 10000. ** (-2.*(i - 1) / embedding_dim)
213
+ m_theta = position * theta
214
+ R[position, 2*i,2*i] = np.cos(m_theta)
215
+ R[position, 2*i,2*i+1] = - np.sin(m_theta)
216
+ R[position, 2*i+1,2*i] = np.sin(m_theta)
217
+ R[position, 2*i+1,2*i+1] = np.cos(m_theta)
218
+ return R
219
+
220
+ def forward(self, x, return_attn_weights=False):
221
+ b,m,d = x.shape
222
+
223
+ q = self.w_q(x)
224
+ k = self.w_k(x)
225
+ v = self.w_v(x)
226
+
227
+ q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1)
228
+ k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
229
+
230
+ activations = F.scaled_dot_product_attention(
231
+ q_rotated,k_rotated,v,dropout_p =.1
232
+ )
233
+
234
+ if return_attn_weights:
235
+ attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d)
236
+ attn_weights = F.softmax(attn_weights, dim=-1)
237
+ return activations, attn_weights
238
+ return activations
239
+
240
+ class RoPEAttentionHead(nn.Module):
241
+ def __init__(self):
242
+ super().__init__()
243
+ self.w_q = nn.Linear(DIM, DIM, bias=False)
244
+ self.w_k = nn.Linear(DIM, DIM, bias=False)
245
+ self.w_v = nn.Linear(DIM, DIM, bias=False)
246
+
247
+ self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM)
248
+
249
+ def get_rotary_matrix(context_window, embedding_dim):
250
+ R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
251
+ for position in range(context_window):
252
+ for i in range(embedding_dim//2):
253
+ theta = 10000. ** (-2.*(i - 1) / embedding_dim)
254
+ m_theta = position * theta
255
+ R[position, 2*i,2*i] = np.cos(m_theta)
256
+ R[position, 2*i,2*i+1] = - np.sin(m_theta)
257
+ R[position, 2*i+1,2*i] = np.sin(m_theta)
258
+ R[position, 2*i+1,2*i+1] = np.cos(m_theta)
259
+ return R
260
+
261
+ def forward(self, x, return_attn_weights=False):
262
+ b,m,d = x.shape
263
+
264
+ q = self.w_q(x)
265
+ k = self.w_k(x)
266
+ v = self.w_v(x)
267
+
268
+ q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1)
269
+ k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
270
+
271
+ activations = F.scaled_dot_product_attention(
272
+ q_rotated,k_rotated,v,dropout_p =.1, is_causal=True
273
+ )
274
+
275
+ if return_attn_weights:
276
+ attn_mask = torch.tril(torch.ones((m,m)), diagonal=0)
277
+ attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
278
+ attn_weights = F.softmax(attn_weights, dim=-1)
279
+ return activations, attn_weights
280
+ return activations
281
+
282
+ class RoPEMultiheadAttention(nn.Module):
283
+ def __init__(self):
284
+ super().__init__()
285
+ self.heads = nn.ModuleList([
286
+ RoPEAttentionHead() for _ in range(HEADS)
287
+ ])
288
+ self.linear = nn.Linear(HEADS * DIM, DIM)
289
+ self.dropout = nn.Dropout(.1)
290
+
291
+ def forward(self, x):
292
+ heads = [h(x) for h in self.heads]
293
+ x = torch.cat(heads, dim=-1)
294
+ x = self.linear(x)
295
+ x = self.dropout(x)
296
+ return x
297
+
298
+
299
+ class SwiGLU(nn.Module):
300
+ def __init__(self, size):
301
+ super().__init__()
302
+ self.linear_gate = nn.Linear(size, size)
303
+ self.linear = nn.Linear(size, size)
304
+ self.beta = torch.randn(1, requires_grad=True)
305
+
306
+ self.beta = nn.Parameter(torch.ones(1))
307
+ self.register_parameter("beta", self.beta)
308
+
309
+ def forward(self, x):
310
+ swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x))
311
+ out = swish_gate * self.linear(x)
312
+ return out
313
+
314
+
315
+ class LlamaBlock(nn.Module):
316
+ def __init__(self):
317
+ super().__init__()
318
+
319
+ self.rms = RMSNorm((CONTEXT_WINDOW, DIM))
320
+
321
+ self.attention = RoPEMultiheadAttention()
322
+ self.feedforward = nn.Sequential(
323
+ nn.Linear(DIM, DIM),
324
+ SwiGLU(DIM),
325
+ )
326
+
327
+ def forward(self, x):
328
+ x = self.rms(x) #RMS NORMALIZATION
329
+ x = x + self.attention(x) #Self attention
330
+
331
+ x = self.rms(x) #RMS NORMALIZATION
332
+ x = x + self.feedforward(x) #Feed Foward: SwiGlu
333
+ return x
334
+
335
+ class Llama(nn.Module):
336
+ def __init__(self):
337
+ super().__init__()
338
+ self.embeddings = nn.Embedding(VOCAB_SIZE, DIM)
339
+ self.llama_blocks = nn.Sequential(
340
+ OrderedDict([(f"llama_{i}", LlamaBlock()) for i in range(LAYERS)])
341
+ )
342
+
343
+ self.ffn = nn.Sequential(
344
+ nn.Linear(DIM, DIM),
345
+ SwiGLU(DIM),
346
+ nn.Linear(DIM, VOCAB_SIZE),
347
+ )
348
+
349
+ print("model params:", sum([m.numel() for m in self.parameters()]))
350
+
351
+ def forward(self, idx, targets=None):
352
+ x = self.embeddings(idx)
353
+ x = self.llama_blocks(x)
354
+ logits = self.ffn(x)
355
+
356
+ if targets is None:
357
+ return logits
358
+ else:
359
+ loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
360
+ return logits, loss
361
+
362
+
363
+ llama = Llama()
364
+ optimizer = torch.optim.Adam(llama.parameters())
365
+ train(llama, optimizer)