refactor: changing structure
Browse files- model/cnn.py → cnn/__init__.py +0 -1
- {model → cnn}/model-old.pt +0 -0
- {model → cnn}/model.pt +0 -0
- testbench.ipynb +8 -8
- train.ipynb +3 -3
model/cnn.py → cnn/__init__.py
RENAMED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch.nn.functional as F
|
| 3 |
|
{model → cnn}/model-old.pt
RENAMED
|
File without changes
|
{model → cnn}/model.pt
RENAMED
|
File without changes
|
testbench.ipynb
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "c831c34c",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [
|
|
@@ -20,10 +20,10 @@
|
|
| 20 |
"import torch\n",
|
| 21 |
"from torch.utils.data import DataLoader\n",
|
| 22 |
"from torchvision import datasets, transforms\n",
|
| 23 |
-
"from
|
| 24 |
"\n",
|
| 25 |
"model = CNN()\n",
|
| 26 |
-
"model.load_state_dict(torch.load(\"model.pt\"))\n",
|
| 27 |
"\n",
|
| 28 |
"check_gpu = torch.cuda.is_available()\n",
|
| 29 |
"device = torch.device(\"cpu\")\n",
|
|
@@ -41,7 +41,7 @@
|
|
| 41 |
},
|
| 42 |
{
|
| 43 |
"cell_type": "code",
|
| 44 |
-
"execution_count":
|
| 45 |
"id": "cd2d6928",
|
| 46 |
"metadata": {},
|
| 47 |
"outputs": [],
|
|
@@ -57,7 +57,7 @@
|
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
-
"execution_count":
|
| 61 |
"id": "f7bb207f",
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
|
@@ -72,7 +72,7 @@
|
|
| 72 |
},
|
| 73 |
{
|
| 74 |
"cell_type": "code",
|
| 75 |
-
"execution_count":
|
| 76 |
"id": "9ca78681",
|
| 77 |
"metadata": {},
|
| 78 |
"outputs": [],
|
|
@@ -82,7 +82,7 @@
|
|
| 82 |
},
|
| 83 |
{
|
| 84 |
"cell_type": "code",
|
| 85 |
-
"execution_count":
|
| 86 |
"id": "9c5c7fae",
|
| 87 |
"metadata": {},
|
| 88 |
"outputs": [
|
|
@@ -137,7 +137,7 @@
|
|
| 137 |
},
|
| 138 |
{
|
| 139 |
"cell_type": "code",
|
| 140 |
-
"execution_count":
|
| 141 |
"id": "1e171b86",
|
| 142 |
"metadata": {},
|
| 143 |
"outputs": [
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 42,
|
| 6 |
"id": "c831c34c",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [
|
|
|
|
| 20 |
"import torch\n",
|
| 21 |
"from torch.utils.data import DataLoader\n",
|
| 22 |
"from torchvision import datasets, transforms\n",
|
| 23 |
+
"from cnn import CNN\n",
|
| 24 |
"\n",
|
| 25 |
"model = CNN()\n",
|
| 26 |
+
"model.load_state_dict(torch.load(\"cnn/model.pt\"))\n",
|
| 27 |
"\n",
|
| 28 |
"check_gpu = torch.cuda.is_available()\n",
|
| 29 |
"device = torch.device(\"cpu\")\n",
|
|
|
|
| 41 |
},
|
| 42 |
{
|
| 43 |
"cell_type": "code",
|
| 44 |
+
"execution_count": 43,
|
| 45 |
"id": "cd2d6928",
|
| 46 |
"metadata": {},
|
| 47 |
"outputs": [],
|
|
|
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
+
"execution_count": 44,
|
| 61 |
"id": "f7bb207f",
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
|
|
|
| 72 |
},
|
| 73 |
{
|
| 74 |
"cell_type": "code",
|
| 75 |
+
"execution_count": 45,
|
| 76 |
"id": "9ca78681",
|
| 77 |
"metadata": {},
|
| 78 |
"outputs": [],
|
|
|
|
| 82 |
},
|
| 83 |
{
|
| 84 |
"cell_type": "code",
|
| 85 |
+
"execution_count": 46,
|
| 86 |
"id": "9c5c7fae",
|
| 87 |
"metadata": {},
|
| 88 |
"outputs": [
|
|
|
|
| 137 |
},
|
| 138 |
{
|
| 139 |
"cell_type": "code",
|
| 140 |
+
"execution_count": 47,
|
| 141 |
"id": "1e171b86",
|
| 142 |
"metadata": {},
|
| 143 |
"outputs": [
|
train.ipynb
CHANGED
|
@@ -9,7 +9,7 @@
|
|
| 9 |
},
|
| 10 |
{
|
| 11 |
"cell_type": "code",
|
| 12 |
-
"execution_count":
|
| 13 |
"metadata": {},
|
| 14 |
"outputs": [
|
| 15 |
{
|
|
@@ -28,7 +28,7 @@
|
|
| 28 |
"import numpy as np\n",
|
| 29 |
"import matplotlib.pyplot as plt\n",
|
| 30 |
"import torch.nn as nn\n",
|
| 31 |
-
"from
|
| 32 |
"from tabulate import tabulate\n",
|
| 33 |
"\n",
|
| 34 |
"\n",
|
|
@@ -465,7 +465,7 @@
|
|
| 465 |
" if valid_loss < min_valid_loss:\n",
|
| 466 |
" saved = \"*\"\n",
|
| 467 |
" min_valid_loss = valid_loss\n",
|
| 468 |
-
" torch.save(model.state_dict(), \"model.pt\")\n",
|
| 469 |
"\n",
|
| 470 |
" row = [\n",
|
| 471 |
" epoch + 1,\n",
|
|
|
|
| 9 |
},
|
| 10 |
{
|
| 11 |
"cell_type": "code",
|
| 12 |
+
"execution_count": 19,
|
| 13 |
"metadata": {},
|
| 14 |
"outputs": [
|
| 15 |
{
|
|
|
|
| 28 |
"import numpy as np\n",
|
| 29 |
"import matplotlib.pyplot as plt\n",
|
| 30 |
"import torch.nn as nn\n",
|
| 31 |
+
"from cnn import CNN\n",
|
| 32 |
"from tabulate import tabulate\n",
|
| 33 |
"\n",
|
| 34 |
"\n",
|
|
|
|
| 465 |
" if valid_loss < min_valid_loss:\n",
|
| 466 |
" saved = \"*\"\n",
|
| 467 |
" min_valid_loss = valid_loss\n",
|
| 468 |
+
" torch.save(model.state_dict(), \"cnn/model.pt\")\n",
|
| 469 |
"\n",
|
| 470 |
" row = [\n",
|
| 471 |
" epoch + 1,\n",
|