{ "cells": [ { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from datasets import load_dataset\n", "import fastcore.all as fc\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import torchvision.transforms.functional as TF\n", "from torch.utils.data import default_collate, DataLoader\n", "import torch.optim as optim\n", "import pickle\n", "%matplotlib inline\n", "plt.rcParams['figure.figsize'] = [2, 2]" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", "100%|██████████| 2/2 [00:00<00:00, 35.54it/s]\n" ] } ], "source": [ "dataset_nm = 'mnist'\n", "x,y = 'image', 'label'\n", "ds = load_dataset(dataset_nm)" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def transform_ds(b):\n", " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n", " return b\n", "\n", "dst = ds.with_transform(transform_ds)\n", "plt.imshow(dst['train'][0]['image'].permute(1,2,0));" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bs = 1024\n", "class DataLoaders:\n", " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n", " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n", " self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n", "\n", "def collate_fn(b):\n", " collate = default_collate(b)\n", " return (collate[x], collate[y])\n", "\n", "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n", "xb,yb = next(iter(dls.train))\n", "xb.shape, yb.shape" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "class Reshape(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", " \n", " def forward(self, x):\n", " return x.reshape(self.dim)" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", " layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n", " if norm:\n", " layers.append(norm)\n", " if act:\n", " layers.append(act())\n", " return nn.Sequential(*layers)\n", "\n", "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", " return nn.Sequential(\n", " conv(ni, nf, ks=ks, s=1, norm=norm, act=act),\n", " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n", " )\n", "\n", "class ResBlock(nn.Module):\n", " def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n", " super().__init__()\n", " self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n", " self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n", " self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n", " self.act = act()\n", " \n", " def forward(self, x):\n", " return self.act(self.convs(x) + self.idconv(self.pool(x)))" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "def cnn_classifier():\n", " return nn.Sequential(\n", " ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n", " ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n", " ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n", " ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n", " ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n", " conv(64, 10, act=False),\n", " nn.Flatten(),\n", " )" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "def kaiming_init(m):\n", " if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n", " nn.init.kaiming_normal_(m.weight) " ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train, epoch:1, loss: 0.1077, accuracy: 0.9104\n", "eval, epoch:1, loss: 0.0382, accuracy: 0.9791\n", "train, epoch:2, loss: 0.0410, accuracy: 0.9832\n", "eval, epoch:2, loss: 0.0221, accuracy: 0.9866\n", "train, epoch:3, loss: 0.0538, accuracy: 0.9871\n", "eval, epoch:3, loss: 0.0141, accuracy: 0.9887\n", "train, epoch:4, loss: 0.0343, accuracy: 0.9858\n", "eval, epoch:4, loss: 0.0163, accuracy: 0.9871\n", "train, epoch:5, loss: 0.0390, accuracy: 0.9865\n", "eval, epoch:5, loss: 0.0169, accuracy: 0.9871\n" ] } ], "source": [ "model = cnn_classifier()\n", "model.apply(kaiming_init)\n", "lr = 0.1\n", "max_lr = 0.3\n", "epochs = 5\n", "opt = optim.AdamW(model.parameters(), lr=lr)\n", "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", "for epoch in range(epochs):\n", " for train in (True, False):\n", " accuracy = 0\n", " dl = dls.train if train else dls.valid\n", " for xb,yb in dl:\n", " preds = model(xb)\n", " loss = F.cross_entropy(preds, yb)\n", " if train:\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " with torch.no_grad():\n", " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", " if train:\n", " sched.step()\n", " accuracy /= len(dl)\n", " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")" ] }, { "cell_type": "code", "execution_count": 110, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "with open('./classifier.pkl', 'wb') as model_file:\n", " pickle.dump(model, model_file)" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "exclude" ] }, "source": [ "#### commit to .py file for deployment" ] }, { "cell_type": "code", "execution_count": 111, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n", "[NbConvertApp] Writing 3691 bytes to mnist_classifier.py\n" ] } ], "source": [ "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n" ] } ], "metadata": { "kernelspec": { "display_name": "python_main", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }