{ "cells": [ { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "import torch, torchvision\n", "from torch import nn" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "class Mnet(nn.Module):\n", " def __init__(self):\n", " super(Mnet, self).__init__()\n", " self.linear1 = nn.Linear(28 * 28, 400)\n", " self.linear2 = nn.Linear(400, 200)\n", " self.linear3 = nn.Linear(200, 100)\n", " self.linear4 = nn.Linear(100, 50)\n", " self.linear5 = nn.Linear(50, 25)\n", " self.final_linear = nn.Linear(25, 10)\n", "\n", " self.relu = nn.ReLU()\n", "\n", " def forward(self, images):\n", " x = images.view(-1, 28 * 28)\n", " x = self.relu(self.linear1(x))\n", " x = self.relu(self.linear2(x))\n", " x = self.relu(self.linear3(x))\n", " x = self.relu(self.linear4(x))\n", " x = self.relu(self.linear5(x))\n", " x = self.final_linear(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "model = torch.load(\"mnistmodel.pt\")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "T = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])\n", "test_data = torchvision.datasets.MNIST(\"mnist_data\", train=False, transform=T, download=True)\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "#image, label = test_data[9016]\n", "#print(label)\n", "#plt.imshow(image[0])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "wrong answer 149\n", "wrong answer 151\n", "wrong answer 247\n", "wrong answer 259\n", "wrong answer 268\n", "wrong answer 340\n", "wrong answer 445\n", "wrong answer 495\n", "wrong answer 582\n", "wrong answer 684\n", "wrong answer 720\n", "wrong answer 844\n", "wrong answer 938\n", "wrong answer 947\n", "wrong answer 1014\n", "wrong answer 1039\n", "wrong answer 1200\n", "wrong answer 1226\n", "wrong answer 1232\n", "wrong answer 1242\n", "wrong answer 1247\n", "wrong answer 1260\n", "wrong answer 1289\n", "wrong answer 1319\n", "wrong answer 1328\n", "wrong answer 1393\n", "wrong answer 1414\n", "wrong answer 1425\n", "wrong answer 1530\n", "wrong answer 1553\n", "wrong answer 1569\n", "wrong answer 1681\n", "wrong answer 1717\n", "wrong answer 1751\n", "wrong answer 1754\n", "wrong answer 1790\n", "wrong answer 1800\n", "wrong answer 1850\n", "wrong answer 1878\n", "wrong answer 1880\n", "wrong answer 1901\n", "wrong answer 1952\n", "wrong answer 2024\n", "wrong answer 2109\n", "wrong answer 2118\n", "wrong answer 2130\n", "wrong answer 2135\n", "wrong answer 2224\n", "wrong answer 2293\n", "wrong answer 2369\n", "wrong answer 2387\n", "wrong answer 2406\n", "wrong answer 2414\n", "wrong answer 2422\n", "wrong answer 2488\n", "wrong answer 2582\n", "wrong answer 2597\n", "wrong answer 2648\n", "wrong answer 2654\n", "wrong answer 2720\n", "wrong answer 2863\n", "wrong answer 2877\n", "wrong answer 2896\n", "wrong answer 2921\n", "wrong answer 2927\n", "wrong answer 2939\n", "wrong answer 2953\n", "wrong answer 2979\n", "wrong answer 3060\n", "wrong answer 3073\n", "wrong answer 3117\n", "wrong answer 3263\n", "wrong answer 3284\n", "wrong answer 3394\n", "wrong answer 3422\n", "wrong answer 3475\n", "wrong answer 3503\n", "wrong answer 3520\n", "wrong answer 3558\n", "wrong answer 3565\n", "wrong answer 3567\n", "wrong answer 3597\n", "wrong answer 3727\n", "wrong answer 3767\n", "wrong answer 3776\n", "wrong answer 3796\n", "wrong answer 3808\n", "wrong answer 3811\n", "wrong answer 3817\n", "wrong answer 3818\n", "wrong answer 3869\n", "wrong answer 3893\n", "wrong answer 3906\n", "wrong answer 3941\n", "wrong answer 3943\n", "wrong answer 3970\n", "wrong answer 3985\n", "wrong answer 4000\n", "wrong answer 4065\n", "wrong answer 4075\n", "wrong answer 4140\n", "wrong answer 4163\n", "wrong answer 4176\n", "wrong answer 4199\n", "wrong answer 4224\n", "wrong answer 4248\n", "wrong answer 4289\n", "wrong answer 4350\n", "wrong answer 4369\n", "wrong answer 4437\n", "wrong answer 4497\n", "wrong answer 4504\n", "wrong answer 4536\n", "wrong answer 4547\n", "wrong answer 4571\n", "wrong answer 4601\n", "wrong answer 4731\n", "wrong answer 4740\n", "wrong answer 4761\n", "wrong answer 4807\n", "wrong answer 4823\n", "wrong answer 4833\n", "wrong answer 4956\n", "wrong answer 4966\n", "wrong answer 5078\n", "wrong answer 5265\n", "wrong answer 5331\n", "wrong answer 5457\n", "wrong answer 5586\n", "wrong answer 5676\n", "wrong answer 5734\n", "wrong answer 5749\n", "wrong answer 5887\n", "wrong answer 5888\n", "wrong answer 5955\n", "wrong answer 5973\n", "wrong answer 6011\n", "wrong answer 6059\n", "wrong answer 6555\n", "wrong answer 6571\n", "wrong answer 6597\n", "wrong answer 6603\n", "wrong answer 6625\n", "wrong answer 6641\n", "wrong answer 6651\n", "wrong answer 6755\n", "wrong answer 6783\n", "wrong answer 6847\n", "wrong answer 7434\n", "wrong answer 7921\n", "wrong answer 8094\n", "wrong answer 8246\n", "wrong answer 8311\n", "wrong answer 8382\n", "wrong answer 8408\n", "wrong answer 8456\n", "wrong answer 8522\n", "wrong answer 8527\n", "wrong answer 9009\n", "wrong answer 9015\n", "wrong answer 9024\n", "wrong answer 9280\n", "wrong answer 9587\n", "wrong answer 9634\n", "wrong answer 9664\n", "wrong answer 9669\n", "wrong answer 9679\n", "wrong answer 9729\n", "wrong answer 9745\n", "wrong answer 9749\n", "wrong answer 9768\n", "wrong answer 9770\n", "wrong answer 9792\n", "wrong answer 9808\n", "wrong answer 9858\n", "9825 10000\n" ] } ], "source": [ "#정답률\n", "\n", "total_test = len(test_data)\n", "correct_answer = 0\n", "\n", "for i, (image, label) in enumerate(test_data):\n", " output = model(image)\n", " s = nn.Softmax(dim=1)\n", " output = s(output)\n", " a = torch.argmax(output)\n", " if label == a.item():\n", " correct_answer+=1\n", " else:\n", " print('wrong answer', i)\n", "\n", "print(correct_answer, total_test)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "computer's guess: 3, answer: 3\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#틀린 1문제\n", "\n", "def testexam(i: int):\n", " image, label = test_data[i]\n", " output = model(image)\n", " s = nn.Softmax(dim=1)\n", " output = s(output)\n", " a = torch.argmax(output)\n", " print(f\"computer's guess: {a.item()}, answer: {label}\")\n", " plt.imshow(image[0])\n", "\n", "\n", "testexam(9975)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 2 }