{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "VktNs2NoNiDt" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3jOZxHu3NiDt", "outputId": "8fe80f12-21b3-4b81-9388-5211f00f6848" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(1337)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T-pNtwn5NiDu" }, "outputs": [], "source": [ "# hyperparameters\n", "batch_size = 8 # how many independent sequences will we process in parallel?\n", "block_size = 128 # what is the maximum context length for predictions?\n", "max_iters = 100\n", "eval_interval = 10\n", "learning_rate = 6.0 * 10**-4\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "eval_iters = 200\n", "n_embd = 768\n", "n_head = 12\n", "n_layer = 12\n", "dropout = 0.25" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XYHnR6ETNiDu" }, "outputs": [], "source": [ "with open(\"\", \"r\", encoding=\"utf-8\") as f:\n", " text = f.read()\n", "\n", "chars = sorted(list(set(text)))\n", "vocab_size = len(chars)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gJZbe7PyNiDu" }, "outputs": [], "source": [ "# create a mapping from characters to integers\n", "stoi = { ch:i for i,ch in enumerate(chars) }\n", "itos = { i:ch for i,ch in enumerate(chars) }\n", "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OAN-qtdPNiDv" }, "outputs": [], "source": [ "# Train and test splits\n", "data = torch.tensor(encode(text), dtype=torch.long)\n", "n = int(0.9*len(data)) # first 90% will be train, rest val\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wZ70_NY-NiDv" }, "outputs": [], "source": [ "# data loading\n", "def get_batch(split):\n", " # generate a small batch of data of inputs x and targets y\n", " data = train_data if split == 'train' else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i:i+block_size] for i in ix])\n", " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", " x, y = x.to(device), y.to(device)\n", " return x, y\n", "\n", "@torch.no_grad()\n", "def estimate_loss(model):\n", " out = {}\n", " model.eval()\n", " for split in ['val']:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_batch(split)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KgnNQJUENiDv" }, "outputs": [], "source": [ "class Head(nn.Module):\n", " \"\"\" one head of self-attention \"\"\"\n", "\n", " def __init__(self, head_size):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " B,T,C = x.shape\n", " k = self.key(x) # (B,T,C)\n", " q = self.query(x) # (B,T,C)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,C)\n", " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", " return out\n", "\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", "\n", " def __init__(self, num_heads, head_size):\n", " super().__init__()\n", " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", " self.proj = nn.Linear(n_embd, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "class FeedFoward(nn.Module):\n", " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", "\n", " def __init__(self, n_embd):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "class Block(nn.Module):\n", " \"\"\" Transformer block: communication followed by computation \"\"\"\n", "\n", " def __init__(self, n_embd, n_head):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " x = x + self.sa(self.ln1(x))\n", " x = x + self.ffwd(self.ln2(x))\n", " return x\n", "\n", "# super simple bigram model\n", "class BigramLanguageModel(nn.Module):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", "\n", " def forward(self, idx, targets=None):\n", " B, T = idx.shape\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B*T, C)\n", " targets = targets.view(B*T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " idx_cond = idx[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self(idx_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2J-6tLksNiDv", "outputId": "dc261c06-4699-45d6-883d-c94829e06e7c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "85.226561 M parameters\n" ] } ], "source": [ "model = BigramLanguageModel().to(device)\n", "# print the number of parameters in the model\n", "print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mJ_twWqrNiDw" }, "outputs": [], "source": [ "# create a pytorch optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lyVGZiRPPKuy" }, "outputs": [], "source": [ "state = torch.load(\"\")\n", "model.load_state_dict(state[\"model_state_dict\"])\n", "optimizer.load_state_dict(state[\"optimizer_state_dict\"])\n", "max_iters = state[\"epoch\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OQaLj9z3NiDw" }, "outputs": [], "source": [ "# train\n", "iterator = tqdm(range(max_iters), desc=\"Training\", postfix={\"train_loss\": 0.0})\n", "\n", "for iter in iterator:\n", "\n", " # sample a batch of data\n", " xb, yb = get_batch('train')\n", "\n", " # evaluate the loss\n", " logits, loss = model(xb, yb)\n", " val_loss = estimate_loss(model)[\"val\"]\n", "\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Update the postfix with current train loss\n", " iterator.set_postfix({\"train_loss\": loss.item(), \"val_loss\": val_loss.item()}, refresh=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "byBjpL1f5gog" }, "outputs": [], "source": [ "torch.save({\n", " \"epoch\": \"\",\n", " \"model_state_dict\": model.state_dict(),\n", " \"optimizer_state_dict\": optimizer.state_dict(),\n", "}, \"\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "SV7zpB87NiDw" }, "outputs": [], "source": [ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", "print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 0 }