{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from datasets import load_dataset\n", "import fastcore.all as fc\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import torchvision.transforms.functional as TF\n", "from torch.utils.data import default_collate, DataLoader\n", "import torch.optim as optim\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "plt.rcParams['figure.figsize'] = [2, 2]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", "100%|██████████| 2/2 [00:00<00:00, 75.69it/s]\n" ] } ], "source": [ "dataset_nm = 'mnist'\n", "x,y = 'image', 'label'\n", "ds = load_dataset(dataset_nm)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def transform_ds(b):\n", " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n", " return b" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dst = ds.with_transform(transform_ds)\n", "plt.imshow(dst['train'][0]['image'].permute(1,2,0));" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "bs = 1024\n", "class DataLoaders:\n", " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n", " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n", " self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)\n", "\n", "def collate_fn(b):\n", " collate = default_collate(b)\n", " return (collate[x], collate[y])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n", "xb,yb = next(iter(dls.train))\n", "xb.shape, yb.shape" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Reshape(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", " \n", " def forward(self, x):\n", " return x.reshape(self.dim)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", " layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n", " if norm:\n", " layers.append(norm)\n", " if act:\n", " layers.append(act())\n", " return nn.Sequential(*layers)\n", "\n", "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", " return nn.Sequential(\n", " conv(ni, nf, ks=ks, s=1, norm=None, act=act),\n", " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n", " )\n", "\n", "class ResBlock(nn.Module):\n", " def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n", " super().__init__()\n", " self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n", " self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n", " self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n", " self.act = act()\n", " \n", " def forward(self, x):\n", " return self.act(self.convs(x) + self.idconv(self.pool(x)))" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "# def cnn_classifier():\n", "# return nn.Sequential(\n", "# ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n", "# ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n", "# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n", "# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n", "# ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n", "# conv(64, 10, act=False),\n", "# nn.Flatten(),\n", "# )\n", "\n", "def cnn_classifier():\n", " return nn.Sequential(\n", " ResBlock(1, 8, norm=nn.LayerNorm([8, 14, 14])),\n", " ResBlock(8, 16, norm=nn.LayerNorm([16, 7, 7])),\n", " ResBlock(16, 32, norm=nn.LayerNorm([32, 4, 4])),\n", " ResBlock(32, 64, norm=nn.LayerNorm([64, 2, 2])),\n", " ResBlock(64, 64, norm=nn.LayerNorm([64, 1, 1])),\n", " conv(64, 10, act=False),\n", " nn.Flatten(),\n", " )\n", "\n", "# def cnn_classifier():\n", "# return nn.Sequential(\n", "# ResBlock(1, 8,),\n", "# ResBlock(8, 16, ),\n", "# ResBlock(16, 32,),\n", "# ResBlock(32, 64, ),\n", "# ResBlock(64, 64,),\n", "# conv(64, 10, act=False),\n", "# nn.Flatten(),\n", "# )" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "def kaiming_init(m):\n", " if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n", " nn.init.kaiming_normal_(m.weight) " ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train, epoch:1, loss: 1.8902, accuracy: 0.3183\n", "eval, epoch:1, loss: 1.0976, accuracy: 0.6274\n", "train, epoch:2, loss: 0.5929, accuracy: 0.8003\n", "eval, epoch:2, loss: 0.2895, accuracy: 0.9102\n", "train, epoch:3, loss: 0.2396, accuracy: 0.9264\n", "eval, epoch:3, loss: 0.1343, accuracy: 0.9597\n", "train, epoch:4, loss: 0.1139, accuracy: 0.9651\n", "eval, epoch:4, loss: 0.0801, accuracy: 0.9763\n", "train, epoch:5, loss: 0.1368, accuracy: 0.9582\n", "eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n" ] } ], "source": [ "model = cnn_classifier()\n", "model.apply(kaiming_init)\n", "lr = 0.1\n", "max_lr = 0.3\n", "epochs = 5\n", "opt = optim.AdamW(model.parameters(), lr=lr)\n", "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", "for epoch in range(epochs):\n", " for train in (True, False):\n", " accuracy = 0\n", " total_loss = 0\n", " dl = dls.train if train else dls.valid\n", " for xb,yb in dl:\n", " preds = model(xb)\n", " loss = F.cross_entropy(preds, yb)\n", " if train:\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " with torch.no_grad():\n", " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", " total_loss += loss.item()\n", " if train:\n", " sched.step()\n", " accuracy /= len(dl)\n", " total_loss /= len(dl)\n", " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "eval, epoch:1, loss: 0.0882, accuracy: 0.9722\n", "eval, epoch:2, loss: 0.0882, accuracy: 0.9722\n", "eval, epoch:3, loss: 0.0882, accuracy: 0.9722\n", "eval, epoch:4, loss: 0.0882, accuracy: 0.9722\n", "eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n" ] } ], "source": [ "for epoch in range(epochs):\n", " train = False\n", " accuracy = 0\n", " total_loss = 0\n", " dl = dls.valid\n", " for xb,yb in dl:\n", " preds = model(xb)\n", " loss = F.cross_entropy(preds, yb)\n", " with torch.no_grad():\n", " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", " total_loss += loss.item()\n", " accuracy /= len(dl)\n", " total_loss /= len(dl)\n", " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAc5klEQVR4nO3dd3yV1RkH8N+ThL0CGCDISBDCFEQUQaVOtG7UOhARq1itiijWWm21rlZxIVKUqihacWCduGitG1GGIgioTBkyIwHCkiSnf7yXc97nmje8SW6Sm7y/7+fDh+fkvOvek5ucnCnGGBARERFFRUpVPwARERFRZWLlh4iIiCKFlR8iIiKKFFZ+iIiIKFJY+SEiIqJIYeWHiIiIIqXKKz8iskJEjq/q56gsInKbiDxb1c9RUVieNQfLsmZhedYcLMvyq/LKT2nE3oA9IpLv+9ehgu95sYh8WpH3iLufEZHtvtf3RGXdu7KJSLqIPC0iG2L/bquEe1Z2eZ4mIt/EyvIzEelWWfeuTCJyrYgsE5GtIvKjiIwRkbQKvmellaWI7Cci00UkV0TyRGSGiBxRGfeuCuIZHXu9uSJyr4hIBd+zsj+bB4nIHBHZEfv/oMq6d2Xi783iJbTyU9E/7GJeNMY09P1bVgn3rGy9fK9veFU9RCWU5xgA9QFkAegLYKiI/LaC71lpRKQTgMkArgCQDmAqgDcq6XMS/ywVfc+pAA42xjQG0ANALwDXVPA9K1M+gEsAZABoCmA0gKlVUZZApZTn7wAMgleOPQGcCuDyCr5npRGR2gBeB/AsvPJ8GsDrsa9X9rPw92ZilOr35j4rP7HmtZtEZKGIbBaRp0SkbizvaBFZLSI3isg6AE+JSIqI/ElElsb+YpgiIs181xsqIj/E8v5crpdaCiJygIi8H7vvJhGZLCLpvvy2IvKKiGyMHfMPEekKYAKA/rHaZF7s2A9FZLjvXFXLFZGxIrIq9lfwHBEZUFmvc1+SrDxPA3CvMWaHMWYFgInwfsGEeR3VoTxPBPCJMeZTY0wBvF+Y+wM4KuT5JUqmsjTGLDXG5O29FIAiAB1Dvo6kL0tjzC5jzHfGmKLY6yuE90uzWclnhpdM5QlgGIAHjDGrjTFrADwA4OKQryPpyxPA0QDSADxkjNltjHkYXrkeG/L8EiVZWZbndVSHsiyTsC0/Q+D9ID8AQA6Av/jyWsH7AdAe3l8L18D7i+EoAK0BbAYwHgDEa/J/FMDQWF5zAG32XkhEjtz7RpXgNBH5SUQWiMjvQz4/4H1j3x27b1cAbQHcFrtvKoA3AfwArxVifwAvGGMWwfurfUasNpke8l6zABwE7315DsBLe7/xf/FQIvNE5IK4L38sIuti31RZIe9ZGslUnhIX9wj5GqpDeQp++fpK8xrDSJqyFJELRGQrgE3wWgz+GfI1VIeytF8DsAvAGwCeMMZsCHnfsJKlPLsD+NqX/jr2tTCqQ3l2BzDP6P2d5iH8awwjWcoS4O/NXzLGlPgPwAoAV/jSJwNYGouPBvAzgLq+/EUAjvOlMwHsgVfLvhXem7M3r0Hs/OP39Ryx47vBK4RUAIcDWAtgcJhzi7nWIABfxeL+ADYCSCvmuIsBfBr3tQ8BDC/pmLjjN8NrkgO8b5xnSzj2VwBqw+sm+QeAb4p7rrL+S7LyfBbAKwAawWslWApgd00pTwBdAGyPva+1AdwCr0XkpppWlnHP1QnAnQBa1ZSyjDunLoDBAIYlohyTsTzhtWx1iStTA0BqQnnGPosvxH1tMoDbamBZ8vdmMf/C9jWu8sU/xN7IvTYaY3b50u0BvCoiRb6vFQJoGTvPXssYs11EckM+A4wxC33Jz0RkLIDfAHh+X+eKSAsADwMYAO+XbQq8NxfwarM/GK9rotxE5HoAw+G9XgOgMYD9wpxrjPk4Fv4sIiMBbIVX456fiGeLSYryhPfXzjgAiwHkwivHwWFOrA7laYz5VkSGwfswZsKr7C0EsDoRzxWTLGVpGWMWi8gCAI8AOGtfx1eHsvSLvafPi8giEZlrjPl6nyeFlyzlmQ/vvdmrMYB8E/tNU5JqUp7xrw+x9LZEPFdMUpQlf28WL2y3V1tf3A7Aj/77xh27CsBJxph037+6xus3Xuu/lojUh9eEV1YGuluhJHfHju9pvEGZF/rOXQWgnRQ/8Ky4D/t2eAN192q1N4j1U94I4FwATY3X5LelFM9Z3P0TPcsiKcrTGPOTMWaIMaaVMaY7vO/HmSFPrxblaYz5tzGmhzGmOYC/wvshNyvMuSElRVkWIw1ec38Y1aIsi1ELQKJnzSRLeS6A13W5V6/Y18KoDuW5AEBPETWDrSfCv8YwkqUs4/H3JsJXfq4SkTaxAVg3A3ixhGMnAPibiLQHABHJEJEzYnn/BnBqrI+yNoA7SvEMEJEzRKSpePrCazl43Zf/oQRPl24Er7afJyL7A7jBlzcT3jfYPSLSQETqipvGuh5AG9GzAOYCOEtE6otIRwCXxt2nALHmQBG5Fb/8CyPo9XUXb/plqog0hDfIcA28JtFESpbyPEBEmsde70nw+r7v8uVX6/KMvYY+sdeXAW8MzFRjzLdhzw8hWcpyeOyvxL1jFG4C8D9ffrUuSxHpt/e9EZF6InIjvL/KvwhzfikkRXkCeAbAKBHZX0RaA7gewKS9mdW9POF1wRQCuEZE6ojI1bGvvx/y/DCSoiz5e7N4Yd/A5wD8B8Cy2L+7Sjh2LLzBgP8RkW0APgdwGAAYYxYAuCp2vbXwms9sF4CIDBCR/BKufT6AJfCaJp8BMNoY87Qvvy2A6QHn3g7gYHi1ybfgjTVB7LkK4c086ghgZeyZzotlvw/vr4F1IrIp9rUx8Ppc18ObIjnZd59pAN4B8D28ps5d0M2fingD0IbEki3hfUC2wnufswCcaozZE3R+GSVLefaB1yy5Dd5fGENi19yrupcn4L1/eQC+i/1/WdC5ZZQsZXkEgPkish3A27F/N/vyq3tZ1oE3ADUX3g/WkwGcYoz5Mej8MkqW8vwnvOUL5sMbP/EW9AD2al2expif4Y1fuQje5/ISAINiX0+UZClL/t4s7hr76sIVkRXwBim9V+KBVUxE2gB4yRjTv6qfJZmxPGsOlmXNwvKsOViWya9KFuiqCMaY1fBGn1MNwPKsOViWNQvLs+aIcllWq+0tiIiIiMprn91eRERERDUJW36IiIgoUlj5ISIiokgJNeB5YMo57BurYv8teikhCx2yLKteosoSYHkmA342aw5+NmuWksqTLT9EREQUKaz8EBERUaSw8kNERESRwsoPERERRQorP0RERBQprPwQERFRpLDyQ0RERJHCyg8RERFFCis/REREFCms/BAREVGksPJDREREkcLKDxEREUVKqI1NiYgSpeDYPja+e+IEldendqqNU0X/bVZoigKvecuGg2z80d2Hq7yGUz4vy2MSRU5aZisbrzmng8pLHbjJxqe2W6DyVu5sZuN5GzMDr/9+76dVemJedxu/d96hKq9wwXchnrjs2PJDREREkcLKDxEREUUKu72IKOHSstur9MIbW9r42RNcV1fv2vrvryIYF5vC0Pe7vcVXNh48ooXK2zYl9GWoGJsu76/SM28db+M+o6+2ccuZ+eq45YMa2PiLIQ+ovKap9W18wItXqLxOf3JlaXbvLsMTU0kKjuuj0k88OdbGDVLExi1SG6BM2pWUWU+lRjVbZuO3M45WeamoWGz5ISIiokhh5YeIiIgihZUfIiIiipSkHPOTtn9rG+cerTsQN5++I/C80zvNt/Hy7c1tvOjtHHVcxrw9Nq7z1qwyPydR1Eit2i7Ro5PKK6rrfpwc9/gnKu+1pq+Euv6Wol02PvSt61RenQ1uFMCI37yp8n7XZIWNj2mup8i+ldHFxoUbN4Z6jqhLbdzYxnXPXK/y/OOyZt04zsb35XZTxy3+6GgbbywyKq9hihvP9e2541VeTgM3BqjlR3rkR5PJXLagvGpt0eOoJuUdZuMFW9009fw9ddRx36924/ZqrdR5vm8JZM4oUFn1F+faeMX5rVReu3e22jhtvp4+r79jEo8tP0RERBQprPwQERFRpCRFt9ePN+gVWe+5/Ekbn1Bve+B5KRCVLgpqKLt6WuA1Jm7R3WqFvvrgfzfqZtyF0/WKl36tP3VNfQ3m/Rh4XLzCjW7VTE7rdFLTm9h445muHHafkaeOm9f3eRvHrwDce9YQGzea3FjlNXzpi0Q8ZuSsvv4QG381YlwJR4bzcv5+Kn3TJ2fbOOeKmSrPP32+39ClcVdy3SMf5HZWOezqKoMWbtjAhwe+FOqUG5ov1Omz/Om6oW/9/SluKYQPjtXn3bnzEhvXf1V/f8BUdEdJzWBmf6PSs884wMZ5E1yVYFh73cU47t1BNm5972eh7+dfsKLtnctUXlWWGFt+iIiIKFJY+SEiIqJIYeWHiIiIIiUpxvzUytc9fyWN8/Hr9tRVKt1ksYvz27rxQN1O/F4d93wHNwbosiarVJ5/3JB/+iwAoGPws6QMc/cLHHtUjEPuH2HjVmPC96PWNKnd9TiNonFuqfx72j9m48veGa6OO36sGwOQlqfHTBWc4sYNHX7TDJVX62Y3PuiDe/WYs8bPczptkB2ZwTurh3XM/HNs3OgPtVVezjezA89bN9AtgdGzdvDi95t2NlTpOtgUcCTtZfr3UumLJr2R0Ouft/TXKn1I+kobx48V8jum3i6dHveIjY81V6q8+q9yHF8YRUcepNJ3/Mv9fO1cy32++z46Sh2XNcGNFQq/8UzyYssPERERRQorP0RERBQpSdHt1fKzvNDH9vyn6ybKuiO4m6ipL86/R7/MQS1PC3Wvtafpnam3tw11WolqbdXT89s8MsfGUZuo6Z/OfsVrU1XemBUDbXzf+RfYuNOs4Kbt+PevrdscGgv/pQtv2f3pNj77xukqb87CHjYu+npR4P2iqPFS9/fSXzbo3aFvznBdi4MXn63yFn3XxsaZH7hrFH2juxi3Du7nEvqjgqEj3wl8Lv/K0BibEZe7IvA88uxqqVfsTU8NXkm/TNe/tJFKf9DS7RS/Z7zuwrx5v/kI48Cbv1bppa+W8eEiIO8i934/cvtYlefvQj7kvmtt3PYh/fu1JnR1+bHlh4iIiCKFlR8iIiKKFFZ+iIiIKFKqbMxPWpbbVqL/v74KPC5+imT2I27H5rB9kKZA7zJbsCbc9hMZE/Rx8SMJEiFK43xS6tdX6bTX69n4lgVnqLzWQ9xU2KLtP5T73gU/6CUN2p3j0p+c1l/lbb3D7TSceUkzlVeY+1O5n6U6a/mwGwcw7yk9jmNw66EusV5vKVH/9+5Hzabe7rt+wJ/09e9qoXf4DuvQqW4H+Jy3ZpZwJBVn/aF63E2P2rm+VD2EsblIT0sfMOkPNs5e+aXKS1nstjl4Z43eRijsmB8K74DLv7Vxnzpxy0t8NMzG2Q9FZ7kVtvwQERFRpLDyQ0RERJFSZd1eBRPdSpI3/WKFTzfHde2EA1TO9otcfa3Zd3sS/lz1PnbNg0XbtiX8+lG29Em9RPaijpNsfPohJ6u8gu3hVvkOK35V05Unuab8rD/r1Z/r+mbd17TpnYkU//lIqeeWExg+U3dln97g/Qp9lm73b7BxQQnHkZPaMdvGWf10t3DL1HBdXWsLd9r4pAl/VHlZf3ddKFHq3k8GRUf1Vunnsp+y8Q3rdF6HoQtsHKVyYssPERERRQorP0RERBQprPwQERFRpFTZmJ+3u7hdg0vaBf3j+4KnvqbErX9fmt3Ug67xQr6b0H7L53r6ddcb19i4YN36Ut8ritL2dztxLxowSeX1fvBqG2euTfwUS+nd3cb1/rZO5dXZ3DT+cCql1O6dVXrIi9NsfHqDzeW+/vg8Pd5v1S637MA9rWapvCW/zbRx1i0ryn3vmmj53XpJh4793BISb+ZMjT88lC93t7Jxm7+H/wzvHNTXxld2eKVM9z63mV7S4I/DLrdx06dnxB9e46U0cktPFNySG3jc6+/2U+nsgui9VwBbfoiIiChiWPkhIiKiSEmKXd23xK0MetbCC2289utWKq/pAgTa7Ho5kNlrXeBxx7T63sa3xq0mem5DN2X23OMfV3nj3utk4yde0CtPt70rOitjlsb3I9vbeNZu3S2Z+UBi37Nt5+vm3HNvedfGV6UvVXk9HxuR0HtH0eZeuuvQ/9kJa85unR7y2lU27jhK7/ie2t19/vAf3e314bD7bHzhByNVXtr7c0r9XDXRrwfOVukHMj8PODK8B2+4wMb1EH5l7dyu7lfP4EZlG0Lw3rbuKp3xnuvGi+JyBxvP72Hj2d0eVXnTd7mlZTreE/dLtLNbgmRrj+Y2Lqyth4SsG1CEIJkfuXaU9Pf1z9rCjRvjD08KbPkhIiKiSGHlh4iIiCKlyrq9Tu97qksU6ea0emuX27gDliOssPN3ZjZId8+RfmrgcUuubK/Sc4aNsfHg381TeYdnX2vjnEt183KU+Gd3AcBDZ7mVRS/7eqjKa434lb2LJ2m+b9MD9QyjXaPdStBrV+rG7of/d6KNs06arPKyx7qmX67iXDaNftDd1VevOdLGHyzrpPIyXnYrBtfNdeVUZ32+Oq7jgrJ1xeznW5F45a/1xo0dKnZx6aRWeMzBNj644ZsJuebIH4+wcaO5bnhBSV1NqRl6W2jTZ2vAkeH9tKeBSofdsLqmyusSPNs5PcX1L58wY6XKO7GBGx7QtbbefDq0QS58a0ddlTVqzrk2PuByvUl1Yd6Wst0vAdjyQ0RERJHCyg8RERFFCis/REREFClVt6t7FfbPFvl2DC8qYffwrD/rZ+y7Y5SNH7lkgsqbdvxYG1/b9WKVV7hocVkes1oqbKFHXp1Qz72/t9XSowLixwf5Lbs0y8a/OtXtEJ5eS0/T/GS0m97e+eUvVd6BX+yx8Z+fukjltcnj0gTlJdPnqvSKfqk2zi6ahzBKNd5q/SYbxq/+7F/KoFd//XnTe89Hy7Kzatl4SKO1ZbrGDesO09e8wr33ZkUJa4/4zxmvlyyZ339SmZ6FyqZ77Xq+eIXKW13gxtxO3OLK6W+zTg68Xq2VdVS6oL0b//fBUeNU3ncDnrFxh9uvUHmdRpZ/uYWyYssPERERRQorP0RERBQpSbHCc3XR9m+uq2TEbt18N3/UIzZedI3u+sn5fcU+VzJJWb5apc9c7JYS+Kz38/rYmW4F0fhNaQcuONvG3955oI3rTtWryDaCazZddo/euLHtHtf10n7iEpXH6e0VoKiC39UC122aX1i3hAOja9epfVX63dMf9KXK9p5NX9tBpZvNCdfVlZbtlgqZ0GdyCUeWzaz17VS6Gb4PODIast9w09l7ZA9RebuWu01PW+jF0ZE+262wXbjELS3TCXoYQVi/z7pApV/8dIqNz/mV7uaaW6Y7JAZbfoiIiChSWPkhIiKiSGHlh4iIiCKFY37KSOJWEi80brrgiAHvqbxpaFwZj5QU4pcrNye76eZ9L9E7qRvfpsGZE+eqvDo7VvhSKxAkrW0bG78z+D6VN2j8H23cej2ntlel1E6+cSO1awUfWIIdWe5zdGPz4D0rvvyyo0p3wqaAI2uewrp6J+7stLKN89lUuNPGuz7aLy433NiaFfc3tPERdfeUcGTZ7HfeGpUO3nM8GlI+ckuC7P9R+PMSPVJvR+cWKl1HyvZ5r2hs+SEiIqJIYeWHiIiIIoXdXhVg3MfHq3QOZgYcWfMV7dhh4xb/CO56KmuT9Y4n3arCJ3yiu9U6Peje9+D9jqk05FC37MBP3RuqvNxe7l3u2lvv3jwu263y2iatHhJtdG53d++H9ErGJe02XtM0maVXpf/VPLej9sc9p8QfHmhVoVvBt/V9wZ/btDb72zj3MV2uf+2YmF3k/Tq/cqWNc3bPSfj1qWwkzVUlhj/8qsqrJe5n9Etf91F5nVB1ZciWHyIiIooUVn6IiIgoUlj5ISIiokjhmJ9SSKlf38YF/bcGHtf8y9TAPCqf3accqtLPdB5j46FjR6k8UxCl0R4Vo/CYg1X6romP27hPnfijgy342U13PW/u+TZOTdGjvbo03WDjx9p+GPr6OXXdOJ9Plkd364uCH1ap9Jbph9t4dbedKq+ksVfDx420ccujdgYed9x4N6d6RNPFoZ+zJOt90+yPef4GlZdzq9tygZ/vqpOa3kSl97zslqEY0mi2ylu6J9/GXe/IVXlVWYJs+SEiIqJIYeWHiIiIIoXdXqWw7KZeNv6m/z9U3ujcbjbOePEblRf1lUfLKzUjw8bjxz+s8k58wq3i3O5NruKcaM3u1FPW/V1dX/2sv7MvG+u6Smrl68UFWrzmVgVuusl1j6iVnwHMPsNNpcd1H4Z+zmPruSne916md5Vu/viM0Nepadre5T4TYwYdo/IeyPw8/nBrzvXjKuyZ4p2w8CyVXjujtY073KrLjktWOCkNGqh0nu+zkz51gcor2rat3Nf/uV8XG2/9Q57Km9H15cDrXHzd9Tauv+yLUj9HRWHLDxEREUUKKz9EREQUKaz8EBERUaRwzE8c/9Te5afVVnkzzrnfxinQ00Qf//QoG+dsi+52FhVh6Ui3S/fc3W1UXvZzbopzoncnjqrU7p1t/GC7iXG57vv+vPeuVDldJ7lxBkXZupyWjswp9l6XnzlNpUuaLl3kGz23vnC3ymsk7u+4VkNWqLw9j4OS2LYprVW6fYTHaJXGkr/2VOnFFz5q44GXnqbyNj/Xw8YtPl4feM1vr3HjK/90/FSV97sm0wPPW13gprMP9I3DBIB2ryXPOB8/tvwQERFRpLDyQ0RERJFSY7u9/Ksxo1N7lbf0vHQbjxr0hsr7TSM3xbNJil4pdnWBm2h5xOSrVF7XB5bYmN0v5edfQfS6s1wZ/X3Seeq4Nks4vT3h1m+y4ZStumnd3y31/ckT9HknJ/Yx4qfSX3HvNTbOeFR3jRQdeZCNa+VuT+yD1BBLTmmqv/Bl8cdVhL9s0Lt5v/pufxt3eGqWyuN09nA6PrtZpU8/7Nc2npLzospremd9lNYeo3+TXbv2MBu/Mb+XyssZ/7ON282qHj+T2fJDREREkcLKDxEREUUKKz9EREQUKVU25ufn/7pxOPm79fbQ8u/moa6xsZ/uk2zfwe0IfVDz1Ta+r9WzgddIgaj0Zt8wg5xpl6u8Lg/vsHGHuXrMAcf5JNaie93U6EvTvrJx+9c3qeP4vide4Sa38/LLtw9UeYMfmGfj/VKDdwWP9+BPbmn8x+YeGXicyXPLS3R55CeVl7EoeAp0yqdzbczvieIV5W1R6cPmuG1AvujzXMLvN3DB2Tauf4HeXiF7kytLjvEpm6J536r0brfaCi7oeYnKW3FmMxuPGaqXr6gre2z8249/a+O2r6eq4+q95pZw6YQ5Kq86liFbfoiIiChSWPkhIiKiSKmybq/3ur1q46L4RrPe4a4R32Xlv87ygl02vn7t0eq4adMOsXHmpwUqr8HcVTbOWTs77vpUUeJ3EB599BQb3/bYhTZuvbB6TKOsKRq+pFdnvfil4C6rsDriq30fBHZfJZrZrVfFzrzSrcp7bF+9Wvdlf3e7dA9uFLwi8KH3jLBxozW6xBp/tc7GBb6uVKp48V1i7VxvNcbc3jXwvJy47qyajC0/REREFCms/BAREVGksPJDREREkVJlY36Ou8xNIy+op+tgG87ZaeNmjXeovIuz3BTJB+cdp/JSFzW0cfu33dRKM2u+Oi4LwVNmCwJzqCKtvlIvl96l9v9s3PZtN72d40CIEqNg9Rob1/fFADD5lTYuRhsEaYngMXj8WUrJjC0/REREFCms/BAREVGkVFm3V5233U6+deLysl9GoFeR4Y7DvMDjquOKk1E26cqHVPo3L1xn4+yFwd2UREREpcWWHyIiIooUVn6IiIgoUlj5ISIiokipsjE/RH43Z/dV6ewSliMgIiIqD7b8EBERUaSw8kNERESRIsZwUjgRERFFB1t+iIiIKFJY+SEiIqJIYeWHiIiIIoWVHyIiIooUVn6IiIgoUlj5ISIiokj5P0TtsIzRFepPAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xbv,ybv = next(iter(dls.train))\n", "logits = model(xbv)\n", "probs = F.softmax(logits, dim=1)\n", "idx = 5\n", "_,axs = plt.subplots(1, idx, figsize=(10, 10))\n", "for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n", " ax.imshow(im)\n", " ax.set_axis_off()\n", " ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n", " " ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "tags": [ "exclude" ] }, "outputs": [], "source": [ "torch.save(model.state_dict(), 'classifier.pth')" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "loaded_model = cnn_classifier()\n", "loaded_model.load_state_dict(torch.load('classifier.pth'));\n", "loaded_model.eval();" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaPklEQVR4nO3dd3hVRfoH8O+bRopEOpEiCCSAiiKIgg1RdBXrsmtDEVQUxMK6Cuy6P8W26C6yrqhrxQa6CIrYXRFsKAKiSBMQpKj0SO9J5vfHOZk5c5/ccJLcfr6f58nzvJM55565d3Jv5k47opQCERERUVCkxbsARERERLHExg8REREFChs/REREFChs/BAREVGgsPFDREREgcLGDxEREQVK3Bs/IrJKRHrGuxyxIiL3iMj4eJcjWlifqYN1mVpYn6mDdVlzcW/8VIWI/ElEfhKR7SKyVkQeEZGMKF+zv4jMiOY1PNdqICJfikixiGwVkZkicnIsrh0PItJDRD4RkW0isipG14xZfbrXe0ZElopImYj0j9V140VEskRkiYj8EoNrxbQuPdftJyJKRAbE+tqxIiIfiMhOz89+EVkQ5WvG8rP21JDnt9Ot0z/E4vqxJiKdRORz93luEJEhUb5erD9n00XkAbddsENEvhOROpWdE9HGT7QbIgDeAdBJKZUP4GgAxwK4NcrXjKWdAK4F0BBAXQD/APBODF7XCsXgursAPA9gaJSvE0/fAxgM4Nt4FiKGf0NDAWyM0bViTkTqAvgrgEVxLkdU61Mpda5S6pDyHwBfAZgUzWvGklLqi5Dndz6cz98PY12WGHyBbwDneT0NoD6ANgA+iuY14+BeACcB6AYgH0BfAHsrO+GgjR+3e+2vIrJYRLaIyAsiku3mnS4iv4jIcBFZD+AFEUkTkb+IyAq3B2OiiNTzPF5fEVnt5v2tKs9OKbVCKbW1/KEAlMGpyIMSkdYiMt297mYRecXbMhSR5iIyWUQ2ucc8LiLtATwFoJvbYt7qHvup91tfaCtXRB4VkZ/dHqq5InKqz+e3Vym1VClV5j6/UjiNoHqVn+lfgtXnbKXUOAA/VeN5JHx9us/xCaXUNBzkjVgdiVSX7vlHALgKwINVPC8p6tL1IIAxADZX8byDSrT69DxOSwCnAhjn8/hkqs9y/QC8rpTaVc3zLQlWl38G8D+l1CtKqX1KqR1KqR98Po+Er0txvpD8CcD1SqnVyrFQKVWzxo/rSgC/A9AaQBGA//PkFcD559wCwA1wemIuBtAdQBMAWwA84RbySABPwmmVNYHTCm3meRKnlL9Q4YhIHxHZDufD51g4rVk/BM4HVxMA7QE0B3CP+5jpAN4FsBpASwBNAUxw/0AGAZjpfkOo4/NacwB0hPO6vApgUvkffgXPZ76I9An9HZx/lm8DeE4pFelv0glTnzWQNPUZZYlUl48BuBPAnio+h6SoSxE5AcDxcD7YoyWR6rPc1QC+UEqt9Hl8UtSn5/e5AP4I4CWf1/QrUeqyK4DfROQrEdkoIu+IyOE+n0My1GUHACUA/igi60VkmYjcdNCrKaUq/QGwCsAgT7oXgBVufDqA/QCyPfk/ADjTkz4MwAEAGQDudl+c8rw89/yeBytHBeUqBHA/gIKqnuuefzGA79y4G4BNADIqOK4/gBkhv/sUwIDKjgk5fguAY934HgDjfZQvG8AVAPpV5/klU30C6AlgVQ2fV6LX5wwA/VO1LgH8HsCHnmv/kkp1CSAdwDcAulV0nVSrz5ByLa/J324i1mfIOX0BrAQgqViXAJYB2AqgC5z/K2MAfJkqdQmgDwAFYCyAHADHuOU6q7Ln4nes8WdPvBpOK7DcJmV3L7UA8KaIlHl+VwqgsXuefiyl1C4RKfZZBotS6kcRWQTgPwB6H+x4EWkEp9JPBVAbTq/XFje7OYDVSqmS6pSlgmvdDmAAnOer4IxBNqjKY7iv6X9F5AcRmaeU+j4SZXMlXH1WVbLVZxTFvS5FJA/AP+F8wFdZktTlYADzlVIzI1GOSsS9Pr1E5BQ4vRSvV+GcZKhPr34AXlbuf9IISpS63APgTaXUHAAQkXsBbBaRQ5VS2yo7MUnqsryn+T6l1B4A80VkApzPo6nhTvI77NXcEx8OYK0nHfoH8zOAc5VSdTw/2UqpXwGs8z6W291Y32cZKpIBp0vRjwfdsh6jnAnTV8Hp0isv8+FS8cSzit4QuwDketIF5YE7TjkcwKUA6iqny2+b51pVlQmgVTXPDSdR67MqkrU+Iy0R6rIQTrf3F+4chskADnO7oFv6OD8Z6vJMAL93n9N6OJMrR4vI4z7OrYpEqE+vfgAmK6V2VuGcZKjP8sdoDqcn5mW/51RBotTl/JDrlcd+XqdkqMv5lVwzLL+Nn5tEpJk4E7DuBPBaJcc+BeDvItICAESkoYhc5Oa9DuB8d4wyC8B9VSgDRGSA2xItHwf9K4BpnvxPReSeMKfXhjObf6uINIW9wmg2nD+wh0QkT0SyxSwx3wCgmVvecvMA9BaRXBFpA+C6kOuUwO0OFJG74bRg/Ty/ruWvjYjkiMhwOC3/WX7Or4JEqc80d0w300lKtvd1Tvb6dJ9DlvscBUCmW5ZIrrJMhLpcCOfDuaP7MwDO69wR7jfWFKjL/nDmPJQ/x2/grDCp9kTiMBKhPuE+Xg6ASwC8WEFestdnub4AvlJKrajieX4kSl2+AKfh3lFEMgHcBWe4aat7raSuS7fuvgDwNxGpJc6E68vgzEcKy+8L+CqcpXE/uT8PVHLso3Am6n4kIjsAfA3gRLeQiwDc5D7eOjjdZ3o/EHH3XqjksU8GsEBEdgF43/2505PfHMCXYc69F0AnOK3J9+B8O4VbrlIAF8BZObbGLdNlbvZ0OMta14tI+QqPR+CMuW6AM0nuFc91/gfgAzjjrKvhTFz2dn9aRGSRiFzpJmvBmeRWDOBXON125yml1oY7v5oSpT5Pg9Nl+T6cb0Z7YC/BTPb6hPt89sDpKXjGjU8Ld341xL0ulVIlSqn15T8AfgNQ5qZL3cOSui6VUltDnuN+ANsPNmxQDXGvT4+L4dTJJxXkJXV9elyNyE90LpcQdamUmg7n/+R7cLahaANnnky5VKjLK+AMHRa75bxLOatsw5KDDXOKs/ncAKXUx5UeGGci0gzAJKVUt3iXJZGxPlMH6zK1sD5TB+sy8cVl87xoUEr9Amf2OaUA1mfqYF2mFtZn6ghyXSbV7S2IiIiIauqgw15EREREqYQ9P0RERBQobPwQERFRoPia8HxW2iUcG4uzqWWTIrKpHusy/iJVlwDrMxHwvZk6+N5MLZXVJ3t+iIiIKFBSZqk7ERERxV7GES10PGK6fRu4y9++RceFQ76OWZkOhj0/REREFChs/BAREVGgsPFDREREgcI5P0RERFRtW59M1/FxWXafSq3fErOPJTFLRURERBQlbPwQERFRoHDYi4iIiPw7oYOVHNN2rI43lO638lq+tVXHZVEtVNWw54eIiIgChY0fIiIiChQ2foiIiChQUmvOj2ccctk1OTp+85zHrMM6ZGXqOF3s9l/v5WfpeMddzay8tM++i0gxiYiIklXOqI1W+pgss9S9y0PDrLzG876KSZmqij0/REREFChs/BAREVGgJPWw17rbT7LS9w16Wcfn5W7T8du7GlrHzduXreO0kMV3r7X+UMfH3tLPymv+WfXLStF3+/JFVvrMnH06Dh3ePK/zOTouWbc+ugUjIkpym2/opuOPW4228qbtqavjJhOXW3ml0S1WtbHnh4iIiAKFjR8iIiIKFDZ+iIiIKFCSbs5PRkFjHY8YON7K887zOXL8zTouHLXMOq50c7GOJTPLyps4tYuOJx3/rJU3oM9tOs5/9euqFJti4ADSrXQZlIlVyMizUqDgKb6um53uWqLjomf32gfPXhCLIhElJO8cHwCYe8+TOl60X6y8Mef00nHphpXRLViEsOeHiIiIAoWNHyIiIgqUpBv2WnFjax1fmPe+lXf24t46LnzYLLfzDnOFUgfsO9AuXnOYjosK7SGx4gt36zj/VZ8FpqhKyzbbFqQj/FDWUeNuttKtiudGrUxBlV6/no53n9Daysv5zGxDULZ7N2Jp37lmKHtLD3toa0mPZ3T8SY9DrLxH2rSPbsESmLcuMybbn4OFh5jdfRfcbHbV//X0POu4Wt3M5+6szpH/wDzlL/Z7us64mRG/RtDIcUfp+OO77OXspcp81t74pyFWXs7y2dEtWBSw54eIiIgChY0fIiIiChQ2foiIiChQkm7OT+YuE/93R2MrL+eKnTqubJ6PV8mZna30hFOf0vEPB+zl0S0eZ1sx3tLbtrHS+c//puMzc8LPJclfYadD53pRzf3at52OZw991Mq7d6N5n809LrrvI+l8lJU+aeQsHY9oZM/1+mC32ZZ/xONXW3kFSMy7UUdDxmEFVrp4rJn/9Fmb/4Y97+Nxi3XcM2eHlZfm+W5dFnIboUjY3Xubla4zLuKXCISMpk10fM74GTo+JK2WdVzhmzfquOhd+32UjBuH8L85ERERBQobP0RERBQoSTfs1fRfZknda5NOtvJKN6/29Rhl3Y/T8SNjn7Dy2mdm6vjoGddZeS1nzPNbTIqg9MaNdJzxzE4rb1zLqbEuDoWxp3H4zm/vcFOvMwbpOGN6ZLYc8A51DZ7wppV3bq4ZjgkdfBk212yP0Wb8EisvUe9GXW1i78qb3qCBjre9kGvlfdYh/FCXV+hQl19rS/bpONMuFhqm1wJFT+hdDZY8ZKaPTKnzjo7/UWwPHxfebIaPk3GYKxR7foiIiChQ2PghIiKiQGHjh4iIiAIl6eb8qBJzF+aSleHn+EgtM27849gjrbxZ3R8Pe17PQbfq+IgP59nX9ltIiqw6+Tp8o81rvk/r+u0VOm48ZbmVl3LzOeLhhA5W8o6L39JxWsj3qk/2mK3xIzHPR53c0UoPe9mscz4t297G4IAytd1r8aVW3hGXz9dxKv5NeLeGWDqogZW3+NLHavz4M/eaz9lr37/ezvTO5Qn58Gw90dTR6vOyrbwFfcf4unb69Dq+jiPbxuvs7V2WnGH+H64sMbd/mXmeva0I8Es0ixVz7PkhIiKiQGHjh4iIiAIlIYe90vPNMIfUr2vlbTzd7EZZd6m9o++P15pl6k+f/pKOe+SE7tRqulmvXHm2lZP3mVnuWspdgBPC2rMbHfwg2F22ALD/c9PNX7ppWUTLFFQlZ5gu8xNHz7Hy+uWbYei5++zvVSPuMttG5OPral17x2VddfyXB1628k7JNnUfupzdO9RV6+xV1bp2stpVaO7OXpVhruHru+l4wZYmVl7m0No6Tt9qtp4oXDkLfpWe3sk8RuudlRxpaz/Z3Mm98An/1wu67X3Me2fOXfb2Lr+W7tFx/+F36Lj2z9V7nyYL9vwQERFRoLDxQ0RERIHCxg8REREFSkLO+Vl1y9E6/n5wzZdjVuaVIz6y0kM/OVHHU1fZ23vXG2/udJz7JsebY+WWwZN9Hdf7yaFWuumo4NyVO1aGPG1ue+C9bQRgz7X5YMcxVl7dj8ycq+ouKd/TZ6uOf5e7Lexxz21rZaWDNs/Hr6UHTE3cvuISKy/rWhNnrF5j5XlXrZfAH9XtWCt923Ov6riyW2R4l9IDQNvntuu4rCwVNyeIjuxr1um4LGTfge7Thui4aEJqz/PxYs8PERERBQobP0RERBQoCTns1XJKsY779epp5eVl+Ft+/tUU083asMdaK2/1yoY67tlxsZV3eQPT7Teqqz209V0n07F/u9xs5eVO5jBYvB3+xnorzU7xmts/tYWV7p4z05Oy7w5998YuOl54vr08urTYfg/6sXbYSVZ6Ssd/elL2cIh3qOu9i08IeaQVVb52qshbZj5LO7x0q5VXMNu8Q3KmzLby/A5nVWb/OebvYfPAXVae37vBX/PJtVa66Ptval6wANhwi/3emXuk2cX5gc32zuxtBy3Usd+7GGQ0b2alSxscqmMptT95y+YvQSJizw8REREFChs/REREFCgJOexVumipjotPtvOK4U8zeFb6PGjnFWGVju11DMDDx5hVDzcMy7HylvR4TsejR9s3R72j7CYdh3YhU9XtO9d0mbfOGhvHkqQ+ybSHr9SHZkftj9qFrrQzxz6zraWVM+84byr8MJf3Zpt7W9Sx84Zv0PG37eyVnpliVlt6b1YKAP8Zf4GOmy3jKr9ypcvMkN8Rd8Z2+G/jcWbH/W9PeKmSI21HTrxFx22HfWvl8ebS4aXl5em4zaXhd7Qf/0F3K91qnxnKTm9opoRsvNi+sam60Pz3HX3k61beydkHdLxb2VNTRmw4VceLhhxt5aXNmBe2nNHGnh8iIiIKFDZ+iIiIKFDY+CEiIqJAScg5P/HkXZbX9u+FVt60rrk6PjPHvqP8ukvNOGerKdEpWyo7cPbxVvry0e/r2DueHKr9pwN0XLQhuEuaa+LnO+zX/tt2j+o49A7pXltK8qz0+iGe5bViH1vrrE06/rLjBM/jh79CaM4Bz4SP0PlGLV9apeNILNOmqksvtHfWHt5voq/zQndxbjfG7EZccsDf1iYErP6z2d7l+1ahd0Ywb8jWr223cpY+abaGePdc894vCpkLmOZ5jDn77NlXnWb31fGzx46z8kYVmG1g5r5sbwlzX5ezdFxa/BtiiT0/REREFChs/BAREVGgcNirEqU//GilB7/fX8dL//AfK49LMGtmW8tMK33doaGbEDiWH9hnpetNzdZx6fbtoYeTD2MGPF2t84bXX2Slhw5b4PPM6n3nGunZmfbdR+3luvV+nRl6OMWAd6ir+2S7/q+obbYtCB3C7PC5Ga5u+qL93s9ayV2cq0N5hprTQsad08W859571x6W8tpeZv6TtZ12vZXX5C1TT3lv2MNXTWDulHDN3bdYeQsHmm1humbb733JsofWYok9P0RERBQobPwQERFRoLDxQ0RERIES3Tk/aek6TD/EXhabjPMzCr7yjKP+IX7lCJo1JXt0fPkjw6y8ghd5K4OaGjSrr5Ve2P3ZqF5vxEZzH4wRjeb6Pm/aiFN0XG8K5/gkglWXFuj4rXqTrLxMMZ//H+3OtvKajTX/ejI/5hyfSGg8x2wJUjYwZBaqMrOuykJmqM7eZ/6vDRt+m44LJ9nzevw6qdd8K+293oA1p9p523dU6xqRwJ4fIiIiChQ2foiIiChQojrsVdrd7Dj5t7HPW3kjL7lSx2quvWQ2UaQXtbbSw+4fH/ZYCZtDNbVgv+laL3iUw1yR1qrPPCt9Wt9bdbytjf2XnXXsFh0PLJph5T026QKE02qM2Tn9h4fM++reXt+FPafXkoutdM6U2WGPpdhZc7fZyfv96/6p4zLYOzV7h7pGDulv5dX6eE50ChdguV8u1fF1a3pYeSOamB3zD8/IsfKOzNyr4y1tzVBl7eOO8n3tH6+ureMXmzxs5a0sMZ8hv9xh/09N2zXP9zUijT0/REREFChs/BAREVGgsPFDREREgRLVOT8/9TbbYXerVWrlbeySr+OG/le7Rl1athmnXn9GIyvvvNxtYc+r90FO2DyqmHjuGryzRRwLQpY648wy8jqVHPc26lvpFgg/H8v77q/buKGOQ+/q/t0+830s61r7MXi39vhIy7O3KZl8rZnT0SSjVujh2sDp/XVc9B7n+ESbd/uYTde3s/KmTSrS8VW1V1l5h6SZOvz+Rs/d4G+0H997y4zQ5fJeB5R9y4rz7x+i4wYzEmeLCvb8EBERUaCw8UNERESBEtVhr7ZjPbs39rbznho+Rse3b7jZyst9s3o7S/qV3raNjpcObGDlnXaSWXb/dvPHEU7nOVdZ6WZvmfNKQw+mCqU3Mq/9wv7hX2tKftLZLJud1fllHYfe7bvPOzfpuHD119EuFoWx+YZuOm7bb4mV1yozM/TwChXdwKGueClbaNfZG+3NFI6nBl1k5eX3XqfjqUe94evx5+6z01d+OUDHh71lD3s1eD1xhrq82PNDREREgcLGDxEREQUKGz9EREQUKFGd85O2bZeOp+3JtfLOzNmt43//6zErb8JdJ4Z9zNe/7qLj+t+lhz2uuLOZeXNUu5+tvAFNP9RxZcvXlx3Yb6WvHvlnHTebuNjKS8a71BPFyu7meRX+/pM9h1jpdk//pmPOnYudjJaHW+nTB5p5lyMLws/BXH7AbEBw6/LL7MfEmgiVjiKp4VMhc3CeMuH56Fytx2yD8LepSVTs+SEiIqJAYeOHiIiIAiWqw14lK1fr+LGzzrHybh7cRMev/fFRK29k42/CPubIizx5F4U9zLejPre3kc2fbobnGn220cprsMx0F7JLPnae7OvdJ2F+3MpB/qXn51vpEaPH6jhTzHD14P/1t44rXBzdbS7ISG9odtpOf8ke4q9sqOuXErPOefiqS3Sc0ZPDXJQ82PNDREREgcLGDxEREQUKGz9EREQUKFGd8+Plnf8DAK2HmvQd026y8lZdaO4eO+uCR6y8UZtO0XFlc4M6ze6r453r7OW07f9drOMjfgyZQ6LM3Wo5rycxpO01y2lDb4dACaqgoZWsnbZXx5tLzfySptNiViIKsfx2c5ufBW3GVHKk7bIHhuq4/rOJeesCooNhzw8REREFChs/REREFCgxG/aqTNaH9t1/i8wGzOg7+OSQo82wVGW7UTbB4rB5HM5KDKWbzfBju4n20GdZnqml9j8tjVmZKDJKl62w0utLDtXxw7+aO4bnTubS9ljZeYm9c/70PqM8qVphz7t/Uycr3XDCQh1zGJqSFXt+iIiIKFDY+CEiIqJAYeOHiIiIAiUh5vxQMKl9Zpv8Nrd9HfY4ztFKfk//7mwdq23b41iSYCk7paOOH/zH01Zew/Tw83yOef5WHbceZ9/mp2zHitDDiZIOe36IiIgoUNj4ISIiokDhsBcRRV3JT6viXYRAylq7VceTtxxv5Z1YYHZnHlXcwcpr/epmHYduW0CUCtjzQ0RERIHCxg8REREFChs/REREFCic80NElKK8c61+CLkb0IXoUsmZP0alPESJgj0/REREFChs/BAREVGgiFLq4EcRERERpQj2/BAREVGgsPFDREREgcLGDxEREQUKGz9EREQUKGz8EBERUaCw8UNERESB8v9D4PCrWQgE7QAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "with torch.no_grad():\n", " xbv,ybv = next(iter(dls.train))\n", " logits = loaded_model(xbv)\n", " probs = F.softmax(logits, dim=1)\n", " idx = 5\n", " _,axs = plt.subplots(1, idx, figsize=(10, 10))\n", " for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n", " ax.imshow(im)\n", " ax.set_axis_off()\n", " ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n", " " ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaPklEQVR4nO3dd3hVRfoH8O+bRopEOpEiCCSAiiKIgg1RdBXrsmtDEVQUxMK6Cuy6P8W26C6yrqhrxQa6CIrYXRFsKAKiSBMQpKj0SO9J5vfHOZk5c5/ccJLcfr6f58nzvJM55565d3Jv5k47opQCERERUVCkxbsARERERLHExg8REREFChs/REREFChs/BAREVGgsPFDREREgcLGDxEREQVK3Bs/IrJKRHrGuxyxIiL3iMj4eJcjWlifqYN1mVpYn6mDdVlzcW/8VIWI/ElEfhKR7SKyVkQeEZGMKF+zv4jMiOY1PNdqICJfikixiGwVkZkicnIsrh0PItJDRD4RkW0isipG14xZfbrXe0ZElopImYj0j9V140VEskRkiYj8EoNrxbQuPdftJyJKRAbE+tqxIiIfiMhOz89+EVkQ5WvG8rP21JDnt9Ot0z/E4vqxJiKdRORz93luEJEhUb5erD9n00XkAbddsENEvhOROpWdE9HGT7QbIgDeAdBJKZUP4GgAxwK4NcrXjKWdAK4F0BBAXQD/APBODF7XCsXgursAPA9gaJSvE0/fAxgM4Nt4FiKGf0NDAWyM0bViTkTqAvgrgEVxLkdU61Mpda5S6pDyHwBfAZgUzWvGklLqi5Dndz6cz98PY12WGHyBbwDneT0NoD6ANgA+iuY14+BeACcB6AYgH0BfAHsrO+GgjR+3e+2vIrJYRLaIyAsiku3mnS4iv4jIcBFZD+AFEUkTkb+IyAq3B2OiiNTzPF5fEVnt5v2tKs9OKbVCKbW1/KEAlMGpyIMSkdYiMt297mYRecXbMhSR5iIyWUQ2ucc8LiLtATwFoJvbYt7qHvup91tfaCtXRB4VkZ/dHqq5InKqz+e3Vym1VClV5j6/UjiNoHqVn+lfgtXnbKXUOAA/VeN5JHx9us/xCaXUNBzkjVgdiVSX7vlHALgKwINVPC8p6tL1IIAxADZX8byDSrT69DxOSwCnAhjn8/hkqs9y/QC8rpTaVc3zLQlWl38G8D+l1CtKqX1KqR1KqR98Po+Er0txvpD8CcD1SqnVyrFQKVWzxo/rSgC/A9AaQBGA//PkFcD559wCwA1wemIuBtAdQBMAWwA84RbySABPwmmVNYHTCm3meRKnlL9Q4YhIHxHZDufD51g4rVk/BM4HVxMA7QE0B3CP+5jpAN4FsBpASwBNAUxw/0AGAZjpfkOo4/NacwB0hPO6vApgUvkffgXPZ76I9An9HZx/lm8DeE4pFelv0glTnzWQNPUZZYlUl48BuBPAnio+h6SoSxE5AcDxcD7YoyWR6rPc1QC+UEqt9Hl8UtSn5/e5AP4I4CWf1/QrUeqyK4DfROQrEdkoIu+IyOE+n0My1GUHACUA/igi60VkmYjcdNCrKaUq/QGwCsAgT7oXgBVufDqA/QCyPfk/ADjTkz4MwAEAGQDudl+c8rw89/yeBytHBeUqBHA/gIKqnuuefzGA79y4G4BNADIqOK4/gBkhv/sUwIDKjgk5fguAY934HgDjfZQvG8AVAPpV5/klU30C6AlgVQ2fV6LX5wwA/VO1LgH8HsCHnmv/kkp1CSAdwDcAulV0nVSrz5ByLa/J324i1mfIOX0BrAQgqViXAJYB2AqgC5z/K2MAfJkqdQmgDwAFYCyAHADHuOU6q7Ln4nes8WdPvBpOK7DcJmV3L7UA8KaIlHl+VwqgsXuefiyl1C4RKfZZBotS6kcRWQTgPwB6H+x4EWkEp9JPBVAbTq/XFje7OYDVSqmS6pSlgmvdDmAAnOer4IxBNqjKY7iv6X9F5AcRmaeU+j4SZXMlXH1WVbLVZxTFvS5FJA/AP+F8wFdZktTlYADzlVIzI1GOSsS9Pr1E5BQ4vRSvV+GcZKhPr34AXlbuf9IISpS63APgTaXUHAAQkXsBbBaRQ5VS2yo7MUnqsryn+T6l1B4A80VkApzPo6nhTvI77NXcEx8OYK0nHfoH8zOAc5VSdTw/2UqpXwGs8z6W291Y32cZKpIBp0vRjwfdsh6jnAnTV8Hp0isv8+FS8cSzit4QuwDketIF5YE7TjkcwKUA6iqny2+b51pVlQmgVTXPDSdR67MqkrU+Iy0R6rIQTrf3F+4chskADnO7oFv6OD8Z6vJMAL93n9N6OJMrR4vI4z7OrYpEqE+vfgAmK6V2VuGcZKjP8sdoDqcn5mW/51RBotTl/JDrlcd+XqdkqMv5lVwzLL+Nn5tEpJk4E7DuBPBaJcc+BeDvItICAESkoYhc5Oa9DuB8d4wyC8B9VSgDRGSA2xItHwf9K4BpnvxPReSeMKfXhjObf6uINIW9wmg2nD+wh0QkT0SyxSwx3wCgmVvecvMA9BaRXBFpA+C6kOuUwO0OFJG74bRg/Ty/ruWvjYjkiMhwOC3/WX7Or4JEqc80d0w300lKtvd1Tvb6dJ9DlvscBUCmW5ZIrrJMhLpcCOfDuaP7MwDO69wR7jfWFKjL/nDmPJQ/x2/grDCp9kTiMBKhPuE+Xg6ASwC8WEFestdnub4AvlJKrajieX4kSl2+AKfh3lFEMgHcBWe4aat7raSuS7fuvgDwNxGpJc6E68vgzEcKy+8L+CqcpXE/uT8PVHLso3Am6n4kIjsAfA3gRLeQiwDc5D7eOjjdZ3o/EHH3XqjksU8GsEBEdgF43/2505PfHMCXYc69F0AnOK3J9+B8O4VbrlIAF8BZObbGLdNlbvZ0OMta14tI+QqPR+CMuW6AM0nuFc91/gfgAzjjrKvhTFz2dn9aRGSRiFzpJmvBmeRWDOBXON125yml1oY7v5oSpT5Pg9Nl+T6cb0Z7YC/BTPb6hPt89sDpKXjGjU8Ld341xL0ulVIlSqn15T8AfgNQ5qZL3cOSui6VUltDnuN+ANsPNmxQDXGvT4+L4dTJJxXkJXV9elyNyE90LpcQdamUmg7n/+R7cLahaANnnky5VKjLK+AMHRa75bxLOatsw5KDDXOKs/ncAKXUx5UeGGci0gzAJKVUt3iXJZGxPlMH6zK1sD5TB+sy8cVl87xoUEr9Amf2OaUA1mfqYF2mFtZn6ghyXSbV7S2IiIiIauqgw15EREREqYQ9P0RERBQobPwQERFRoPia8HxW2iUcG4uzqWWTIrKpHusy/iJVlwDrMxHwvZk6+N5MLZXVJ3t+iIiIKFBSZqk7ERERxV7GES10PGK6fRu4y9++RceFQ76OWZkOhj0/REREFChs/BAREVGgsPFDREREgcI5P0RERFRtW59M1/FxWXafSq3fErOPJTFLRURERBQlbPwQERFRoHDYi4iIiPw7oYOVHNN2rI43lO638lq+tVXHZVEtVNWw54eIiIgChY0fIiIiChQ2foiIiChQUmvOj2ccctk1OTp+85zHrMM6ZGXqOF3s9l/v5WfpeMddzay8tM++i0gxiYiIklXOqI1W+pgss9S9y0PDrLzG876KSZmqij0/REREFChs/BAREVGgJPWw17rbT7LS9w16Wcfn5W7T8du7GlrHzduXreO0kMV3r7X+UMfH3tLPymv+WfXLStF3+/JFVvrMnH06Dh3ePK/zOTouWbc+ugUjIkpym2/opuOPW4228qbtqavjJhOXW3ml0S1WtbHnh4iIiAKFjR8iIiIKFDZ+iIiIKFCSbs5PRkFjHY8YON7K887zOXL8zTouHLXMOq50c7GOJTPLyps4tYuOJx3/rJU3oM9tOs5/9euqFJti4ADSrXQZlIlVyMizUqDgKb6um53uWqLjomf32gfPXhCLIhElJO8cHwCYe8+TOl60X6y8Mef00nHphpXRLViEsOeHiIiIAoWNHyIiIgqUpBv2WnFjax1fmPe+lXf24t46LnzYLLfzDnOFUgfsO9AuXnOYjosK7SGx4gt36zj/VZ8FpqhKyzbbFqQj/FDWUeNuttKtiudGrUxBlV6/no53n9Daysv5zGxDULZ7N2Jp37lmKHtLD3toa0mPZ3T8SY9DrLxH2rSPbsESmLcuMybbn4OFh5jdfRfcbHbV//X0POu4Wt3M5+6szpH/wDzlL/Z7us64mRG/RtDIcUfp+OO77OXspcp81t74pyFWXs7y2dEtWBSw54eIiIgChY0fIiIiChQ2foiIiChQkm7OT+YuE/93R2MrL+eKnTqubJ6PV8mZna30hFOf0vEPB+zl0S0eZ1sx3tLbtrHS+c//puMzc8LPJclfYadD53pRzf3at52OZw991Mq7d6N5n809LrrvI+l8lJU+aeQsHY9oZM/1+mC32ZZ/xONXW3kFSMy7UUdDxmEFVrp4rJn/9Fmb/4Y97+Nxi3XcM2eHlZfm+W5dFnIboUjY3Xubla4zLuKXCISMpk10fM74GTo+JK2WdVzhmzfquOhd+32UjBuH8L85ERERBQobP0RERBQoSTfs1fRfZknda5NOtvJKN6/29Rhl3Y/T8SNjn7Dy2mdm6vjoGddZeS1nzPNbTIqg9MaNdJzxzE4rb1zLqbEuDoWxp3H4zm/vcFOvMwbpOGN6ZLYc8A51DZ7wppV3bq4ZjgkdfBk212yP0Wb8EisvUe9GXW1i78qb3qCBjre9kGvlfdYh/FCXV+hQl19rS/bpONMuFhqm1wJFT+hdDZY8ZKaPTKnzjo7/UWwPHxfebIaPk3GYKxR7foiIiChQ2PghIiKiQGHjh4iIiAIl6eb8qBJzF+aSleHn+EgtM27849gjrbxZ3R8Pe17PQbfq+IgP59nX9ltIiqw6+Tp8o81rvk/r+u0VOm48ZbmVl3LzOeLhhA5W8o6L39JxWsj3qk/2mK3xIzHPR53c0UoPe9mscz4t297G4IAytd1r8aVW3hGXz9dxKv5NeLeGWDqogZW3+NLHavz4M/eaz9lr37/ezvTO5Qn58Gw90dTR6vOyrbwFfcf4unb69Dq+jiPbxuvs7V2WnGH+H64sMbd/mXmeva0I8Es0ixVz7PkhIiKiQGHjh4iIiAIlIYe90vPNMIfUr2vlbTzd7EZZd6m9o++P15pl6k+f/pKOe+SE7tRqulmvXHm2lZP3mVnuWspdgBPC2rMbHfwg2F22ALD/c9PNX7ppWUTLFFQlZ5gu8xNHz7Hy+uWbYei5++zvVSPuMttG5OPral17x2VddfyXB1628k7JNnUfupzdO9RV6+xV1bp2stpVaO7OXpVhruHru+l4wZYmVl7m0No6Tt9qtp4oXDkLfpWe3sk8RuudlRxpaz/Z3Mm98An/1wu67X3Me2fOXfb2Lr+W7tFx/+F36Lj2z9V7nyYL9vwQERFRoLDxQ0RERIHCxg8REREFSkLO+Vl1y9E6/n5wzZdjVuaVIz6y0kM/OVHHU1fZ23vXG2/udJz7JsebY+WWwZN9Hdf7yaFWuumo4NyVO1aGPG1ue+C9bQRgz7X5YMcxVl7dj8ycq+ouKd/TZ6uOf5e7Lexxz21rZaWDNs/Hr6UHTE3cvuISKy/rWhNnrF5j5XlXrZfAH9XtWCt923Ov6riyW2R4l9IDQNvntuu4rCwVNyeIjuxr1um4LGTfge7Thui4aEJqz/PxYs8PERERBQobP0RERBQoCTns1XJKsY779epp5eVl+Ft+/tUU083asMdaK2/1yoY67tlxsZV3eQPT7Teqqz209V0n07F/u9xs5eVO5jBYvB3+xnorzU7xmts/tYWV7p4z05Oy7w5998YuOl54vr08urTYfg/6sXbYSVZ6Ssd/elL2cIh3qOu9i08IeaQVVb52qshbZj5LO7x0q5VXMNu8Q3KmzLby/A5nVWb/OebvYfPAXVae37vBX/PJtVa66Ptval6wANhwi/3emXuk2cX5gc32zuxtBy3Usd+7GGQ0b2alSxscqmMptT95y+YvQSJizw8REREFChs/REREFCgJOexVumipjotPtvOK4U8zeFb6PGjnFWGVju11DMDDx5hVDzcMy7HylvR4TsejR9s3R72j7CYdh3YhU9XtO9d0mbfOGhvHkqQ+ybSHr9SHZkftj9qFrrQzxz6zraWVM+84byr8MJf3Zpt7W9Sx84Zv0PG37eyVnpliVlt6b1YKAP8Zf4GOmy3jKr9ypcvMkN8Rd8Z2+G/jcWbH/W9PeKmSI21HTrxFx22HfWvl8ebS4aXl5em4zaXhd7Qf/0F3K91qnxnKTm9opoRsvNi+sam60Pz3HX3k61beydkHdLxb2VNTRmw4VceLhhxt5aXNmBe2nNHGnh8iIiIKFDZ+iIiIKFDY+CEiIqJAScg5P/HkXZbX9u+FVt60rrk6PjPHvqP8ukvNOGerKdEpWyo7cPbxVvry0e/r2DueHKr9pwN0XLQhuEuaa+LnO+zX/tt2j+o49A7pXltK8qz0+iGe5bViH1vrrE06/rLjBM/jh79CaM4Bz4SP0PlGLV9apeNILNOmqksvtHfWHt5voq/zQndxbjfG7EZccsDf1iYErP6z2d7l+1ahd0Ywb8jWr223cpY+abaGePdc894vCpkLmOZ5jDn77NlXnWb31fGzx46z8kYVmG1g5r5sbwlzX5ezdFxa/BtiiT0/REREFChs/BAREVGgcNirEqU//GilB7/fX8dL//AfK49LMGtmW8tMK33doaGbEDiWH9hnpetNzdZx6fbtoYeTD2MGPF2t84bXX2Slhw5b4PPM6n3nGunZmfbdR+3luvV+nRl6OMWAd6ir+2S7/q+obbYtCB3C7PC5Ga5u+qL93s9ayV2cq0N5hprTQsad08W859571x6W8tpeZv6TtZ12vZXX5C1TT3lv2MNXTWDulHDN3bdYeQsHmm1humbb733JsofWYok9P0RERBQobPwQERFRoLDxQ0RERIES3Tk/aek6TD/EXhabjPMzCr7yjKP+IX7lCJo1JXt0fPkjw6y8ghd5K4OaGjSrr5Ve2P3ZqF5vxEZzH4wRjeb6Pm/aiFN0XG8K5/gkglWXFuj4rXqTrLxMMZ//H+3OtvKajTX/ejI/5hyfSGg8x2wJUjYwZBaqMrOuykJmqM7eZ/6vDRt+m44LJ9nzevw6qdd8K+293oA1p9p523dU6xqRwJ4fIiIiChQ2foiIiChQojrsVdrd7Dj5t7HPW3kjL7lSx2quvWQ2UaQXtbbSw+4fH/ZYCZtDNbVgv+laL3iUw1yR1qrPPCt9Wt9bdbytjf2XnXXsFh0PLJph5T026QKE02qM2Tn9h4fM++reXt+FPafXkoutdM6U2WGPpdhZc7fZyfv96/6p4zLYOzV7h7pGDulv5dX6eE50ChdguV8u1fF1a3pYeSOamB3zD8/IsfKOzNyr4y1tzVBl7eOO8n3tH6+ureMXmzxs5a0sMZ8hv9xh/09N2zXP9zUijT0/REREFChs/BAREVGgsPFDREREgRLVOT8/9TbbYXerVWrlbeySr+OG/le7Rl1athmnXn9GIyvvvNxtYc+r90FO2DyqmHjuGryzRRwLQpY648wy8jqVHPc26lvpFgg/H8v77q/buKGOQ+/q/t0+830s61r7MXi39vhIy7O3KZl8rZnT0SSjVujh2sDp/XVc9B7n+ESbd/uYTde3s/KmTSrS8VW1V1l5h6SZOvz+Rs/d4G+0H997y4zQ5fJeB5R9y4rz7x+i4wYzEmeLCvb8EBERUaCw8UNERESBEtVhr7ZjPbs39rbznho+Rse3b7jZyst9s3o7S/qV3raNjpcObGDlnXaSWXb/dvPHEU7nOVdZ6WZvmfNKQw+mCqU3Mq/9wv7hX2tKftLZLJud1fllHYfe7bvPOzfpuHD119EuFoWx+YZuOm7bb4mV1yozM/TwChXdwKGueClbaNfZG+3NFI6nBl1k5eX3XqfjqUe94evx5+6z01d+OUDHh71lD3s1eD1xhrq82PNDREREgcLGDxEREQUKGz9EREQUKFGd85O2bZeOp+3JtfLOzNmt43//6zErb8JdJ4Z9zNe/7qLj+t+lhz2uuLOZeXNUu5+tvAFNP9RxZcvXlx3Yb6WvHvlnHTebuNjKS8a71BPFyu7meRX+/pM9h1jpdk//pmPOnYudjJaHW+nTB5p5lyMLws/BXH7AbEBw6/LL7MfEmgiVjiKp4VMhc3CeMuH56Fytx2yD8LepSVTs+SEiIqJAYeOHiIiIAiWqw14lK1fr+LGzzrHybh7cRMev/fFRK29k42/CPubIizx5F4U9zLejPre3kc2fbobnGn220cprsMx0F7JLPnae7OvdJ2F+3MpB/qXn51vpEaPH6jhTzHD14P/1t44rXBzdbS7ISG9odtpOf8ke4q9sqOuXErPOefiqS3Sc0ZPDXJQ82PNDREREgcLGDxEREQUKGz9EREQUKFGd8+Plnf8DAK2HmvQd026y8lZdaO4eO+uCR6y8UZtO0XFlc4M6ze6r453r7OW07f9drOMjfgyZQ6LM3Wo5rycxpO01y2lDb4dACaqgoZWsnbZXx5tLzfySptNiViIKsfx2c5ufBW3GVHKk7bIHhuq4/rOJeesCooNhzw8REREFChs/REREFCgxG/aqTNaH9t1/i8wGzOg7+OSQo82wVGW7UTbB4rB5HM5KDKWbzfBju4n20GdZnqml9j8tjVmZKDJKl62w0utLDtXxw7+aO4bnTubS9ljZeYm9c/70PqM8qVphz7t/Uycr3XDCQh1zGJqSFXt+iIiIKFDY+CEiIqJAYeOHiIiIAiUh5vxQMKl9Zpv8Nrd9HfY4ztFKfk//7mwdq23b41iSYCk7paOOH/zH01Zew/Tw83yOef5WHbceZ9/mp2zHitDDiZIOe36IiIgoUNj4ISIiokDhsBcRRV3JT6viXYRAylq7VceTtxxv5Z1YYHZnHlXcwcpr/epmHYduW0CUCtjzQ0RERIHCxg8REREFChs/REREFCic80NElKK8c61+CLkb0IXoUsmZP0alPESJgj0/REREFChs/BAREVGgiFLq4EcRERERpQj2/BAREVGgsPFDREREgcLGDxEREQUKGz9EREQUKGz8EBERUaCw8UNERESB8v9D4PCrWQgE7QAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "logits = model(xbv)\n", "probs = F.softmax(logits, dim=1)\n", "idx = 5\n", "_,axs = plt.subplots(1, idx, figsize=(10, 10))\n", "for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n", " ax.imshow(im)\n", " ax.set_axis_off()\n", " ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n", " " ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "def predict(img):\n", " with torch.no_grad():\n", " img = img[None,]\n", " pred = loaded_model(img)[0]\n", " pred_probs = F.softmax(pred, dim=0)\n", " pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n", " pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n", " return pred" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(3)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n" ] }, { "data": { "text/plain": [ "[{'digit': 0, 'prob': '0.00%', 'logits': tensor(-5.5980)},\n", " {'digit': 1, 'prob': '0.00%', 'logits': tensor(-0.4972)},\n", " {'digit': 2, 'prob': '0.02%', 'logits': tensor(1.2516)},\n", " {'digit': 3, 'prob': '99.95%', 'logits': tensor(9.9263)},\n", " {'digit': 4, 'prob': '0.00%', 'logits': tensor(-5.5094)},\n", " {'digit': 5, 'prob': '0.01%', 'logits': tensor(0.2367)},\n", " {'digit': 6, 'prob': '0.00%', 'logits': tensor(-9.4633)},\n", " {'digit': 7, 'prob': '0.00%', 'logits': tensor(-2.4315)},\n", " {'digit': 8, 'prob': '0.02%', 'logits': tensor(1.4733)},\n", " {'digit': 9, 'prob': '0.00%', 'logits': tensor(-0.0205)}]" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = xb[1].reshape(1, 28, 28)\n", "print(yb[1])\n", "predict(img)" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "exclude" ] }, "source": [ "#### commit to .py file for deployment" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "tags": [ "exclude" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n" ] } ], "source": [ "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "python_main", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }