{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from contextlib import nullcontext\n", "from bigram_model import BigramLanguageModel\n", "from tokenizer_utils import IntCharTokenizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler\n", "ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n", "ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)\n", "scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from data_utils import *\n", "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embed, block_size=BLOCK_SIZE,\n", " bias=False, vocab_size=None, dropout=dropout)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([128, 256, 65])\n", "tensor(4.3690, device='cuda:0', grad_fn=)\n" ] } ], "source": [ "from data_utils import *\n", "xb, yb = get_random_batch('train')\n", "xb = xb.to(device)\n", "yb = yb.to(device)\n", "\n", "m = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device)\n", "logits, loss = m(xb, yb)\n", "print(logits.shape)\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "\n", "def estimate_loss(model):\n", " out = {}\n", " model.eval()\n", " for split in ['train', 'val']:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_random_batch(split)\n", " with ctx:\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "char_tokenizer = load_int_char_tokenizer(load_text())" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.788929 M parameters\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 0: train loss 4.3685, val loss 4.3640\n", "step 500: train loss 1.9681, val loss 2.0837\n", "step 1000: train loss 1.5377, val loss 1.7404\n", "step 1500: train loss 1.3802, val loss 1.6101\n", "step 2000: train loss 1.2855, val loss 1.5551\n", "step 2500: train loss 1.2162, val loss 1.5157\n", "step 3000: train loss 1.1617, val loss 1.5088\n", "step 3500: train loss 1.1061, val loss 1.5088\n", "step 4000: train loss 1.0555, val loss 1.5150\n", "step 4500: train loss 1.0086, val loss 1.5385\n", "step 4999: train loss 0.9583, val loss 1.5524\n" ] } ], "source": [ "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n", "\n", "# create a PyTorch optimizer\n", "optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)\n", "\n", "for iter in range(max_iters):\n", "\n", " # every once in a while evaluate the loss on train and val sets\n", " if iter % eval_interval == 0 or iter == max_iters - 1:\n", " losses = estimate_loss(m)\n", " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", "\n", " # sample a batch of data\n", " xb, yb = get_random_batch('train')\n", "\n", " # evaluate the loss\n", " logits, loss = m(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "saving checkpoint to ./nano_gpt_ckpts\n" ] } ], "source": [ "checkpoint = {\n", " 'model': m.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'model_args': model_args,\n", " 'iter_num': max_iters,\n", " 'best_val_loss': losses['val'],\n", "\n", "}\n", "out_dir = \"./nano_gpt_ckpts\"\n", "print(f\"saving checkpoint to {out_dir}\")\n", "torch.save(checkpoint, os.path.join(out_dir, 'ckpt_5k_iters.pt'))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "#m2 = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "GLOUCESTER: learn, like a nap. Prisoner will to my intents! with my brother! and this bloody makes off flows,--and haste tear'd your roe!--I should not be the other's.---I'ld do hear that be pupy with thear; sweet Montague,--thou as done not--So that they have nage must know,--never speak so many tears,--traightful ner-light,--with'd yet a ping tymp,--which time to stir; now still hurr'd,---water'd honour,--Pray's Coitlinius: the mountake's nobled daughter.' Sir, it is some thee on Rome is sin:--'proud him 'there;' none honest seen; forsweet must be pointed, hurls thee in men; a proud confines, foot, die, gin night, old Ratchard!--Go, good lord!--will'd you not piece, I dare not.' an't; swear by the dog, belike! mother!--How sir!-Spite! Jupiteous put o's!--God leave your lawful coward!'--for I'll dry down, you in death;'--near'---for very 'ven a day.---fa, by; 'twas his mother's disposed;--'I shall make no son,--hard him hear me,--do. Madam, or smother'd wife: and that you may part this denies.'--'--thrieks for Richmond dancerts, in free people's anointed,--O, hold: Curs, on a fiathful doom: every nurse, is I long now, never large.' quoth let return him; for an't plead the fie, his maids; he will not quarrel; 'twas this, but take within, as he learn, as and heat, it see; a gized evassages of season, imagish: yet, a very no other consulance, good den.--To fair cousin, stay! come, sir; and hath been, let it breather ring.' God; I am, trusper, I say: provided, pardone! a never lady; come in God. I'll fight with Montagues come. Why, 'twas bring you to be, if the pass off, and here, it dare, man cryield. Frow, your head A called with Gaunt; the cause. O, prettiest his pale thing, rust, and good. Thou adventure be more, Juliet, perishease: I'll take the queen, and his love.--give me note to de,--dyes help, Edward, and after Romeo!--Whence labour cann'd Warwick! was? whither? why hours! fairs! after was? stay come! your run? a happy kind!--O day, go be--hours, wrong!--ta w\n" ] } ], "source": [ "# generate from the model\n", "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", "#print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n", "print(char_tokenizer.decode(m.generate(context, max_new_tokens=2000)[0].tolist()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m3 = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer)\n", "ckpt = torch.load(os.path.join(\"./nano_gpt_ckpts\", \"ckpt_5k_iters.pt\"))\n", "m3.load_state_dict(ckpt['model'])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "But Dohor, aged by! At Antigonus. You see his court! For death; a talm every hand, here shall!--So,--O, I, title now point!--Who, this I sem blind--that tark;--come boy?---O pray, peace! May, two here, do not---that I troth:----to villain leave, where was the Gallent--if I look the house,--bold Jour---whether may I go,--Mine son,---as I amiled me pized,--or so fled; 'tis a famouse,--there littenants,--If an either lawful hant ther is gone.' Sicilence, if it wer done! I have twize its sourness. P\n" ] } ], "source": [ "context = torch.zeros((1, 1), dtype=torch.long)\n", "print(char_tokenizer.decode(m3.generate(context, max_new_tokens=500)[0].tolist()))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }