pt-sk commited on
Commit
4081493
1 Parent(s): 67ab825

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ GPT-3[[:space:]]Small_shakespeare filter=lfs diff=lfs merge=lfs -text
GPT-3 Small_shakespeare ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57f62bad78627b3fded475e8f8766554ad9145e6ed94f975b69022b44f659d7d
3
+ size 1032960390
GPT_3_Small_shakespeare notebook.ipynb ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "VktNs2NoNiDt"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "import torch\n",
12
+ "import torch.nn as nn\n",
13
+ "from torch.nn import functional as F\n",
14
+ "from tqdm import tqdm"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "colab": {
22
+ "base_uri": "https://localhost:8080/"
23
+ },
24
+ "id": "3jOZxHu3NiDt",
25
+ "outputId": "8fe80f12-21b3-4b81-9388-5211f00f6848"
26
+ },
27
+ "outputs": [
28
+ {
29
+ "data": {
30
+ "text/plain": [
31
+ "<torch._C.Generator at 0x78c6be325b90>"
32
+ ]
33
+ },
34
+ "execution_count": 4,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "torch.manual_seed(1337)"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {
47
+ "id": "T-pNtwn5NiDu"
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "# hyperparameters\n",
52
+ "batch_size = 8 # how many independent sequences will we process in parallel?\n",
53
+ "block_size = 128 # what is the maximum context length for predictions?\n",
54
+ "max_iters = 100\n",
55
+ "eval_interval = 10\n",
56
+ "learning_rate = 6.0 * 10**-4\n",
57
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
58
+ "eval_iters = 200\n",
59
+ "n_embd = 768\n",
60
+ "n_head = 12\n",
61
+ "n_layer = 12\n",
62
+ "dropout = 0.25"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {
69
+ "id": "XYHnR6ETNiDu"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "with open(\"\", \"r\", encoding=\"utf-8\") as f:\n",
74
+ " text = f.read()\n",
75
+ "\n",
76
+ "chars = sorted(list(set(text)))\n",
77
+ "vocab_size = len(chars)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "id": "gJZbe7PyNiDu"
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "# create a mapping from characters to integers\n",
89
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
90
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
91
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
92
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {
99
+ "id": "OAN-qtdPNiDv"
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "# Train and test splits\n",
104
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
105
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
106
+ "train_data = data[:n]\n",
107
+ "val_data = data[n:]"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {
114
+ "id": "wZ70_NY-NiDv"
115
+ },
116
+ "outputs": [],
117
+ "source": [
118
+ "# data loading\n",
119
+ "def get_batch(split):\n",
120
+ " # generate a small batch of data of inputs x and targets y\n",
121
+ " data = train_data if split == 'train' else val_data\n",
122
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
123
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
124
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
125
+ " x, y = x.to(device), y.to(device)\n",
126
+ " return x, y\n",
127
+ "\n",
128
+ "@torch.no_grad()\n",
129
+ "def estimate_loss(model):\n",
130
+ " out = {}\n",
131
+ " model.eval()\n",
132
+ " for split in ['val']:\n",
133
+ " losses = torch.zeros(eval_iters)\n",
134
+ " for k in range(eval_iters):\n",
135
+ " X, Y = get_batch(split)\n",
136
+ " logits, loss = model(X, Y)\n",
137
+ " losses[k] = loss.item()\n",
138
+ " out[split] = losses.mean()\n",
139
+ " model.train()\n",
140
+ " return out"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {
147
+ "id": "KgnNQJUENiDv"
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "class Head(nn.Module):\n",
152
+ " \"\"\" one head of self-attention \"\"\"\n",
153
+ "\n",
154
+ " def __init__(self, head_size):\n",
155
+ " super().__init__()\n",
156
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
157
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
158
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
159
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
160
+ "\n",
161
+ " self.dropout = nn.Dropout(dropout)\n",
162
+ "\n",
163
+ " def forward(self, x):\n",
164
+ " B,T,C = x.shape\n",
165
+ " k = self.key(x) # (B,T,C)\n",
166
+ " q = self.query(x) # (B,T,C)\n",
167
+ " # compute attention scores (\"affinities\")\n",
168
+ " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
169
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
170
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
171
+ " wei = self.dropout(wei)\n",
172
+ " # perform the weighted aggregation of the values\n",
173
+ " v = self.value(x) # (B,T,C)\n",
174
+ " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
175
+ " return out\n",
176
+ "\n",
177
+ "class MultiHeadAttention(nn.Module):\n",
178
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
179
+ "\n",
180
+ " def __init__(self, num_heads, head_size):\n",
181
+ " super().__init__()\n",
182
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
183
+ " self.proj = nn.Linear(n_embd, n_embd)\n",
184
+ " self.dropout = nn.Dropout(dropout)\n",
185
+ "\n",
186
+ " def forward(self, x):\n",
187
+ " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
188
+ " out = self.dropout(self.proj(out))\n",
189
+ " return out\n",
190
+ "\n",
191
+ "class FeedFoward(nn.Module):\n",
192
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
193
+ "\n",
194
+ " def __init__(self, n_embd):\n",
195
+ " super().__init__()\n",
196
+ " self.net = nn.Sequential(\n",
197
+ " nn.Linear(n_embd, 4 * n_embd),\n",
198
+ " nn.ReLU(),\n",
199
+ " nn.Linear(4 * n_embd, n_embd),\n",
200
+ " nn.Dropout(dropout),\n",
201
+ " )\n",
202
+ "\n",
203
+ " def forward(self, x):\n",
204
+ " return self.net(x)\n",
205
+ "\n",
206
+ "class Block(nn.Module):\n",
207
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
208
+ "\n",
209
+ " def __init__(self, n_embd, n_head):\n",
210
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
211
+ " super().__init__()\n",
212
+ " head_size = n_embd // n_head\n",
213
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
214
+ " self.ffwd = FeedFoward(n_embd)\n",
215
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
216
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
217
+ "\n",
218
+ " def forward(self, x):\n",
219
+ " x = x + self.sa(self.ln1(x))\n",
220
+ " x = x + self.ffwd(self.ln2(x))\n",
221
+ " return x\n",
222
+ "\n",
223
+ "# super simple bigram model\n",
224
+ "class BigramLanguageModel(nn.Module):\n",
225
+ "\n",
226
+ " def __init__(self):\n",
227
+ " super().__init__()\n",
228
+ " # each token directly reads off the logits for the next token from a lookup table\n",
229
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
230
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
231
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
232
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
233
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
234
+ "\n",
235
+ " def forward(self, idx, targets=None):\n",
236
+ " B, T = idx.shape\n",
237
+ "\n",
238
+ " # idx and targets are both (B,T) tensor of integers\n",
239
+ " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
240
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
241
+ " x = tok_emb + pos_emb # (B,T,C)\n",
242
+ " x = self.blocks(x) # (B,T,C)\n",
243
+ " x = self.ln_f(x) # (B,T,C)\n",
244
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
245
+ "\n",
246
+ " if targets is None:\n",
247
+ " loss = None\n",
248
+ " else:\n",
249
+ " B, T, C = logits.shape\n",
250
+ " logits = logits.view(B*T, C)\n",
251
+ " targets = targets.view(B*T)\n",
252
+ " loss = F.cross_entropy(logits, targets)\n",
253
+ "\n",
254
+ " return logits, loss\n",
255
+ "\n",
256
+ " def generate(self, idx, max_new_tokens):\n",
257
+ " # idx is (B, T) array of indices in the current context\n",
258
+ " for _ in range(max_new_tokens):\n",
259
+ " # crop idx to the last block_size tokens\n",
260
+ " idx_cond = idx[:, -block_size:]\n",
261
+ " # get the predictions\n",
262
+ " logits, loss = self(idx_cond)\n",
263
+ " # focus only on the last time step\n",
264
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
265
+ " # apply softmax to get probabilities\n",
266
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
267
+ " # sample from the distribution\n",
268
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
269
+ " # append sampled index to the running sequence\n",
270
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
271
+ " return idx"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {
278
+ "colab": {
279
+ "base_uri": "https://localhost:8080/"
280
+ },
281
+ "id": "2J-6tLksNiDv",
282
+ "outputId": "dc261c06-4699-45d6-883d-c94829e06e7c"
283
+ },
284
+ "outputs": [
285
+ {
286
+ "name": "stdout",
287
+ "output_type": "stream",
288
+ "text": [
289
+ "85.226561 M parameters\n"
290
+ ]
291
+ }
292
+ ],
293
+ "source": [
294
+ "model = BigramLanguageModel().to(device)\n",
295
+ "# print the number of parameters in the model\n",
296
+ "print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "metadata": {
303
+ "id": "mJ_twWqrNiDw"
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "# create a pytorch optimizer\n",
308
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": null,
314
+ "metadata": {
315
+ "id": "lyVGZiRPPKuy"
316
+ },
317
+ "outputs": [],
318
+ "source": [
319
+ "state = torch.load(\"\")\n",
320
+ "model.load_state_dict(state[\"model_state_dict\"])\n",
321
+ "optimizer.load_state_dict(state[\"optimizer_state_dict\"])\n",
322
+ "max_iters = state[\"epoch\"]"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "metadata": {
329
+ "id": "OQaLj9z3NiDw"
330
+ },
331
+ "outputs": [],
332
+ "source": [
333
+ "# train\n",
334
+ "iterator = tqdm(range(max_iters), desc=\"Training\", postfix={\"train_loss\": 0.0})\n",
335
+ "\n",
336
+ "for iter in iterator:\n",
337
+ "\n",
338
+ " # sample a batch of data\n",
339
+ " xb, yb = get_batch('train')\n",
340
+ "\n",
341
+ " # evaluate the loss\n",
342
+ " logits, loss = model(xb, yb)\n",
343
+ " val_loss = estimate_loss(model)[\"val\"]\n",
344
+ "\n",
345
+ " optimizer.zero_grad(set_to_none=True)\n",
346
+ " loss.backward()\n",
347
+ " optimizer.step()\n",
348
+ "\n",
349
+ " # Update the postfix with current train loss\n",
350
+ " iterator.set_postfix({\"train_loss\": loss.item(), \"val_loss\": val_loss.item()}, refresh=False)"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "metadata": {
357
+ "id": "byBjpL1f5gog"
358
+ },
359
+ "outputs": [],
360
+ "source": [
361
+ "torch.save({\n",
362
+ " \"epoch\": \"\",\n",
363
+ " \"model_state_dict\": model.state_dict(),\n",
364
+ " \"optimizer_state_dict\": optimizer.state_dict(),\n",
365
+ "}, \"\")"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "metadata": {
372
+ "colab": {
373
+ "background_save": true
374
+ },
375
+ "id": "SV7zpB87NiDw"
376
+ },
377
+ "outputs": [],
378
+ "source": [
379
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
380
+ "print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))"
381
+ ]
382
+ }
383
+ ],
384
+ "metadata": {
385
+ "accelerator": "GPU",
386
+ "colab": {
387
+ "gpuType": "T4",
388
+ "provenance": []
389
+ },
390
+ "kernelspec": {
391
+ "display_name": "Python 3",
392
+ "name": "python3"
393
+ },
394
+ "language_info": {
395
+ "codemirror_mode": {
396
+ "name": "ipython",
397
+ "version": 3
398
+ },
399
+ "file_extension": ".py",
400
+ "mimetype": "text/x-python",
401
+ "name": "python",
402
+ "nbconvert_exporter": "python",
403
+ "pygments_lexer": "ipython3",
404
+ "version": "3.11.8"
405
+ }
406
+ },
407
+ "nbformat": 4,
408
+ "nbformat_minor": 0
409
+ }
dataset.txt ADDED
The diff for this file is too large to render. See raw diff
 
gpt_dev.ipynb ADDED
@@ -0,0 +1,1555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "## Building a GPT\n",
21
+ "\n",
22
+ "Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT."
23
+ ],
24
+ "metadata": {
25
+ "id": "wJpXpmjEYC_T"
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {
32
+ "colab": {
33
+ "base_uri": "https://localhost:8080/"
34
+ },
35
+ "id": "h5hjCcLDr2WC",
36
+ "outputId": "ccc60f0c-fd78-4dbe-8598-0512d1036aad"
37
+ },
38
+ "outputs": [
39
+ {
40
+ "output_type": "stream",
41
+ "name": "stdout",
42
+ "text": [
43
+ "--2023-01-17 01:39:27-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
44
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
45
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
46
+ "HTTP request sent, awaiting response... 200 OK\n",
47
+ "Length: 1115394 (1.1M) [text/plain]\n",
48
+ "Saving to: ‘input.txt’\n",
49
+ "\n",
50
+ "input.txt 100%[===================>] 1.06M --.-KB/s in 0.04s \n",
51
+ "\n",
52
+ "2023-01-17 01:39:28 (29.0 MB/s) - ‘input.txt’ saved [1115394/1115394]\n",
53
+ "\n"
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
59
+ "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "source": [
65
+ "# read it in to inspect it\n",
66
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
67
+ " text = f.read()"
68
+ ],
69
+ "metadata": {
70
+ "id": "O6medjfRsLD9"
71
+ },
72
+ "execution_count": null,
73
+ "outputs": []
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "source": [
78
+ "print(\"length of dataset in characters: \", len(text))"
79
+ ],
80
+ "metadata": {
81
+ "colab": {
82
+ "base_uri": "https://localhost:8080/"
83
+ },
84
+ "id": "6xWI_VyAsN8F",
85
+ "outputId": "ed819dd0-72e5-40a6-d2ed-928ff73bfda6"
86
+ },
87
+ "execution_count": null,
88
+ "outputs": [
89
+ {
90
+ "output_type": "stream",
91
+ "name": "stdout",
92
+ "text": [
93
+ "length of dataset in characters: 1115394\n"
94
+ ]
95
+ }
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "source": [
101
+ "# let's look at the first 1000 characters\n",
102
+ "print(text[:1000])"
103
+ ],
104
+ "metadata": {
105
+ "colab": {
106
+ "base_uri": "https://localhost:8080/"
107
+ },
108
+ "id": "2c5V0FvqseE0",
109
+ "outputId": "25ca7adc-b8c0-42d1-b08c-e0863c5c314e"
110
+ },
111
+ "execution_count": null,
112
+ "outputs": [
113
+ {
114
+ "output_type": "stream",
115
+ "name": "stdout",
116
+ "text": [
117
+ "First Citizen:\n",
118
+ "Before we proceed any further, hear me speak.\n",
119
+ "\n",
120
+ "All:\n",
121
+ "Speak, speak.\n",
122
+ "\n",
123
+ "First Citizen:\n",
124
+ "You are all resolved rather to die than to famish?\n",
125
+ "\n",
126
+ "All:\n",
127
+ "Resolved. resolved.\n",
128
+ "\n",
129
+ "First Citizen:\n",
130
+ "First, you know Caius Marcius is chief enemy to the people.\n",
131
+ "\n",
132
+ "All:\n",
133
+ "We know't, we know't.\n",
134
+ "\n",
135
+ "First Citizen:\n",
136
+ "Let us kill him, and we'll have corn at our own price.\n",
137
+ "Is't a verdict?\n",
138
+ "\n",
139
+ "All:\n",
140
+ "No more talking on't; let it be done: away, away!\n",
141
+ "\n",
142
+ "Second Citizen:\n",
143
+ "One word, good citizens.\n",
144
+ "\n",
145
+ "First Citizen:\n",
146
+ "We are accounted poor citizens, the patricians good.\n",
147
+ "What authority surfeits on would relieve us: if they\n",
148
+ "would yield us but the superfluity, while it were\n",
149
+ "wholesome, we might guess they relieved us humanely;\n",
150
+ "but they think we are too dear: the leanness that\n",
151
+ "afflicts us, the object of our misery, is as an\n",
152
+ "inventory to particularise their abundance; our\n",
153
+ "sufferance is a gain to them Let us revenge this with\n",
154
+ "our pikes, ere we become rakes: for the gods know I\n",
155
+ "speak this in hunger for bread, not in thirst for revenge.\n",
156
+ "\n",
157
+ "\n"
158
+ ]
159
+ }
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "source": [
165
+ "# here are all the unique characters that occur in this text\n",
166
+ "chars = sorted(list(set(text)))\n",
167
+ "vocab_size = len(chars)\n",
168
+ "print(''.join(chars))\n",
169
+ "print(vocab_size)"
170
+ ],
171
+ "metadata": {
172
+ "colab": {
173
+ "base_uri": "https://localhost:8080/"
174
+ },
175
+ "id": "0e-Rbyr8sfM8",
176
+ "outputId": "f34e94a9-5b44-4cf3-885b-986731929109"
177
+ },
178
+ "execution_count": null,
179
+ "outputs": [
180
+ {
181
+ "output_type": "stream",
182
+ "name": "stdout",
183
+ "text": [
184
+ "\n",
185
+ " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
186
+ "65\n"
187
+ ]
188
+ }
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "# create a mapping from characters to integers\n",
195
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
196
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
197
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
198
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
199
+ "\n",
200
+ "print(encode(\"hii there\"))\n",
201
+ "print(decode(encode(\"hii there\")))"
202
+ ],
203
+ "metadata": {
204
+ "colab": {
205
+ "base_uri": "https://localhost:8080/"
206
+ },
207
+ "id": "Yw1LKNCgwjj1",
208
+ "outputId": "86fcc21c-2cf7-40d9-cd7b-b5a253da4459"
209
+ },
210
+ "execution_count": null,
211
+ "outputs": [
212
+ {
213
+ "output_type": "stream",
214
+ "name": "stdout",
215
+ "text": [
216
+ "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
217
+ "hii there\n"
218
+ ]
219
+ }
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "source": [
225
+ "# let's now encode the entire text dataset and store it into a torch.Tensor\n",
226
+ "import torch # we use PyTorch: https://pytorch.org\n",
227
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
228
+ "print(data.shape, data.dtype)\n",
229
+ "print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this"
230
+ ],
231
+ "metadata": {
232
+ "colab": {
233
+ "base_uri": "https://localhost:8080/"
234
+ },
235
+ "id": "YJb0OXPwzvqg",
236
+ "outputId": "db7297cc-36a9-4fae-e941-e7bb9e0e91d1"
237
+ },
238
+ "execution_count": null,
239
+ "outputs": [
240
+ {
241
+ "output_type": "stream",
242
+ "name": "stdout",
243
+ "text": [
244
+ "torch.Size([1115394]) torch.int64\n",
245
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
246
+ " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
247
+ " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
248
+ " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
249
+ " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
250
+ " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50,\n",
251
+ " 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58,\n",
252
+ " 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47,\n",
253
+ " 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42,\n",
254
+ " 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58,\n",
255
+ " 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63,\n",
256
+ " 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41,\n",
257
+ " 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63,\n",
258
+ " 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13,\n",
259
+ " 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 58, 6, 1, 61, 43, 1,\n",
260
+ " 49, 52, 53, 61, 5, 58, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58,\n",
261
+ " 47, 64, 43, 52, 10, 0, 24, 43, 58, 1, 59, 57, 1, 49, 47, 50, 50, 1,\n",
262
+ " 46, 47, 51, 6, 1, 39, 52, 42, 1, 61, 43, 5, 50, 50, 1, 46, 39, 60,\n",
263
+ " 43, 1, 41, 53, 56, 52, 1, 39, 58, 1, 53, 59, 56, 1, 53, 61, 52, 1,\n",
264
+ " 54, 56, 47, 41, 43, 8, 0, 21, 57, 5, 58, 1, 39, 1, 60, 43, 56, 42,\n",
265
+ " 47, 41, 58, 12, 0, 0, 13, 50, 50, 10, 0, 26, 53, 1, 51, 53, 56, 43,\n",
266
+ " 1, 58, 39, 50, 49, 47, 52, 45, 1, 53, 52, 5, 58, 11, 1, 50, 43, 58,\n",
267
+ " 1, 47, 58, 1, 40, 43, 1, 42, 53, 52, 43, 10, 1, 39, 61, 39, 63, 6,\n",
268
+ " 1, 39, 61, 39, 63, 2, 0, 0, 31, 43, 41, 53, 52, 42, 1, 15, 47, 58,\n",
269
+ " 47, 64, 43, 52, 10, 0, 27, 52, 43, 1, 61, 53, 56, 42, 6, 1, 45, 53,\n",
270
+ " 53, 42, 1, 41, 47, 58, 47, 64, 43, 52, 57, 8, 0, 0, 18, 47, 56, 57,\n",
271
+ " 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 35, 43, 1, 39, 56, 43, 1,\n",
272
+ " 39, 41, 41, 53, 59, 52, 58, 43, 42, 1, 54, 53, 53, 56, 1, 41, 47, 58,\n",
273
+ " 47, 64, 43, 52, 57, 6, 1, 58, 46, 43, 1, 54, 39, 58, 56, 47, 41, 47,\n",
274
+ " 39, 52, 57, 1, 45, 53, 53, 42, 8, 0, 35, 46, 39, 58, 1, 39, 59, 58,\n",
275
+ " 46, 53, 56, 47, 58, 63, 1, 57, 59, 56, 44, 43, 47, 58, 57, 1, 53, 52,\n",
276
+ " 1, 61, 53, 59, 50, 42, 1, 56, 43, 50, 47, 43, 60, 43, 1, 59, 57, 10,\n",
277
+ " 1, 47, 44, 1, 58, 46, 43, 63, 0, 61, 53, 59, 50, 42, 1, 63, 47, 43,\n",
278
+ " 50, 42, 1, 59, 57, 1, 40, 59, 58, 1, 58, 46, 43, 1, 57, 59, 54, 43,\n",
279
+ " 56, 44, 50, 59, 47, 58, 63, 6, 1, 61, 46, 47, 50, 43, 1, 47, 58, 1,\n",
280
+ " 61, 43, 56, 43, 0, 61, 46, 53, 50, 43, 57, 53, 51, 43, 6, 1, 61, 43,\n",
281
+ " 1, 51, 47, 45, 46, 58, 1, 45, 59, 43, 57, 57, 1, 58, 46, 43, 63, 1,\n",
282
+ " 56, 43, 50, 47, 43, 60, 43, 42, 1, 59, 57, 1, 46, 59, 51, 39, 52, 43,\n",
283
+ " 50, 63, 11, 0, 40, 59, 58, 1, 58, 46, 43, 63, 1, 58, 46, 47, 52, 49,\n",
284
+ " 1, 61, 43, 1, 39, 56, 43, 1, 58, 53, 53, 1, 42, 43, 39, 56, 10, 1,\n",
285
+ " 58, 46, 43, 1, 50, 43, 39, 52, 52, 43, 57, 57, 1, 58, 46, 39, 58, 0,\n",
286
+ " 39, 44, 44, 50, 47, 41, 58, 57, 1, 59, 57, 6, 1, 58, 46, 43, 1, 53,\n",
287
+ " 40, 48, 43, 41, 58, 1, 53, 44, 1, 53, 59, 56, 1, 51, 47, 57, 43, 56,\n",
288
+ " 63, 6, 1, 47, 57, 1, 39, 57, 1, 39, 52, 0, 47, 52, 60, 43, 52, 58,\n",
289
+ " 53, 56, 63, 1, 58, 53, 1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,\n",
290
+ " 57, 43, 1, 58, 46, 43, 47, 56, 1, 39, 40, 59, 52, 42, 39, 52, 41, 43,\n",
291
+ " 11, 1, 53, 59, 56, 0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43, 1, 47,\n",
292
+ " 57, 1, 39, 1, 45, 39, 47, 52, 1, 58, 53, 1, 58, 46, 43, 51, 1, 24,\n",
293
+ " 43, 58, 1, 59, 57, 1, 56, 43, 60, 43, 52, 45, 43, 1, 58, 46, 47, 57,\n",
294
+ " 1, 61, 47, 58, 46, 0, 53, 59, 56, 1, 54, 47, 49, 43, 57, 6, 1, 43,\n",
295
+ " 56, 43, 1, 61, 43, 1, 40, 43, 41, 53, 51, 43, 1, 56, 39, 49, 43, 57,\n",
296
+ " 10, 1, 44, 53, 56, 1, 58, 46, 43, 1, 45, 53, 42, 57, 1, 49, 52, 53,\n",
297
+ " 61, 1, 21, 0, 57, 54, 43, 39, 49, 1, 58, 46, 47, 57, 1, 47, 52, 1,\n",
298
+ " 46, 59, 52, 45, 43, 56, 1, 44, 53, 56, 1, 40, 56, 43, 39, 42, 6, 1,\n",
299
+ " 52, 53, 58, 1, 47, 52, 1, 58, 46, 47, 56, 57, 58, 1, 44, 53, 56, 1,\n",
300
+ " 56, 43, 60, 43, 52, 45, 43, 8, 0, 0])\n"
301
+ ]
302
+ }
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "source": [
308
+ "# Let's now split up the data into train and validation sets\n",
309
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
310
+ "train_data = data[:n]\n",
311
+ "val_data = data[n:]"
312
+ ],
313
+ "metadata": {
314
+ "id": "f_WIXqxz0lU5"
315
+ },
316
+ "execution_count": null,
317
+ "outputs": []
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "source": [
322
+ "block_size = 8\n",
323
+ "train_data[:block_size+1]"
324
+ ],
325
+ "metadata": {
326
+ "colab": {
327
+ "base_uri": "https://localhost:8080/"
328
+ },
329
+ "id": "TD5Bj8Y6IAD4",
330
+ "outputId": "bf23c586-1d33-4af1-b63d-ce6f90b0a528"
331
+ },
332
+ "execution_count": null,
333
+ "outputs": [
334
+ {
335
+ "output_type": "execute_result",
336
+ "data": {
337
+ "text/plain": [
338
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])"
339
+ ]
340
+ },
341
+ "metadata": {},
342
+ "execution_count": 9
343
+ }
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "source": [
349
+ "x = train_data[:block_size]\n",
350
+ "y = train_data[1:block_size+1]\n",
351
+ "for t in range(block_size):\n",
352
+ " context = x[:t+1]\n",
353
+ " target = y[t]\n",
354
+ " print(f\"when input is {context} the target: {target}\")"
355
+ ],
356
+ "metadata": {
357
+ "colab": {
358
+ "base_uri": "https://localhost:8080/"
359
+ },
360
+ "id": "9HXDe8vGJCEn",
361
+ "outputId": "588663aa-1de5-4ef7-aba0-4a96fe828353"
362
+ },
363
+ "execution_count": null,
364
+ "outputs": [
365
+ {
366
+ "output_type": "stream",
367
+ "name": "stdout",
368
+ "text": [
369
+ "when input is tensor([18]) the target: 47\n",
370
+ "when input is tensor([18, 47]) the target: 56\n",
371
+ "when input is tensor([18, 47, 56]) the target: 57\n",
372
+ "when input is tensor([18, 47, 56, 57]) the target: 58\n",
373
+ "when input is tensor([18, 47, 56, 57, 58]) the target: 1\n",
374
+ "when input is tensor([18, 47, 56, 57, 58, 1]) the target: 15\n",
375
+ "when input is tensor([18, 47, 56, 57, 58, 1, 15]) the target: 47\n",
376
+ "when input is tensor([18, 47, 56, 57, 58, 1, 15, 47]) the target: 58\n"
377
+ ]
378
+ }
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "source": [
384
+ "torch.manual_seed(1337)\n",
385
+ "batch_size = 4 # how many independent sequences will we process in parallel?\n",
386
+ "block_size = 8 # what is the maximum context length for predictions?\n",
387
+ "\n",
388
+ "def get_batch(split):\n",
389
+ " # generate a small batch of data of inputs x and targets y\n",
390
+ " data = train_data if split == 'train' else val_data\n",
391
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
392
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
393
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
394
+ " return x, y\n",
395
+ "\n",
396
+ "xb, yb = get_batch('train')\n",
397
+ "print('inputs:')\n",
398
+ "print(xb.shape)\n",
399
+ "print(xb)\n",
400
+ "print('targets:')\n",
401
+ "print(yb.shape)\n",
402
+ "print(yb)\n",
403
+ "\n",
404
+ "print('----')\n",
405
+ "\n",
406
+ "for b in range(batch_size): # batch dimension\n",
407
+ " for t in range(block_size): # time dimension\n",
408
+ " context = xb[b, :t+1]\n",
409
+ " target = yb[b,t]\n",
410
+ " print(f\"when input is {context.tolist()} the target: {target}\")"
411
+ ],
412
+ "metadata": {
413
+ "colab": {
414
+ "base_uri": "https://localhost:8080/"
415
+ },
416
+ "id": "Q3k1Czf7LuA9",
417
+ "outputId": "4ea8e8a0-443c-49bb-b3bf-ba36e1712999"
418
+ },
419
+ "execution_count": null,
420
+ "outputs": [
421
+ {
422
+ "output_type": "stream",
423
+ "name": "stdout",
424
+ "text": [
425
+ "inputs:\n",
426
+ "torch.Size([4, 8])\n",
427
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
428
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
429
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
430
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n",
431
+ "targets:\n",
432
+ "torch.Size([4, 8])\n",
433
+ "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
434
+ " [53, 56, 1, 58, 46, 39, 58, 1],\n",
435
+ " [58, 1, 58, 46, 39, 58, 1, 46],\n",
436
+ " [17, 27, 10, 0, 21, 1, 54, 39]])\n",
437
+ "----\n",
438
+ "when input is [24] the target: 43\n",
439
+ "when input is [24, 43] the target: 58\n",
440
+ "when input is [24, 43, 58] the target: 5\n",
441
+ "when input is [24, 43, 58, 5] the target: 57\n",
442
+ "when input is [24, 43, 58, 5, 57] the target: 1\n",
443
+ "when input is [24, 43, 58, 5, 57, 1] the target: 46\n",
444
+ "when input is [24, 43, 58, 5, 57, 1, 46] the target: 43\n",
445
+ "when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39\n",
446
+ "when input is [44] the target: 53\n",
447
+ "when input is [44, 53] the target: 56\n",
448
+ "when input is [44, 53, 56] the target: 1\n",
449
+ "when input is [44, 53, 56, 1] the target: 58\n",
450
+ "when input is [44, 53, 56, 1, 58] the target: 46\n",
451
+ "when input is [44, 53, 56, 1, 58, 46] the target: 39\n",
452
+ "when input is [44, 53, 56, 1, 58, 46, 39] the target: 58\n",
453
+ "when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1\n",
454
+ "when input is [52] the target: 58\n",
455
+ "when input is [52, 58] the target: 1\n",
456
+ "when input is [52, 58, 1] the target: 58\n",
457
+ "when input is [52, 58, 1, 58] the target: 46\n",
458
+ "when input is [52, 58, 1, 58, 46] the target: 39\n",
459
+ "when input is [52, 58, 1, 58, 46, 39] the target: 58\n",
460
+ "when input is [52, 58, 1, 58, 46, 39, 58] the target: 1\n",
461
+ "when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46\n",
462
+ "when input is [25] the target: 17\n",
463
+ "when input is [25, 17] the target: 27\n",
464
+ "when input is [25, 17, 27] the target: 10\n",
465
+ "when input is [25, 17, 27, 10] the target: 0\n",
466
+ "when input is [25, 17, 27, 10, 0] the target: 21\n",
467
+ "when input is [25, 17, 27, 10, 0, 21] the target: 1\n",
468
+ "when input is [25, 17, 27, 10, 0, 21, 1] the target: 54\n",
469
+ "when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39\n"
470
+ ]
471
+ }
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "source": [
477
+ "print(xb) # our input to the transformer"
478
+ ],
479
+ "metadata": {
480
+ "colab": {
481
+ "base_uri": "https://localhost:8080/"
482
+ },
483
+ "id": "qpyyAeIzQjlO",
484
+ "outputId": "a650f8dc-da81-400b-bc59-0a595487fdb9"
485
+ },
486
+ "execution_count": null,
487
+ "outputs": [
488
+ {
489
+ "output_type": "stream",
490
+ "name": "stdout",
491
+ "text": [
492
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
493
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
494
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
495
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n"
496
+ ]
497
+ }
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "source": [
503
+ "import torch\n",
504
+ "import torch.nn as nn\n",
505
+ "from torch.nn import functional as F\n",
506
+ "torch.manual_seed(1337)\n",
507
+ "\n",
508
+ "class BigramLanguageModel(nn.Module):\n",
509
+ "\n",
510
+ " def __init__(self, vocab_size):\n",
511
+ " super().__init__()\n",
512
+ " # each token directly reads off the logits for the next token from a lookup table\n",
513
+ " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
514
+ "\n",
515
+ " def forward(self, idx, targets=None):\n",
516
+ "\n",
517
+ " # idx and targets are both (B,T) tensor of integers\n",
518
+ " logits = self.token_embedding_table(idx) # (B,T,C)\n",
519
+ "\n",
520
+ " if targets is None:\n",
521
+ " loss = None\n",
522
+ " else:\n",
523
+ " B, T, C = logits.shape\n",
524
+ " logits = logits.view(B*T, C)\n",
525
+ " targets = targets.view(B*T)\n",
526
+ " loss = F.cross_entropy(logits, targets)\n",
527
+ "\n",
528
+ " return logits, loss\n",
529
+ "\n",
530
+ " def generate(self, idx, max_new_tokens):\n",
531
+ " # idx is (B, T) array of indices in the current context\n",
532
+ " for _ in range(max_new_tokens):\n",
533
+ " # get the predictions\n",
534
+ " logits, loss = self(idx)\n",
535
+ " # focus only on the last time step\n",
536
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
537
+ " # apply softmax to get probabilities\n",
538
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
539
+ " # sample from the distribution\n",
540
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
541
+ " # append sampled index to the running sequence\n",
542
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
543
+ " return idx\n",
544
+ "\n",
545
+ "m = BigramLanguageModel(vocab_size)\n",
546
+ "logits, loss = m(xb, yb)\n",
547
+ "print(logits.shape)\n",
548
+ "print(loss)\n",
549
+ "\n",
550
+ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))\n"
551
+ ],
552
+ "metadata": {
553
+ "colab": {
554
+ "base_uri": "https://localhost:8080/"
555
+ },
556
+ "id": "nql_1ER53oCf",
557
+ "outputId": "5de90b1b-4603-428a-f571-fe4bd3c45436"
558
+ },
559
+ "execution_count": null,
560
+ "outputs": [
561
+ {
562
+ "output_type": "stream",
563
+ "name": "stdout",
564
+ "text": [
565
+ "torch.Size([32, 65])\n",
566
+ "tensor(4.8786, grad_fn=<NllLossBackward0>)\n",
567
+ "\n",
568
+ "SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp\n",
569
+ "wnYWmnxKWWev-tDqXErVKLgJ\n"
570
+ ]
571
+ }
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "source": [
577
+ "# create a PyTorch optimizer\n",
578
+ "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
579
+ ],
580
+ "metadata": {
581
+ "id": "eTyJ8qAaDdiF"
582
+ },
583
+ "execution_count": null,
584
+ "outputs": []
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "source": [
589
+ "batch_size = 32\n",
590
+ "for steps in range(100): # increase number of steps for good results...\n",
591
+ "\n",
592
+ " # sample a batch of data\n",
593
+ " xb, yb = get_batch('train')\n",
594
+ "\n",
595
+ " # evaluate the loss\n",
596
+ " logits, loss = m(xb, yb)\n",
597
+ " optimizer.zero_grad(set_to_none=True)\n",
598
+ " loss.backward()\n",
599
+ " optimizer.step()\n",
600
+ "\n",
601
+ "print(loss.item())\n"
602
+ ],
603
+ "metadata": {
604
+ "colab": {
605
+ "base_uri": "https://localhost:8080/"
606
+ },
607
+ "id": "Hs4kI8YdEkQj",
608
+ "outputId": "42ded55c-2983-4d91-c528-675b2edfa849"
609
+ },
610
+ "execution_count": null,
611
+ "outputs": [
612
+ {
613
+ "output_type": "stream",
614
+ "name": "stdout",
615
+ "text": [
616
+ "4.65630578994751\n"
617
+ ]
618
+ }
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "source": [
624
+ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))"
625
+ ],
626
+ "metadata": {
627
+ "colab": {
628
+ "base_uri": "https://localhost:8080/"
629
+ },
630
+ "id": "EcVIDWAZEtjN",
631
+ "outputId": "0ad6f9d2-ad58-4498-a5f8-6f31407bb18b"
632
+ },
633
+ "execution_count": null,
634
+ "outputs": [
635
+ {
636
+ "output_type": "stream",
637
+ "name": "stdout",
638
+ "text": [
639
+ "\n",
640
+ "oTo.JUZ!!zqe!\n",
641
+ "xBP qbs$Gy'AcOmrLwwt\n",
642
+ "p$x;Seh-onQbfM?OjKbn'NwUAW -Np3fkz$FVwAUEa-wzWC -wQo-R!v -Mj?,SPiTyZ;o-opr$mOiPJEYD-CfigkzD3p3?zvS;ADz;.y?o,ivCuC'zqHxcVT cHA\n",
643
+ "rT'Fd,SBMZyOslg!NXeF$sBe,juUzLq?w-wzP-h\n",
644
+ "ERjjxlgJzPbHxf$ q,q,KCDCU fqBOQT\n",
645
+ "SV&CW:xSVwZv'DG'NSPypDhKStKzC -$hslxIVzoivnp ,ethA:NCCGoi\n",
646
+ "tN!ljjP3fwJMwNelgUzzPGJlgihJ!d?q.d\n",
647
+ "pSPYgCuCJrIFtb\n",
648
+ "jQXg\n",
649
+ "pA.P LP,SPJi\n",
650
+ "DBcuBM:CixjJ$Jzkq,OLf3KLQLMGph$O 3DfiPHnXKuHMlyjxEiyZib3FaHV-oJa!zoc'XSP :CKGUhd?lgCOF$;;DTHZMlvvcmZAm;:iv'MMgO&Ywbc;BLCUd&vZINLIzkuTGZa\n",
651
+ "D.?\n"
652
+ ]
653
+ }
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "markdown",
658
+ "source": [
659
+ "## The mathematical trick in self-attention"
660
+ ],
661
+ "metadata": {
662
+ "id": "XinV8nmAnmKN"
663
+ }
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "source": [
668
+ "# toy example illustrating how matrix multiplication can be used for a \"weighted aggregation\"\n",
669
+ "torch.manual_seed(42)\n",
670
+ "a = torch.tril(torch.ones(3, 3))\n",
671
+ "a = a / torch.sum(a, 1, keepdim=True)\n",
672
+ "b = torch.randint(0,10,(3,2)).float()\n",
673
+ "c = a @ b\n",
674
+ "print('a=')\n",
675
+ "print(a)\n",
676
+ "print('--')\n",
677
+ "print('b=')\n",
678
+ "print(b)\n",
679
+ "print('--')\n",
680
+ "print('c=')\n",
681
+ "print(c)"
682
+ ],
683
+ "metadata": {
684
+ "colab": {
685
+ "base_uri": "https://localhost:8080/"
686
+ },
687
+ "id": "tukiH-NbRBhA",
688
+ "outputId": "d981f6d4-ac08-4ec2-8284-82f5fa1e0815"
689
+ },
690
+ "execution_count": null,
691
+ "outputs": [
692
+ {
693
+ "output_type": "stream",
694
+ "name": "stdout",
695
+ "text": [
696
+ "a=\n",
697
+ "tensor([[1.0000, 0.0000, 0.0000],\n",
698
+ " [0.5000, 0.5000, 0.0000],\n",
699
+ " [0.3333, 0.3333, 0.3333]])\n",
700
+ "--\n",
701
+ "b=\n",
702
+ "tensor([[2., 7.],\n",
703
+ " [6., 4.],\n",
704
+ " [6., 5.]])\n",
705
+ "--\n",
706
+ "c=\n",
707
+ "tensor([[2.0000, 7.0000],\n",
708
+ " [4.0000, 5.5000],\n",
709
+ " [4.6667, 5.3333]])\n"
710
+ ]
711
+ }
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "source": [
717
+ "# consider the following toy example:\n",
718
+ "\n",
719
+ "torch.manual_seed(1337)\n",
720
+ "B,T,C = 4,8,2 # batch, time, channels\n",
721
+ "x = torch.randn(B,T,C)\n",
722
+ "x.shape"
723
+ ],
724
+ "metadata": {
725
+ "colab": {
726
+ "base_uri": "https://localhost:8080/"
727
+ },
728
+ "id": "Hs_E24uRE8kr",
729
+ "outputId": "8bf3ff5f-565e-48b8-de8e-7272706c8e12"
730
+ },
731
+ "execution_count": null,
732
+ "outputs": [
733
+ {
734
+ "output_type": "execute_result",
735
+ "data": {
736
+ "text/plain": [
737
+ "torch.Size([4, 8, 2])"
738
+ ]
739
+ },
740
+ "metadata": {},
741
+ "execution_count": 18
742
+ }
743
+ ]
744
+ },
745
+ {
746
+ "cell_type": "code",
747
+ "source": [
748
+ "# We want x[b,t] = mean_{i<=t} x[b,i]\n",
749
+ "xbow = torch.zeros((B,T,C))\n",
750
+ "for b in range(B):\n",
751
+ " for t in range(T):\n",
752
+ " xprev = x[b,:t+1] # (t,C)\n",
753
+ " xbow[b,t] = torch.mean(xprev, 0)\n"
754
+ ],
755
+ "metadata": {
756
+ "id": "86NuXX0fn7ps"
757
+ },
758
+ "execution_count": null,
759
+ "outputs": []
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "source": [
764
+ "# version 2: using matrix multiply for a weighted aggregation\n",
765
+ "wei = torch.tril(torch.ones(T, T))\n",
766
+ "wei = wei / wei.sum(1, keepdim=True)\n",
767
+ "xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)\n",
768
+ "torch.allclose(xbow, xbow2)"
769
+ ],
770
+ "metadata": {
771
+ "colab": {
772
+ "base_uri": "https://localhost:8080/"
773
+ },
774
+ "id": "yhdOAd6-wXkZ",
775
+ "outputId": "eaf6ab61-dff1-4bb7-e623-47f692bad5f9"
776
+ },
777
+ "execution_count": null,
778
+ "outputs": [
779
+ {
780
+ "output_type": "execute_result",
781
+ "data": {
782
+ "text/plain": [
783
+ "True"
784
+ ]
785
+ },
786
+ "metadata": {},
787
+ "execution_count": 20
788
+ }
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "code",
793
+ "source": [
794
+ "# version 3: use Softmax\n",
795
+ "tril = torch.tril(torch.ones(T, T))\n",
796
+ "wei = torch.zeros((T,T))\n",
797
+ "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
798
+ "wei = F.softmax(wei, dim=-1)\n",
799
+ "xbow3 = wei @ x\n",
800
+ "torch.allclose(xbow, xbow3)\n"
801
+ ],
802
+ "metadata": {
803
+ "colab": {
804
+ "base_uri": "https://localhost:8080/"
805
+ },
806
+ "id": "wOURrfG-ysoL",
807
+ "outputId": "080b500d-8110-4602-fcef-7d6f2ebfc6bc"
808
+ },
809
+ "execution_count": null,
810
+ "outputs": [
811
+ {
812
+ "output_type": "execute_result",
813
+ "data": {
814
+ "text/plain": [
815
+ "True"
816
+ ]
817
+ },
818
+ "metadata": {},
819
+ "execution_count": 21
820
+ }
821
+ ]
822
+ },
823
+ {
824
+ "cell_type": "code",
825
+ "source": [
826
+ "# version 4: self-attention!\n",
827
+ "torch.manual_seed(1337)\n",
828
+ "B,T,C = 4,8,32 # batch, time, channels\n",
829
+ "x = torch.randn(B,T,C)\n",
830
+ "\n",
831
+ "# let's see a single Head perform self-attention\n",
832
+ "head_size = 16\n",
833
+ "key = nn.Linear(C, head_size, bias=False)\n",
834
+ "query = nn.Linear(C, head_size, bias=False)\n",
835
+ "value = nn.Linear(C, head_size, bias=False)\n",
836
+ "k = key(x) # (B, T, 16)\n",
837
+ "q = query(x) # (B, T, 16)\n",
838
+ "wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
839
+ "\n",
840
+ "tril = torch.tril(torch.ones(T, T))\n",
841
+ "#wei = torch.zeros((T,T))\n",
842
+ "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
843
+ "wei = F.softmax(wei, dim=-1)\n",
844
+ "\n",
845
+ "v = value(x)\n",
846
+ "out = wei @ v\n",
847
+ "#out = wei @ x\n",
848
+ "\n",
849
+ "out.shape"
850
+ ],
851
+ "metadata": {
852
+ "colab": {
853
+ "base_uri": "https://localhost:8080/"
854
+ },
855
+ "id": "EDarxEWIRMKq",
856
+ "outputId": "07b587dd-a91c-4bb0-d7f1-e247cd5dacb5"
857
+ },
858
+ "execution_count": null,
859
+ "outputs": [
860
+ {
861
+ "output_type": "execute_result",
862
+ "data": {
863
+ "text/plain": [
864
+ "torch.Size([4, 8, 16])"
865
+ ]
866
+ },
867
+ "metadata": {},
868
+ "execution_count": 22
869
+ }
870
+ ]
871
+ },
872
+ {
873
+ "cell_type": "code",
874
+ "source": [
875
+ "wei[0]"
876
+ ],
877
+ "metadata": {
878
+ "colab": {
879
+ "base_uri": "https://localhost:8080/"
880
+ },
881
+ "id": "vT1hdtzXCjgL",
882
+ "outputId": "6d2c569b-7922-451f-9934-0fc564678d17"
883
+ },
884
+ "execution_count": null,
885
+ "outputs": [
886
+ {
887
+ "output_type": "execute_result",
888
+ "data": {
889
+ "text/plain": [
890
+ "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
891
+ " [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
892
+ " [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
893
+ " [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],\n",
894
+ " [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],\n",
895
+ " [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],\n",
896
+ " [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],\n",
897
+ " [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],\n",
898
+ " grad_fn=<SelectBackward0>)"
899
+ ]
900
+ },
901
+ "metadata": {},
902
+ "execution_count": 23
903
+ }
904
+ ]
905
+ },
906
+ {
907
+ "cell_type": "markdown",
908
+ "source": [
909
+ "Notes:\n",
910
+ "- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n",
911
+ "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n",
912
+ "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n",
913
+ "- In an \"encoder\" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n",
914
+ "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n",
915
+ "- \"Scaled\" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below"
916
+ ],
917
+ "metadata": {
918
+ "id": "M5CvobiQ0pLr"
919
+ }
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "source": [
924
+ "k = torch.randn(B,T,head_size)\n",
925
+ "q = torch.randn(B,T,head_size)\n",
926
+ "wei = q @ k.transpose(-2, -1) * head_size**-0.5"
927
+ ],
928
+ "metadata": {
929
+ "id": "4SNbLq5z3oBw"
930
+ },
931
+ "execution_count": null,
932
+ "outputs": []
933
+ },
934
+ {
935
+ "cell_type": "code",
936
+ "source": [
937
+ "k.var()"
938
+ ],
939
+ "metadata": {
940
+ "colab": {
941
+ "base_uri": "https://localhost:8080/"
942
+ },
943
+ "id": "Nl6I9n9IRTSo",
944
+ "outputId": "0c5b9cd0-af8a-4564-fbad-41d844e54822"
945
+ },
946
+ "execution_count": null,
947
+ "outputs": [
948
+ {
949
+ "output_type": "execute_result",
950
+ "data": {
951
+ "text/plain": [
952
+ "tensor(1.0449)"
953
+ ]
954
+ },
955
+ "metadata": {},
956
+ "execution_count": 25
957
+ }
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "source": [
963
+ "q.var()"
964
+ ],
965
+ "metadata": {
966
+ "colab": {
967
+ "base_uri": "https://localhost:8080/"
968
+ },
969
+ "id": "T1tQx7oeRvtc",
970
+ "outputId": "3541ca1a-7447-4ef7-835e-81824aebc1b5"
971
+ },
972
+ "execution_count": null,
973
+ "outputs": [
974
+ {
975
+ "output_type": "execute_result",
976
+ "data": {
977
+ "text/plain": [
978
+ "tensor(1.0700)"
979
+ ]
980
+ },
981
+ "metadata": {},
982
+ "execution_count": 26
983
+ }
984
+ ]
985
+ },
986
+ {
987
+ "cell_type": "code",
988
+ "source": [
989
+ "wei.var()"
990
+ ],
991
+ "metadata": {
992
+ "colab": {
993
+ "base_uri": "https://localhost:8080/"
994
+ },
995
+ "id": "MLb_odHU3iKM",
996
+ "outputId": "a687a222-5a2c-4cdb-c1bf-17cd05b45b69"
997
+ },
998
+ "execution_count": null,
999
+ "outputs": [
1000
+ {
1001
+ "output_type": "execute_result",
1002
+ "data": {
1003
+ "text/plain": [
1004
+ "tensor(1.0918)"
1005
+ ]
1006
+ },
1007
+ "metadata": {},
1008
+ "execution_count": 27
1009
+ }
1010
+ ]
1011
+ },
1012
+ {
1013
+ "cell_type": "code",
1014
+ "source": [
1015
+ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)"
1016
+ ],
1017
+ "metadata": {
1018
+ "colab": {
1019
+ "base_uri": "https://localhost:8080/"
1020
+ },
1021
+ "id": "JB82yzt44REI",
1022
+ "outputId": "f07da2f1-10bb-4a7a-bcaa-578587977d00"
1023
+ },
1024
+ "execution_count": null,
1025
+ "outputs": [
1026
+ {
1027
+ "output_type": "execute_result",
1028
+ "data": {
1029
+ "text/plain": [
1030
+ "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])"
1031
+ ]
1032
+ },
1033
+ "metadata": {},
1034
+ "execution_count": 28
1035
+ }
1036
+ ]
1037
+ },
1038
+ {
1039
+ "cell_type": "code",
1040
+ "source": [
1041
+ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot"
1042
+ ],
1043
+ "metadata": {
1044
+ "colab": {
1045
+ "base_uri": "https://localhost:8080/"
1046
+ },
1047
+ "id": "Mpt8569BB9_f",
1048
+ "outputId": "5d8b910a-6192-44ba-ebb2-497d88e0b629"
1049
+ },
1050
+ "execution_count": null,
1051
+ "outputs": [
1052
+ {
1053
+ "output_type": "execute_result",
1054
+ "data": {
1055
+ "text/plain": [
1056
+ "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])"
1057
+ ]
1058
+ },
1059
+ "metadata": {},
1060
+ "execution_count": 31
1061
+ }
1062
+ ]
1063
+ },
1064
+ {
1065
+ "cell_type": "code",
1066
+ "source": [
1067
+ "class LayerNorm1d: # (used to be BatchNorm1d)\n",
1068
+ "\n",
1069
+ " def __init__(self, dim, eps=1e-5, momentum=0.1):\n",
1070
+ " self.eps = eps\n",
1071
+ " self.gamma = torch.ones(dim)\n",
1072
+ " self.beta = torch.zeros(dim)\n",
1073
+ "\n",
1074
+ " def __call__(self, x):\n",
1075
+ " # calculate the forward pass\n",
1076
+ " xmean = x.mean(1, keepdim=True) # batch mean\n",
1077
+ " xvar = x.var(1, keepdim=True) # batch variance\n",
1078
+ " xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n",
1079
+ " self.out = self.gamma * xhat + self.beta\n",
1080
+ " return self.out\n",
1081
+ "\n",
1082
+ " def parameters(self):\n",
1083
+ " return [self.gamma, self.beta]\n",
1084
+ "\n",
1085
+ "torch.manual_seed(1337)\n",
1086
+ "module = LayerNorm1d(100)\n",
1087
+ "x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors\n",
1088
+ "x = module(x)\n",
1089
+ "x.shape"
1090
+ ],
1091
+ "metadata": {
1092
+ "colab": {
1093
+ "base_uri": "https://localhost:8080/"
1094
+ },
1095
+ "id": "2Num7sX9CKOH",
1096
+ "outputId": "929ceb78-a639-41d6-aac7-12997b5c93f0"
1097
+ },
1098
+ "execution_count": null,
1099
+ "outputs": [
1100
+ {
1101
+ "output_type": "execute_result",
1102
+ "data": {
1103
+ "text/plain": [
1104
+ "torch.Size([32, 100])"
1105
+ ]
1106
+ },
1107
+ "metadata": {},
1108
+ "execution_count": 32
1109
+ }
1110
+ ]
1111
+ },
1112
+ {
1113
+ "cell_type": "code",
1114
+ "source": [
1115
+ "x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs"
1116
+ ],
1117
+ "metadata": {
1118
+ "colab": {
1119
+ "base_uri": "https://localhost:8080/"
1120
+ },
1121
+ "id": "633T2cmnW1uk",
1122
+ "outputId": "7720fa58-0478-4e8a-86a7-502d4cce9443"
1123
+ },
1124
+ "execution_count": null,
1125
+ "outputs": [
1126
+ {
1127
+ "output_type": "execute_result",
1128
+ "data": {
1129
+ "text/plain": [
1130
+ "(tensor(0.1469), tensor(0.8803))"
1131
+ ]
1132
+ },
1133
+ "metadata": {},
1134
+ "execution_count": 33
1135
+ }
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "code",
1140
+ "source": [
1141
+ "x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features"
1142
+ ],
1143
+ "metadata": {
1144
+ "colab": {
1145
+ "base_uri": "https://localhost:8080/"
1146
+ },
1147
+ "id": "LN9cK9BoXCYb",
1148
+ "outputId": "6368ece0-600e-417d-8a91-7c1e5d750ba8"
1149
+ },
1150
+ "execution_count": null,
1151
+ "outputs": [
1152
+ {
1153
+ "output_type": "execute_result",
1154
+ "data": {
1155
+ "text/plain": [
1156
+ "(tensor(-9.5367e-09), tensor(1.0000))"
1157
+ ]
1158
+ },
1159
+ "metadata": {},
1160
+ "execution_count": 34
1161
+ }
1162
+ ]
1163
+ },
1164
+ {
1165
+ "cell_type": "code",
1166
+ "source": [
1167
+ "# French to English translation example:\n",
1168
+ "\n",
1169
+ "# <--------- ENCODE ------------------><--------------- DECODE ----------------->\n",
1170
+ "# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>\n",
1171
+ "\n"
1172
+ ],
1173
+ "metadata": {
1174
+ "id": "dRJH6wM_XFfU"
1175
+ },
1176
+ "execution_count": null,
1177
+ "outputs": []
1178
+ },
1179
+ {
1180
+ "cell_type": "markdown",
1181
+ "source": [
1182
+ "### Full finished code, for reference\n",
1183
+ "\n",
1184
+ "You may want to refer directly to the git repo instead though."
1185
+ ],
1186
+ "metadata": {
1187
+ "id": "ZcvKeBXoZFOY"
1188
+ }
1189
+ },
1190
+ {
1191
+ "cell_type": "code",
1192
+ "source": [
1193
+ "import torch\n",
1194
+ "import torch.nn as nn\n",
1195
+ "from torch.nn import functional as F\n",
1196
+ "\n",
1197
+ "# hyperparameters\n",
1198
+ "batch_size = 16 # how many independent sequences will we process in parallel?\n",
1199
+ "block_size = 32 # what is the maximum context length for predictions?\n",
1200
+ "max_iters = 5000\n",
1201
+ "eval_interval = 100\n",
1202
+ "learning_rate = 1e-3\n",
1203
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
1204
+ "eval_iters = 200\n",
1205
+ "n_embd = 64\n",
1206
+ "n_head = 4\n",
1207
+ "n_layer = 4\n",
1208
+ "dropout = 0.0\n",
1209
+ "# ------------\n",
1210
+ "\n",
1211
+ "torch.manual_seed(1337)\n",
1212
+ "\n",
1213
+ "# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
1214
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
1215
+ " text = f.read()\n",
1216
+ "\n",
1217
+ "# here are all the unique characters that occur in this text\n",
1218
+ "chars = sorted(list(set(text)))\n",
1219
+ "vocab_size = len(chars)\n",
1220
+ "# create a mapping from characters to integers\n",
1221
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
1222
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
1223
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
1224
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
1225
+ "\n",
1226
+ "# Train and test splits\n",
1227
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
1228
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
1229
+ "train_data = data[:n]\n",
1230
+ "val_data = data[n:]\n",
1231
+ "\n",
1232
+ "# data loading\n",
1233
+ "def get_batch(split):\n",
1234
+ " # generate a small batch of data of inputs x and targets y\n",
1235
+ " data = train_data if split == 'train' else val_data\n",
1236
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
1237
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
1238
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
1239
+ " x, y = x.to(device), y.to(device)\n",
1240
+ " return x, y\n",
1241
+ "\n",
1242
+ "@torch.no_grad()\n",
1243
+ "def estimate_loss():\n",
1244
+ " out = {}\n",
1245
+ " model.eval()\n",
1246
+ " for split in ['train', 'val']:\n",
1247
+ " losses = torch.zeros(eval_iters)\n",
1248
+ " for k in range(eval_iters):\n",
1249
+ " X, Y = get_batch(split)\n",
1250
+ " logits, loss = model(X, Y)\n",
1251
+ " losses[k] = loss.item()\n",
1252
+ " out[split] = losses.mean()\n",
1253
+ " model.train()\n",
1254
+ " return out\n",
1255
+ "\n",
1256
+ "class Head(nn.Module):\n",
1257
+ " \"\"\" one head of self-attention \"\"\"\n",
1258
+ "\n",
1259
+ " def __init__(self, head_size):\n",
1260
+ " super().__init__()\n",
1261
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
1262
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
1263
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
1264
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
1265
+ "\n",
1266
+ " self.dropout = nn.Dropout(dropout)\n",
1267
+ "\n",
1268
+ " def forward(self, x):\n",
1269
+ " B,T,C = x.shape\n",
1270
+ " k = self.key(x) # (B,T,C)\n",
1271
+ " q = self.query(x) # (B,T,C)\n",
1272
+ " # compute attention scores (\"affinities\")\n",
1273
+ " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
1274
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
1275
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
1276
+ " wei = self.dropout(wei)\n",
1277
+ " # perform the weighted aggregation of the values\n",
1278
+ " v = self.value(x) # (B,T,C)\n",
1279
+ " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
1280
+ " return out\n",
1281
+ "\n",
1282
+ "class MultiHeadAttention(nn.Module):\n",
1283
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
1284
+ "\n",
1285
+ " def __init__(self, num_heads, head_size):\n",
1286
+ " super().__init__()\n",
1287
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
1288
+ " self.proj = nn.Linear(n_embd, n_embd)\n",
1289
+ " self.dropout = nn.Dropout(dropout)\n",
1290
+ "\n",
1291
+ " def forward(self, x):\n",
1292
+ " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
1293
+ " out = self.dropout(self.proj(out))\n",
1294
+ " return out\n",
1295
+ "\n",
1296
+ "class FeedFoward(nn.Module):\n",
1297
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
1298
+ "\n",
1299
+ " def __init__(self, n_embd):\n",
1300
+ " super().__init__()\n",
1301
+ " self.net = nn.Sequential(\n",
1302
+ " nn.Linear(n_embd, 4 * n_embd),\n",
1303
+ " nn.ReLU(),\n",
1304
+ " nn.Linear(4 * n_embd, n_embd),\n",
1305
+ " nn.Dropout(dropout),\n",
1306
+ " )\n",
1307
+ "\n",
1308
+ " def forward(self, x):\n",
1309
+ " return self.net(x)\n",
1310
+ "\n",
1311
+ "class Block(nn.Module):\n",
1312
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
1313
+ "\n",
1314
+ " def __init__(self, n_embd, n_head):\n",
1315
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
1316
+ " super().__init__()\n",
1317
+ " head_size = n_embd // n_head\n",
1318
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
1319
+ " self.ffwd = FeedFoward(n_embd)\n",
1320
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
1321
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
1322
+ "\n",
1323
+ " def forward(self, x):\n",
1324
+ " x = x + self.sa(self.ln1(x))\n",
1325
+ " x = x + self.ffwd(self.ln2(x))\n",
1326
+ " return x\n",
1327
+ "\n",
1328
+ "# super simple bigram model\n",
1329
+ "class BigramLanguageModel(nn.Module):\n",
1330
+ "\n",
1331
+ " def __init__(self):\n",
1332
+ " super().__init__()\n",
1333
+ " # each token directly reads off the logits for the next token from a lookup table\n",
1334
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
1335
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
1336
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
1337
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
1338
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
1339
+ "\n",
1340
+ " def forward(self, idx, targets=None):\n",
1341
+ " B, T = idx.shape\n",
1342
+ "\n",
1343
+ " # idx and targets are both (B,T) tensor of integers\n",
1344
+ " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
1345
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
1346
+ " x = tok_emb + pos_emb # (B,T,C)\n",
1347
+ " x = self.blocks(x) # (B,T,C)\n",
1348
+ " x = self.ln_f(x) # (B,T,C)\n",
1349
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
1350
+ "\n",
1351
+ " if targets is None:\n",
1352
+ " loss = None\n",
1353
+ " else:\n",
1354
+ " B, T, C = logits.shape\n",
1355
+ " logits = logits.view(B*T, C)\n",
1356
+ " targets = targets.view(B*T)\n",
1357
+ " loss = F.cross_entropy(logits, targets)\n",
1358
+ "\n",
1359
+ " return logits, loss\n",
1360
+ "\n",
1361
+ " def generate(self, idx, max_new_tokens):\n",
1362
+ " # idx is (B, T) array of indices in the current context\n",
1363
+ " for _ in range(max_new_tokens):\n",
1364
+ " # crop idx to the last block_size tokens\n",
1365
+ " idx_cond = idx[:, -block_size:]\n",
1366
+ " # get the predictions\n",
1367
+ " logits, loss = self(idx_cond)\n",
1368
+ " # focus only on the last time step\n",
1369
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
1370
+ " # apply softmax to get probabilities\n",
1371
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
1372
+ " # sample from the distribution\n",
1373
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
1374
+ " # append sampled index to the running sequence\n",
1375
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
1376
+ " return idx\n",
1377
+ "\n",
1378
+ "model = BigramLanguageModel()\n",
1379
+ "m = model.to(device)\n",
1380
+ "# print the number of parameters in the model\n",
1381
+ "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
1382
+ "\n",
1383
+ "# create a PyTorch optimizer\n",
1384
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
1385
+ "\n",
1386
+ "for iter in range(max_iters):\n",
1387
+ "\n",
1388
+ " # every once in a while evaluate the loss on train and val sets\n",
1389
+ " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
1390
+ " losses = estimate_loss()\n",
1391
+ " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
1392
+ "\n",
1393
+ " # sample a batch of data\n",
1394
+ " xb, yb = get_batch('train')\n",
1395
+ "\n",
1396
+ " # evaluate the loss\n",
1397
+ " logits, loss = model(xb, yb)\n",
1398
+ " optimizer.zero_grad(set_to_none=True)\n",
1399
+ " loss.backward()\n",
1400
+ " optimizer.step()\n",
1401
+ "\n",
1402
+ "# generate from the model\n",
1403
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
1404
+ "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n"
1405
+ ],
1406
+ "metadata": {
1407
+ "colab": {
1408
+ "base_uri": "https://localhost:8080/"
1409
+ },
1410
+ "id": "hoelkOrFY8bN",
1411
+ "outputId": "961304cd-e379-40d4-dd56-8de0b91d2861"
1412
+ },
1413
+ "execution_count": null,
1414
+ "outputs": [
1415
+ {
1416
+ "output_type": "stream",
1417
+ "name": "stdout",
1418
+ "text": [
1419
+ "0.209729 M parameters\n",
1420
+ "step 0: train loss 4.4116, val loss 4.4022\n",
1421
+ "step 100: train loss 2.6568, val loss 2.6670\n",
1422
+ "step 200: train loss 2.5090, val loss 2.5058\n",
1423
+ "step 300: train loss 2.4198, val loss 2.4340\n",
1424
+ "step 400: train loss 2.3503, val loss 2.3567\n",
1425
+ "step 500: train loss 2.2970, val loss 2.3136\n",
1426
+ "step 600: train loss 2.2410, val loss 2.2506\n",
1427
+ "step 700: train loss 2.2062, val loss 2.2198\n",
1428
+ "step 800: train loss 2.1638, val loss 2.1871\n",
1429
+ "step 900: train loss 2.1232, val loss 2.1494\n",
1430
+ "step 1000: train loss 2.1020, val loss 2.1293\n",
1431
+ "step 1100: train loss 2.0704, val loss 2.1196\n",
1432
+ "step 1200: train loss 2.0382, val loss 2.0798\n",
1433
+ "step 1300: train loss 2.0249, val loss 2.0640\n",
1434
+ "step 1400: train loss 1.9922, val loss 2.0354\n",
1435
+ "step 1500: train loss 1.9707, val loss 2.0308\n",
1436
+ "step 1600: train loss 1.9614, val loss 2.0474\n",
1437
+ "step 1700: train loss 1.9393, val loss 2.0130\n",
1438
+ "step 1800: train loss 1.9070, val loss 1.9943\n",
1439
+ "step 1900: train loss 1.9057, val loss 1.9871\n",
1440
+ "step 2000: train loss 1.8834, val loss 1.9954\n",
1441
+ "step 2100: train loss 1.8719, val loss 1.9758\n",
1442
+ "step 2200: train loss 1.8582, val loss 1.9623\n",
1443
+ "step 2300: train loss 1.8546, val loss 1.9517\n",
1444
+ "step 2400: train loss 1.8410, val loss 1.9476\n",
1445
+ "step 2500: train loss 1.8167, val loss 1.9455\n",
1446
+ "step 2600: train loss 1.8263, val loss 1.9401\n",
1447
+ "step 2700: train loss 1.8108, val loss 1.9340\n",
1448
+ "step 2800: train loss 1.8040, val loss 1.9247\n",
1449
+ "step 2900: train loss 1.8044, val loss 1.9304\n",
1450
+ "step 3000: train loss 1.7963, val loss 1.9242\n",
1451
+ "step 3100: train loss 1.7687, val loss 1.9147\n",
1452
+ "step 3200: train loss 1.7547, val loss 1.9102\n",
1453
+ "step 3300: train loss 1.7557, val loss 1.9037\n",
1454
+ "step 3400: train loss 1.7547, val loss 1.8946\n",
1455
+ "step 3500: train loss 1.7385, val loss 1.8968\n",
1456
+ "step 3600: train loss 1.7260, val loss 1.8914\n",
1457
+ "step 3700: train loss 1.7257, val loss 1.8808\n",
1458
+ "step 3800: train loss 1.7204, val loss 1.8919\n",
1459
+ "step 3900: train loss 1.7215, val loss 1.8788\n",
1460
+ "step 4000: train loss 1.7146, val loss 1.8639\n",
1461
+ "step 4100: train loss 1.7095, val loss 1.8724\n",
1462
+ "step 4200: train loss 1.7079, val loss 1.8707\n",
1463
+ "step 4300: train loss 1.7035, val loss 1.8502\n",
1464
+ "step 4400: train loss 1.7043, val loss 1.8693\n",
1465
+ "step 4500: train loss 1.6914, val loss 1.8522\n",
1466
+ "step 4600: train loss 1.6853, val loss 1.8357\n",
1467
+ "step 4700: train loss 1.6862, val loss 1.8483\n",
1468
+ "step 4800: train loss 1.6671, val loss 1.8434\n",
1469
+ "step 4900: train loss 1.6736, val loss 1.8415\n",
1470
+ "step 4999: train loss 1.6635, val loss 1.8226\n",
1471
+ "\n",
1472
+ "FlY BOLINGLO:\n",
1473
+ "Them thrumply towiter arts the\n",
1474
+ "muscue rike begatt the sea it\n",
1475
+ "What satell in rowers that some than othis Marrity.\n",
1476
+ "\n",
1477
+ "LUCENTVO:\n",
1478
+ "But userman these that, where can is not diesty rege;\n",
1479
+ "What and see to not. But's eyes. What?\n",
1480
+ "\n",
1481
+ "JOHN MARGARET:\n",
1482
+ "Than up I wark, what out, I ever of and love,\n",
1483
+ "one these do sponce, vois I me;\n",
1484
+ "But my pray sape to ries all to the not erralied in may.\n",
1485
+ "\n",
1486
+ "BENVOLIO:\n",
1487
+ "To spits as stold's bewear I would and say mesby all\n",
1488
+ "on sworn make he anough\n",
1489
+ "As cousins the solle, whose be my conforeful may lie them yet\n",
1490
+ "nobe allimely untraled to be thre I say be,\n",
1491
+ "Notham a brotes theme an make come,\n",
1492
+ "And that his reach to the duke ento\n",
1493
+ "the grmeants bell! and now there king-liff-or grief?\n",
1494
+ "\n",
1495
+ "GLOUCESTER:\n",
1496
+ "All the bettle dreene, for To his like thou thron!\n",
1497
+ "\n",
1498
+ "MENENIUS:\n",
1499
+ "Then, if I knom her all.\n",
1500
+ "My lord, but terruly friend\n",
1501
+ "Rish of the ploceiness and wilt tends sure?\n",
1502
+ "Is you knows a fasir wead\n",
1503
+ "That with him my spaut,\n",
1504
+ "I shall not tas where's not, becomity; my coulds sting,\n",
1505
+ "then the wit be dong to tyget our hereefore,\n",
1506
+ "Who strop me, mend here, if agains, bitten, thy lack.\n",
1507
+ "The but these it were is tus. For the her skeep the fasting. joy tweet Bumner:-\n",
1508
+ "How the enclady: It you and how,\n",
1509
+ "I am in him, And ladderle:\n",
1510
+ "Their hand whose wife, it my hithre,\n",
1511
+ "Roman and where sposs gives'd you.\n",
1512
+ "\n",
1513
+ "TROMIOLANUS:\n",
1514
+ "But livants you great, I shom mistrot come, for to she to lot\n",
1515
+ "for smy to men ventry mehus. Gazise;\n",
1516
+ "Full't were some the cause, and stouch set,\n",
1517
+ "Or promises, which a kingsasted to your gove them; and sterrer,\n",
1518
+ "And that wae love him.\n",
1519
+ "\n",
1520
+ "BRUTUS:\n",
1521
+ "You shape with these sweet.\n",
1522
+ "\n",
1523
+ "CORTENGONO:\n",
1524
+ "Lo, where 'twon elmes, 'morth young agres;\n",
1525
+ "Sir, azavoust to striel accurded we missery sets crave.\n",
1526
+ "\n",
1527
+ "ANGOLUM:\n",
1528
+ "For is Henry to have gleise the dreason\n",
1529
+ "That I ant shorfold wefth their servy in enscy.\n",
1530
+ "\n",
1531
+ "ISABELLA:\n",
1532
+ "O, I better you eyse such formfetrews.\n",
1533
+ "\n",
1534
+ "BUCKINGHARENT:\n",
1535
+ "Qead my lightle this righanneds flase them\n",
1536
+ "Wam which an take was our some pleasurs,\n",
1537
+ "Lovisoname to me, then fult me?--have it?\n",
1538
+ "\n",
1539
+ "HENRY BOLINGBROY:\n",
1540
+ "That wha\n"
1541
+ ]
1542
+ }
1543
+ ]
1544
+ },
1545
+ {
1546
+ "cell_type": "code",
1547
+ "source": [],
1548
+ "metadata": {
1549
+ "id": "fjjvMifYZf7x"
1550
+ },
1551
+ "execution_count": null,
1552
+ "outputs": []
1553
+ }
1554
+ ]
1555
+ }