{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset\n", "from torch.utils.data.dataloader import DataLoader\n", "from mingpt.utils import set_seed\n", "set_seed(3407)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "class SortDataset(Dataset):\n", " \"\"\" \n", " Dataset for the Sort problem. E.g. for problem length 6:\n", " Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n", " Which will feed into the transformer concatenated as:\n", " input: 0 0 2 1 0 1 0 0 0 1 1\n", " output: I I I I I 0 0 0 1 1 2\n", " where I is \"ignore\", as the transformer is reading the input sequence\n", " \"\"\"\n", "\n", " def __init__(self, split, length=6, num_digits=3):\n", " assert split in {'train', 'test'}\n", " self.split = split\n", " self.length = length\n", " self.num_digits = num_digits\n", " \n", " def __len__(self):\n", " return 10000 # ...\n", " \n", " def get_vocab_size(self):\n", " return self.num_digits\n", " \n", " def get_block_size(self):\n", " # the length of the sequence that will feed into transformer, \n", " # containing concatenated input and the output, but -1 because\n", " # the transformer starts making predictions at the last input element\n", " return self.length * 2 - 1\n", "\n", " def __getitem__(self, idx):\n", " \n", " # use rejection sampling to generate an input example from the desired split\n", " while True:\n", " # generate some random integers\n", " inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n", " # half of the time let's try to boost the number of examples that \n", " # have a large number of repeats, as this is what the model seems to struggle\n", " # with later in training, and they are kind of rate\n", " if torch.rand(1).item() < 0.5:\n", " if inp.unique().nelement() > self.length // 2:\n", " # too many unqiue digits, re-sample\n", " continue\n", " # figure out if this generated example is train or test based on its hash\n", " h = hash(pickle.dumps(inp.tolist()))\n", " inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n", " if inp_split == self.split:\n", " break # ok\n", " \n", " # solve the task: i.e. sort\n", " sol = torch.sort(inp)[0]\n", "\n", " # concatenate the problem specification and the solution\n", " cat = torch.cat((inp, sol), dim=0)\n", "\n", " # the inputs to the transformer will be the offset sequence\n", " x = cat[:-1].clone()\n", " y = cat[1:].clone()\n", " # we only want to predict at output locations, mask out the loss at the input locations\n", " y[:self.length-1] = -1\n", " return x, y\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 -1\n", "0 -1\n", "1 -1\n", "0 -1\n", "0 -1\n", "0 0\n", "0 0\n", "0 0\n", "0 0\n", "0 1\n", "1 1\n" ] } ], "source": [ "# print an example instance of the dataset\n", "train_dataset = SortDataset('train')\n", "test_dataset = SortDataset('test')\n", "x, y = train_dataset[0]\n", "for a, b in zip(x,y):\n", " print(int(a),int(b))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of parameters: 0.09M\n" ] } ], "source": [ "# create a GPT instance\n", "from mingpt.model import GPT\n", "\n", "model_config = GPT.get_default_config()\n", "model_config.model_type = 'gpt-nano'\n", "model_config.vocab_size = train_dataset.get_vocab_size()\n", "model_config.block_size = train_dataset.get_block_size()\n", "model = GPT(model_config)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "running on device cuda\n" ] } ], "source": [ "# create a Trainer object\n", "from mingpt.trainer import Trainer\n", "\n", "train_config = Trainer.get_default_config()\n", "train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n", "train_config.max_iters = 2000\n", "train_config.num_workers = 0\n", "trainer = Trainer(train_config, model, train_dataset)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter_dt 0.00ms; iter 0: train loss 1.06407\n", "iter_dt 18.17ms; iter 100: train loss 0.14712\n", "iter_dt 18.70ms; iter 200: train loss 0.05315\n", "iter_dt 19.65ms; iter 300: train loss 0.04404\n", "iter_dt 31.64ms; iter 400: train loss 0.04724\n", "iter_dt 18.43ms; iter 500: train loss 0.02521\n", "iter_dt 19.83ms; iter 600: train loss 0.03352\n", "iter_dt 19.58ms; iter 700: train loss 0.00539\n", "iter_dt 18.72ms; iter 800: train loss 0.02057\n", "iter_dt 18.26ms; iter 900: train loss 0.00360\n", "iter_dt 18.50ms; iter 1000: train loss 0.00788\n", "iter_dt 20.64ms; iter 1100: train loss 0.01162\n", "iter_dt 18.63ms; iter 1200: train loss 0.00963\n", "iter_dt 18.32ms; iter 1300: train loss 0.02066\n", "iter_dt 18.40ms; iter 1400: train loss 0.01739\n", "iter_dt 18.37ms; iter 1500: train loss 0.00376\n", "iter_dt 18.67ms; iter 1600: train loss 0.00133\n", "iter_dt 18.38ms; iter 1700: train loss 0.00179\n", "iter_dt 18.66ms; iter 1800: train loss 0.00079\n", "iter_dt 18.48ms; iter 1900: train loss 0.00042\n" ] } ], "source": [ "def batch_end_callback(trainer):\n", " if trainer.iter_num % 100 == 0:\n", " print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n", "trainer.set_callback('on_batch_end', batch_end_callback)\n", "\n", "trainer.run()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# now let's perform some evaluation\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train final score: 5000/5000 = 100.00% correct\n", "test final score: 5000/5000 = 100.00% correct\n" ] } ], "source": [ "def eval_split(trainer, split, max_batches):\n", " dataset = {'train':train_dataset, 'test':test_dataset}[split]\n", " n = train_dataset.length # naugy direct access shrug\n", " results = []\n", " mistakes_printed_already = 0\n", " loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n", " for b, (x, y) in enumerate(loader):\n", " x = x.to(trainer.device)\n", " y = y.to(trainer.device)\n", " # isolate the input pattern alone\n", " inp = x[:, :n]\n", " sol = y[:, -n:]\n", " # let the model sample the rest of the sequence\n", " cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n", " sol_candidate = cat[:, n:] # isolate the filled in sequence\n", " # compare the predicted sequence to the true sequence\n", " correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n", " for i in range(x.size(0)):\n", " results.append(int(correct[i]))\n", " if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n", " mistakes_printed_already += 1\n", " print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n", " if max_batches is not None and b+1 >= max_batches:\n", " break\n", " rt = torch.tensor(results, dtype=torch.float)\n", " print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n", " return rt.sum()\n", "\n", "# run a lot of examples from both train and test through the model and verify the output correctness\n", "with torch.no_grad():\n", " train_score = eval_split(trainer, 'train', max_batches=50)\n", " test_score = eval_split(trainer, 'test', max_batches=50)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input sequence : [[0, 0, 2, 1, 0, 1]]\n", "predicted sorted: [[0, 0, 0, 1, 1, 2]]\n", "gt sort : [0, 0, 0, 1, 1, 2]\n", "matches : True\n" ] } ], "source": [ "# let's run a random given sequence through the model as well\n", "n = train_dataset.length # naugy direct access shrug\n", "inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n", "assert inp[0].nelement() == n\n", "with torch.no_grad():\n", " cat = model.generate(inp, n, do_sample=False)\n", "sol = torch.sort(inp[0])[0]\n", "sol_candidate = cat[:, n:]\n", "print('input sequence :', inp.tolist())\n", "print('predicted sorted:', sol_candidate.tolist())\n", "print('gt sort :', sol.tolist())\n", "print('matches :', bool((sol == sol_candidate).all()))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.4 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" } } }, "nbformat": 4, "nbformat_minor": 2 }