{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2023-10-27 16:11:32-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1115394 (1.1M) [text/plain]\n", "Saving to: ‘input.txt.1’\n", "\n", "input.txt.1 100%[===================>] 1.06M 734KB/s in 1.5s \n", "\n", "2023-10-27 16:11:36 (734 KB/s) - ‘input.txt.1’ saved [1115394/1115394]\n", "\n" ] } ], "source": [ "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "with open(\"input.txt\") as f:\n", " text = f.read()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'First Citizen:\\nBefore we proceed any further, hear'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text[:50]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n", "65\n" ] } ], "source": [ "chars = sorted(list(set(text)))\n", "vocab_size = len(chars)\n", "\n", "print(\"\".join(chars))\n", "print(vocab_size)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[46, 47, 1, 58, 46, 43, 56, 43]\n", "hi there\n" ] } ], "source": [ "stoi = {ch: i for i, ch in enumerate(chars)}\n", "itos = {i: ch for i, ch in enumerate(chars)}\n", "\n", "encode = lambda s: [stoi[c] for c in s]\n", "decode = lambda l: \"\".join([itos[i] for i in l])\n", "\n", "print(encode(\"hi there\"))\n", "\n", "print(decode(encode(\"hi there\")))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1115394]) torch.int64\n", "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n", " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n", " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n", " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n", " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n", " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59])\n" ] } ], "source": [ "import torch\n", "\n", "data = torch.tensor(encode(text), dtype=torch.long)\n", "print(data.shape, data.dtype)\n", "print(data[:100])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "n = int(0.9 * len(data))\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Inputs:\n", "torch.Size([4, 8])\n", "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n", " [44, 53, 56, 1, 58, 46, 39, 58],\n", " [52, 58, 1, 58, 46, 39, 58, 1],\n", " [25, 17, 27, 10, 0, 21, 1, 54]])\n", "-----------\n", "Targets:\n", "torch.Size([4, 8])\n", "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n", " [53, 56, 1, 58, 46, 39, 58, 1],\n", " [58, 1, 58, 46, 39, 58, 1, 46],\n", " [17, 27, 10, 0, 21, 1, 54, 39]])\n" ] } ], "source": [ "torch.manual_seed(1337)\n", "batch_size = 4\n", "block_size = 8\n", "\n", "\n", "def get_batch(split):\n", " data = train_data if split == \"train\" else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i : i + block_size] for i in ix])\n", " y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])\n", " return x, y\n", "\n", "\n", "xb, yb = get_batch(\"train\")\n", "print(\"Inputs:\")\n", "print(xb.shape)\n", "print(xb)\n", "\n", "print(\"-----------\")\n", "print(\"Targets:\")\n", "print(yb.shape)\n", "print(yb)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "\n", "class BigramLanguageModel(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n", "\n", " def forward(self, idx, targets):\n", " logits = self.token_embedding_table(idx)\n", "\n", " return logits" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([4, 8, 65])\n" ] } ], "source": [ "m = BigramLanguageModel(vocab_size)\n", "out = m(xb, yb)\n", "print(out.shape) # B,T,C -> 4X8X65" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 65])\n", "tensor(4.5262, grad_fn=)\n" ] } ], "source": [ "class BigramLanguageModel(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n", "\n", " def forward(self, idx, targets=None):\n", " logits = self.token_embedding_table(idx) # BTC\n", " loss = None\n", " if targets is not None:\n", " B, T, C = logits.shape\n", " logits = logits.view(B * T, C)\n", " targets = targets.view(B * T)\n", " loss = F.cross_entropy(logits, targets)\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " for _ in range(max_new_tokens):\n", " logits, loss = self(idx) # BxTxC\n", " logits = logits[:, -1, :] # BxC\n", " probs = F.softmax(logits, dim=-1) # BxC\n", " idx_next = torch.multinomial(probs, num_samples=1) # Bx1\n", " idx = torch.cat((idx, idx_next), dim=1) # BxT+1\n", "\n", " return idx\n", "\n", "\n", "m = BigramLanguageModel(vocab_size)\n", "logits, loss = m(xb, yb)\n", "print(logits.shape) # B,T,C -> 4X8X65\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "'JgC.JZWqUkpdtkSpmzjM-,RqzgaN?vC:hgjnAnBZDga-APqGUH!WdCbIb;$DefOYbEvcaKGMmnO'q$KdS-'ZH\n", ".YSqr'X!Q! d;\n" ] } ], "source": [ "idx = torch.zeros((1, 1), dtype=torch.long)\n", "\n", "results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n", "\n", "print(results)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.4206888675689697\n" ] } ], "source": [ "batch_size = 32\n", "\n", "for steps in range(10000):\n", " xb, yb = get_batch(\"train\")\n", "\n", " logits, loss = m(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "print(loss.item())" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Hou'sy'ting'stis's w ys'stholealy woawhimedy it 'save,\n", "Too:Had wh fo an, ZCENERUCHENar ee onds, th h\n" ] } ], "source": [ "idx = torch.zeros((1, 1), dtype=torch.long)\n", "\n", "results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n", "\n", "print(results)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 8, 16])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "B, T, C = 4, 8, 32\n", "\n", "x = torch.randn(B, T, C)\n", "\n", "head_size = 16\n", "key = nn.Linear(C, head_size, bias=False)\n", "query = nn.Linear(C, head_size, bias=False)\n", "value = nn.Linear(C, head_size, bias=False)\n", "k = key(x)\n", "q = query(x)\n", "wei = q @ k.transpose(-2, -1) * (head_size**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)\n", "\n", "tril = torch.tril(torch.ones(T, T))\n", "wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n", "wei = F.softmax(wei, dim=-1)\n", "v = value(x)\n", "out = wei @ v\n", "\n", "out.shape\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.3325, 0.6675, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.3578, 0.2873, 0.3550, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.2281, 0.1964, 0.2733, 0.3022, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.2851, 0.1588, 0.2068, 0.1436, 0.2057, 0.0000, 0.0000, 0.0000],\n", " [0.2429, 0.1547, 0.1550, 0.1475, 0.2049, 0.0951, 0.0000, 0.0000],\n", " [0.1573, 0.1838, 0.1123, 0.1680, 0.1528, 0.1194, 0.1063, 0.0000],\n", " [0.1139, 0.1704, 0.0766, 0.1134, 0.1600, 0.1466, 0.1228, 0.0963]],\n", " grad_fn=)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wei[0]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }