File size: 11,266 Bytes
4673b21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data.dataloader import DataLoader\n",
"from mingpt.utils import set_seed\n",
"set_seed(3407)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"class SortDataset(Dataset):\n",
" \"\"\" \n",
" Dataset for the Sort problem. E.g. for problem length 6:\n",
" Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n",
" Which will feed into the transformer concatenated as:\n",
" input: 0 0 2 1 0 1 0 0 0 1 1\n",
" output: I I I I I 0 0 0 1 1 2\n",
" where I is \"ignore\", as the transformer is reading the input sequence\n",
" \"\"\"\n",
"\n",
" def __init__(self, split, length=6, num_digits=3):\n",
" assert split in {'train', 'test'}\n",
" self.split = split\n",
" self.length = length\n",
" self.num_digits = num_digits\n",
" \n",
" def __len__(self):\n",
" return 10000 # ...\n",
" \n",
" def get_vocab_size(self):\n",
" return self.num_digits\n",
" \n",
" def get_block_size(self):\n",
" # the length of the sequence that will feed into transformer, \n",
" # containing concatenated input and the output, but -1 because\n",
" # the transformer starts making predictions at the last input element\n",
" return self.length * 2 - 1\n",
"\n",
" def __getitem__(self, idx):\n",
" \n",
" # use rejection sampling to generate an input example from the desired split\n",
" while True:\n",
" # generate some random integers\n",
" inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n",
" # half of the time let's try to boost the number of examples that \n",
" # have a large number of repeats, as this is what the model seems to struggle\n",
" # with later in training, and they are kind of rate\n",
" if torch.rand(1).item() < 0.5:\n",
" if inp.unique().nelement() > self.length // 2:\n",
" # too many unqiue digits, re-sample\n",
" continue\n",
" # figure out if this generated example is train or test based on its hash\n",
" h = hash(pickle.dumps(inp.tolist()))\n",
" inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n",
" if inp_split == self.split:\n",
" break # ok\n",
" \n",
" # solve the task: i.e. sort\n",
" sol = torch.sort(inp)[0]\n",
"\n",
" # concatenate the problem specification and the solution\n",
" cat = torch.cat((inp, sol), dim=0)\n",
"\n",
" # the inputs to the transformer will be the offset sequence\n",
" x = cat[:-1].clone()\n",
" y = cat[1:].clone()\n",
" # we only want to predict at output locations, mask out the loss at the input locations\n",
" y[:self.length-1] = -1\n",
" return x, y\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 -1\n",
"0 -1\n",
"1 -1\n",
"0 -1\n",
"0 -1\n",
"0 0\n",
"0 0\n",
"0 0\n",
"0 0\n",
"0 1\n",
"1 1\n"
]
}
],
"source": [
"# print an example instance of the dataset\n",
"train_dataset = SortDataset('train')\n",
"test_dataset = SortDataset('test')\n",
"x, y = train_dataset[0]\n",
"for a, b in zip(x,y):\n",
" print(int(a),int(b))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of parameters: 0.09M\n"
]
}
],
"source": [
"# create a GPT instance\n",
"from mingpt.model import GPT\n",
"\n",
"model_config = GPT.get_default_config()\n",
"model_config.model_type = 'gpt-nano'\n",
"model_config.vocab_size = train_dataset.get_vocab_size()\n",
"model_config.block_size = train_dataset.get_block_size()\n",
"model = GPT(model_config)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"running on device cuda\n"
]
}
],
"source": [
"# create a Trainer object\n",
"from mingpt.trainer import Trainer\n",
"\n",
"train_config = Trainer.get_default_config()\n",
"train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n",
"train_config.max_iters = 2000\n",
"train_config.num_workers = 0\n",
"trainer = Trainer(train_config, model, train_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iter_dt 0.00ms; iter 0: train loss 1.06407\n",
"iter_dt 18.17ms; iter 100: train loss 0.14712\n",
"iter_dt 18.70ms; iter 200: train loss 0.05315\n",
"iter_dt 19.65ms; iter 300: train loss 0.04404\n",
"iter_dt 31.64ms; iter 400: train loss 0.04724\n",
"iter_dt 18.43ms; iter 500: train loss 0.02521\n",
"iter_dt 19.83ms; iter 600: train loss 0.03352\n",
"iter_dt 19.58ms; iter 700: train loss 0.00539\n",
"iter_dt 18.72ms; iter 800: train loss 0.02057\n",
"iter_dt 18.26ms; iter 900: train loss 0.00360\n",
"iter_dt 18.50ms; iter 1000: train loss 0.00788\n",
"iter_dt 20.64ms; iter 1100: train loss 0.01162\n",
"iter_dt 18.63ms; iter 1200: train loss 0.00963\n",
"iter_dt 18.32ms; iter 1300: train loss 0.02066\n",
"iter_dt 18.40ms; iter 1400: train loss 0.01739\n",
"iter_dt 18.37ms; iter 1500: train loss 0.00376\n",
"iter_dt 18.67ms; iter 1600: train loss 0.00133\n",
"iter_dt 18.38ms; iter 1700: train loss 0.00179\n",
"iter_dt 18.66ms; iter 1800: train loss 0.00079\n",
"iter_dt 18.48ms; iter 1900: train loss 0.00042\n"
]
}
],
"source": [
"def batch_end_callback(trainer):\n",
" if trainer.iter_num % 100 == 0:\n",
" print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n",
"trainer.set_callback('on_batch_end', batch_end_callback)\n",
"\n",
"trainer.run()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# now let's perform some evaluation\n",
"model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train final score: 5000/5000 = 100.00% correct\n",
"test final score: 5000/5000 = 100.00% correct\n"
]
}
],
"source": [
"def eval_split(trainer, split, max_batches):\n",
" dataset = {'train':train_dataset, 'test':test_dataset}[split]\n",
" n = train_dataset.length # naugy direct access shrug\n",
" results = []\n",
" mistakes_printed_already = 0\n",
" loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n",
" for b, (x, y) in enumerate(loader):\n",
" x = x.to(trainer.device)\n",
" y = y.to(trainer.device)\n",
" # isolate the input pattern alone\n",
" inp = x[:, :n]\n",
" sol = y[:, -n:]\n",
" # let the model sample the rest of the sequence\n",
" cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n",
" sol_candidate = cat[:, n:] # isolate the filled in sequence\n",
" # compare the predicted sequence to the true sequence\n",
" correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
" for i in range(x.size(0)):\n",
" results.append(int(correct[i]))\n",
" if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n",
" mistakes_printed_already += 1\n",
" print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n",
" if max_batches is not None and b+1 >= max_batches:\n",
" break\n",
" rt = torch.tensor(results, dtype=torch.float)\n",
" print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n",
" return rt.sum()\n",
"\n",
"# run a lot of examples from both train and test through the model and verify the output correctness\n",
"with torch.no_grad():\n",
" train_score = eval_split(trainer, 'train', max_batches=50)\n",
" test_score = eval_split(trainer, 'test', max_batches=50)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input sequence : [[0, 0, 2, 1, 0, 1]]\n",
"predicted sorted: [[0, 0, 0, 1, 1, 2]]\n",
"gt sort : [0, 0, 0, 1, 1, 2]\n",
"matches : True\n"
]
}
],
"source": [
"# let's run a random given sequence through the model as well\n",
"n = train_dataset.length # naugy direct access shrug\n",
"inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n",
"assert inp[0].nelement() == n\n",
"with torch.no_grad():\n",
" cat = model.generate(inp, n, do_sample=False)\n",
"sol = torch.sort(inp[0])[0]\n",
"sol_candidate = cat[:, n:]\n",
"print('input sequence :', inp.tolist())\n",
"print('predicted sorted:', sol_candidate.tolist())\n",
"print('gt sort :', sol.tolist())\n",
"print('matches :', bool((sol == sol_candidate).all()))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|