{ "cells": [ { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'dlopen(/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c106detail19maybe_wrap_dim_slowIxEET_S2_S2_b\n", " Referenced from: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so\n", " Expected in: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/lib/libc10.dylib'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n", " warn(\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from datasets import load_dataset\n", "import fastcore.all as fc\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import torchvision.transforms.functional as TF\n", "from torch.utils.data import default_collate, DataLoader\n", "import torch.optim as optim\n", "import pickle\n", "%matplotlib inline\n", "plt.rcParams['figure.figsize'] = [2, 2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset_nm = 'mnist'\n", "x,y = 'image', 'label'\n", "ds = load_dataset(dataset_nm)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIt0lEQVR4nO3df2yU9R0H8PfHtrQroFLBrmJHO6j8EBxujUAgSDJx1SxxZkFgZtmMC5nInBvb+LFlmwsumCwkyJiJZF0xUXQwF4hhI4MocRkyqgMHspafzkIpFgYyh9JeP/vjzvY+F9peP8/9eO76fiVN7/08d32+IR++973n7vmcqCqIBuqabA+AchMLh1xYOOTCwiEXFg65sHDIJVDhiEidiDSJyFERWZ6qQVH4ifc8jogUAGgGMBdAC4B9ABaq6jupGx6FVWGAx94B4KiqHgcAEXkRwH0Aei2cIVKsJRga4JCUaZfwn3ZVHZW4PUjhjAbwXlxuATCtrweUYCimyRcDHJIybaduefdq24MUTlJEZBGARQBQgtJ0H44yJMji+BSAyrh8c2yboarPqmqtqtYWoTjA4ShMghTOPgA1IlItIkMALACwLTXDorBzP1WpaqeILAGwA0ABgHpVPZSykVGoBVrjqOp2ANtTNBbKITxzTC4sHHJh4ZALC4dcWDjkwsIhFxYOubBwyIWFQy4sHHJh4ZALC4dc0v5BrnwhhfafqmDUyKQf2/SDKpMjpV0mjxl71uTSxWLymTVDTH6r9iWT2yMfmjxt89Lu2+O+/0bS4xwIzjjkwsIhFxYOuQyaNU7BxBqTtbjI5NN3Xm/y5el23VB2nc2vf86uM4L40/+Gm/zUr+tM3jvlBZNPdFw2eXXbXJNvej39PY8445ALC4dcWDjkkrdrnMicz5u8pmG9ybcU2XMjmdShEZN/uu6bJhd+aNcoMzYvMXn4qU6Ti9vtmqe0cW/AEfaPMw65sHDIhYVDLnm7xiluOm3ymx9VmnxLUVvKjrW0dbrJx/9r38dqGLvF5Itddg1T/vTfAh0/G52qOeOQCwuHXFg45JK3a5zO1jMmr3tqnslP1tn3ngreHmbygcXr+vz7q9pv67599C7bMCpyodXkr81YbPLJx+zfqsaBPo8VRpxxyKXfwhGRehE5KyIH47aVichfRORI7PeI9A6TwiaZGacBQF3CtuUAdqlqDYBdsUyDSFJ9jkWkCsArqjo5lpsAzFHVVhGpAPCaqo7v7+9cK2Ualq6jBSNvMDly7rzJJ164zeRDs+tNvuOX3+m+feP6YOdhwmynbnlTVWsTt3vXOOWq+skK8AyAcvfIKCcFXhxrdMrqddoSkUUi0igijR34OOjhKCS8hdMWe4pC7PfZ3u7IdrX5yXseZxuAbwBYHfu9NWUjypBI+7k+93d80PfndW59sOebB95/psDu7Iog3yXzcnwTgD0AxotIi4g8jGjBzBWRIwDuimUaRPqdcVR1YS+7wvHyiLKCZ47JJW/fqwpq4rJmkx+aYifY343Z1X37znmPmn3DX0rP9dphwhmHXFg45MLCIReucXoRuXDR5HOPTDT539t6rmVavuo5s2/FA/ebrP+4zuTKJ/fYgzm/FzWbOOOQCwuHXPhUlaSuA4dNXvDED7tvP/+zX5l9+6fbpy7Yq2dw61B7SW/NBvtR087jJ32DzCDOOOTCwiEXFg65JPXR0VQJ00dHU0lnTjX52tUtJm/67I4+Hz/h1W+ZPP4JeyogcuS4f3ABpfqjozTIsXDIhYVDLlzjpEFB+Y0mn54/zuS9y9aafE3C/98HT9xt8sVZfX/MNZ24xqGUYuGQCwuHXPheVRpE2uxlZuVP2/zRj2y72VKxl+JsqHrF5C/f/7i9/x/T3462P5xxyIWFQy4sHHLhGicFumZNNfnYvBKTJ089aXLimibRuvO32/tvbXSPLV0445ALC4dcWDjkwjVOkqR2ssnNj/WsUzbM3Gj2zS65MqC//bF2mPzG+Wp7hy77meQw4IxDLsn0x6kUkVdF5B0ROSQi341tZ8vaQSyZGacTwFJVnYTohR6PisgksGXtoJZMY6VWAK2x25dE5DCA0QDuAzAndreNAF4DsCwto8yAwuoxJh976CaTfz7/RZO/OqzdfayVbfbjLbvX2guvRmxMuEQ4hAa0xon1O74dwF6wZe2glnThiMgwAH8A8LiqfhC/r6+WtWxXm5+SKhwRKUK0aJ5X1Zdjm5NqWct2tfmp3zWOiAiA3wI4rKpr4nblVMvawqrPmHzxCxUmz//Fn03+9vUvwyvxqxb3/Mauacoa/m7yiK7wr2kSJXMCcCaArwP4p4jsj21biWjB/D7WvvZdAA+kZYQUSsm8qvorAOlld/5fskBXxTPH5JI371UVVnza5PP1Q01+pHq3yQuHB/v66CWnZnXffuuZqWbfyC0HTS67lHtrmP5wxiEXFg65sHDIJafWOFe+1HM+5Mr37Fchrhy33eS7P2W/Hnqg2iKXTZ69banJE37yr+7bZRfsGqYr0JFzA2cccmHhkEtOPVWd/EpPnTdP2Tygx66/MNbktbttKxGJ2HOcE1adMLmmzV52m//fgdc3zjjkwsIhFxYOubCVG/WJrdwopVg45MLCIRcWDrmwcMiFhUMuLBxyYeGQCwuHXFg45MLCIZeMvlclIu8jetXnSAD+PiHpFdaxZWtcY1R1VOLGjBZO90FFGq/2xlkYhHVsYRsXn6rIhYVDLtkqnGezdNxkhHVsoRpXVtY4lPv4VEUuGS0cEakTkSYROSoiWW1vKyL1InJWRA7GbQtF7+Zc6C2dscIRkQIA6wHcA2ASgIWxfsnZ0gCgLmFbWHo3h7+3tKpm5AfADAA74vIKACsydfxexlQF4GBcbgJQEbtdAaApm+OLG9dWAHPDNL5MPlWNBvBeXG6JbQuT0PVuDmtvaS6Oe6HR/9ZZfcnp7S2dCZksnFMAKuPyzbFtYZJU7+ZMCNJbOhMyWTj7ANSISLWIDAGwANFeyWHySe9mIIu9m5PoLQ1ku7d0hhd59wJoBnAMwI+zvODchOiXm3Qgut56GMANiL5aOQJgJ4CyLI1tFqJPQ28D2B/7uTcs41NVnjkmHy6OyYWFQy4sHHJh4ZALC4dcWDjkwsIhFxYOufwfp3xNA0HdZ/0AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def transform_ds(b):\n", " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n", " return b\n", "\n", "dst = ds.with_transform(transform_ds)\n", "plt.imshow(dst['train'][0]['image'].permute(1,2,0));" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bs = 1024\n", "class DataLoaders:\n", " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n", " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n", " self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n", "\n", "def collate_fn(b):\n", " collate = default_collate(b)\n", " return (collate[x], collate[y])\n", "\n", "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n", "xb,yb = next(iter(dls.train))\n", "xb.shape, yb.shape" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class Reshape(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", " \n", " def forward(self, x):\n", " return x.reshape(self.dim)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def cnn_classifier():\n", " ks,stride = 3,2\n", " return nn.Sequential(\n", " nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.ReLU(),\n", " nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.ReLU(),\n", " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.ReLU(),\n", " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.ReLU(),\n", " nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.ReLU(),\n", " nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.Flatten(),\n", " )" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def linear_classifier():\n", " return nn.Sequential(\n", " Reshape((-1, 784)),\n", " nn.Linear(784, 50),\n", " nn.ReLU(),\n", " nn.Linear(50, 50),\n", " nn.ReLU(),\n", " nn.Linear(50, 10)\n", " )" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train, epoch:1, loss: 0.4002, accuracy: 0.7806\n", "eval, epoch:1, loss: 0.2896, accuracy: 0.9007\n", "train, epoch:2, loss: 0.2815, accuracy: 0.9171\n", "eval, epoch:2, loss: 0.2144, accuracy: 0.9318\n", "train, epoch:3, loss: 0.2128, accuracy: 0.9370\n", "eval, epoch:3, loss: 0.1721, accuracy: 0.9435\n", "train, epoch:4, loss: 0.1453, accuracy: 0.9489\n", "eval, epoch:4, loss: 0.1629, accuracy: 0.9590\n", "train, epoch:5, loss: 0.1110, accuracy: 0.9565\n", "eval, epoch:5, loss: 0.1162, accuracy: 0.9681\n" ] } ], "source": [ "model = linear_classifier()\n", "lr = 0.1\n", "max_lr = 0.1\n", "epochs = 5\n", "opt = optim.AdamW(model.parameters(), lr=lr)\n", "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", "\n", "for epoch in range(epochs):\n", " for train in (True, False):\n", " accuracy = 0\n", " dl = dls.train if train else dls.valid\n", " for xb,yb in dl:\n", " preds = model(xb)\n", " loss = F.cross_entropy(preds, yb)\n", " if train:\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " with torch.no_grad():\n", " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", " if train:\n", " sched.step()\n", " accuracy /= len(dl)\n", " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# with open('./linear_classifier.pkl', 'wb') as model_file:\n", "# pickle.dump(model, model_file)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from IPython.display import HTML, display, Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "\n", "\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%html\n", "\n", "
\n", "\n", "\n", "
\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "python_main", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }