{ "cells": [ { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from datasets import load_dataset\n", "import fastcore.all as fc\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import torchvision.transforms.functional as TF\n", "from torch.utils.data import default_collate, DataLoader\n", "import torch.optim as optim\n", "import pickle\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "plt.rcParams['figure.figsize'] = [2, 2]" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", "100%|██████████| 2/2 [00:00<00:00, 35.54it/s]\n" ] } ], "source": [ "dataset_nm = 'mnist'\n", "x,y = 'image', 'label'\n", "ds = load_dataset(dataset_nm)" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "def transform_ds(b):\n", " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n", " return b" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "dst = ds.with_transform(transform_ds)\n", "plt.imshow(dst['train'][0]['image'].permute(1,2,0));" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bs = 1024\n", "class DataLoaders:\n", " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n", " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n", " self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n", "\n", "def collate_fn(b):\n", " collate = default_collate(b)\n", " return (collate[x], collate[y])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n", "xb,yb = next(iter(dls.train))\n", "xb.shape, yb.shape" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "class Reshape(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", " \n", " def forward(self, x):\n", " return x.reshape(self.dim)" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", " layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n", " if norm:\n", " layers.append(norm)\n", " if act:\n", " layers.append(act())\n", " return nn.Sequential(*layers)\n", "\n", "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", " return nn.Sequential(\n", " conv(ni, nf, ks=ks, s=1, norm=norm, act=act),\n", " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n", " )\n", "\n", "class ResBlock(nn.Module):\n", " def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n", " super().__init__()\n", " self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n", " self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n", " self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n", " self.act = act()\n", " \n", " def forward(self, x):\n", " return self.act(self.convs(x) + self.idconv(self.pool(x)))" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "def cnn_classifier():\n", " return nn.Sequential(\n", " ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n", " ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n", " ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n", " ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n", " ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n", " conv(64, 10, act=False),\n", " nn.Flatten(),\n", " )" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "def kaiming_init(m):\n", " if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n", " nn.init.kaiming_normal_(m.weight) " ] }, { "cell_type": "code", "execution_count": 195, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train, epoch:1, loss: 0.0776, accuracy: 0.9172\n", "eval, epoch:1, loss: 0.0372, accuracy: 0.9818\n", "train, epoch:2, loss: 0.0571, accuracy: 0.9828\n", "eval, epoch:2, loss: 0.0287, accuracy: 0.9863\n", "train, epoch:3, loss: 0.0425, accuracy: 0.9847\n", "eval, epoch:3, loss: 0.0256, accuracy: 0.9865\n", "train, epoch:4, loss: 0.0271, accuracy: 0.9868\n", "eval, epoch:4, loss: 0.0378, accuracy: 0.9826\n", "train, epoch:5, loss: 0.0395, accuracy: 0.9844\n", "eval, epoch:5, loss: 0.0307, accuracy: 0.9873\n" ] } ], "source": [ "model = cnn_classifier()\n", "model.apply(kaiming_init)\n", "lr = 0.1\n", "max_lr = 0.3\n", "epochs = 5\n", "opt = optim.AdamW(model.parameters(), lr=lr)\n", "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", "for epoch in range(epochs):\n", " for train in (True, False):\n", " accuracy = 0\n", " dl = dls.train if train else dls.valid\n", " for xb,yb in dl:\n", " preds = model(xb)\n", " loss = F.cross_entropy(preds, yb)\n", " if train:\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " with torch.no_grad():\n", " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", " if train:\n", " sched.step()\n", " accuracy /= len(dl)\n", " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")" ] }, { "cell_type": "code", "execution_count": 196, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "torch.save(model.state_dict(), 'classifier.pth')" ] }, { "cell_type": "code", "execution_count": 197, "metadata": {}, "outputs": [], "source": [ "loaded_model = cnn_classifier()\n", "loaded_model.load_state_dict(torch.load('classifier.pth'))\n", "loaded_model.eval();" ] }, { "cell_type": "code", "execution_count": 206, "metadata": {}, "outputs": [], "source": [ "def predict(img):\n", " with torch.no_grad():\n", " img = img[None,]\n", " pred = loaded_model(img)[0]\n", " pred_probs = F.softmax(pred, dim=0)\n", " pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n", " pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n", " return pred" ] }, { "cell_type": "code", "execution_count": 204, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(5)\n" ] }, { "data": { "text/plain": [ "[{'digit': 0, 'prob': '21.42%', 'logits': tensor(0.0559)},\n", " {'digit': 8, 'prob': '19.44%', 'logits': tensor(-0.0408)},\n", " {'digit': 4, 'prob': '18.08%', 'logits': tensor(-0.1135)},\n", " {'digit': 9, 'prob': '16.41%', 'logits': tensor(-0.2104)},\n", " {'digit': 6, 'prob': '12.23%', 'logits': tensor(-0.5049)},\n", " {'digit': 1, 'prob': '6.87%', 'logits': tensor(-1.0806)},\n", " {'digit': 7, 'prob': '2.33%', 'logits': tensor(-2.1633)},\n", " {'digit': 5, 'prob': '1.19%', 'logits': tensor(-2.8386)},\n", " {'digit': 2, 'prob': '1.06%', 'logits': tensor(-2.9527)},\n", " {'digit': 3, 'prob': '0.97%', 'logits': tensor(-3.0359)}]" ] }, "execution_count": 204, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = xb[0].reshape(1, 28, 28)\n", "print(yb[0])\n", "predict(img)" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "exclude" ] }, "source": [ "#### commit to .py file for deployment" ] }, { "cell_type": "code", "execution_count": 205, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n", "[NbConvertApp] Writing 2904 bytes to mnist_classifier.py\n" ] } ], "source": [ "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "python_main", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }