Hannes Kuchelmeister
commited on
Commit
•
ad947ed
1
Parent(s):
596687f
add testing notebook for convolutional neural network
Browse files
notebooks/2.0-hfk-convolutional_testing.ipynb
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Testing FocusDataSet"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"name": "stdout",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"2450\n"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"data": {
|
24 |
+
"text/plain": [
|
25 |
+
"{'image': array([[[211, 185, 62],\n",
|
26 |
+
" [216, 192, 68],\n",
|
27 |
+
" [223, 198, 79],\n",
|
28 |
+
" ...,\n",
|
29 |
+
" [214, 190, 64],\n",
|
30 |
+
" [222, 199, 71],\n",
|
31 |
+
" [224, 201, 73]],\n",
|
32 |
+
" \n",
|
33 |
+
" [[218, 192, 69],\n",
|
34 |
+
" [223, 197, 74],\n",
|
35 |
+
" [229, 205, 83],\n",
|
36 |
+
" ...,\n",
|
37 |
+
" [216, 193, 65],\n",
|
38 |
+
" [225, 202, 74],\n",
|
39 |
+
" [226, 203, 75]],\n",
|
40 |
+
" \n",
|
41 |
+
" [[223, 198, 72],\n",
|
42 |
+
" [228, 202, 79],\n",
|
43 |
+
" [234, 210, 88],\n",
|
44 |
+
" ...,\n",
|
45 |
+
" [220, 197, 69],\n",
|
46 |
+
" [228, 205, 77],\n",
|
47 |
+
" [226, 203, 73]],\n",
|
48 |
+
" \n",
|
49 |
+
" ...,\n",
|
50 |
+
" \n",
|
51 |
+
" [[157, 138, 17],\n",
|
52 |
+
" [163, 145, 21],\n",
|
53 |
+
" [178, 157, 32],\n",
|
54 |
+
" ...,\n",
|
55 |
+
" [166, 169, 40],\n",
|
56 |
+
" [170, 173, 42],\n",
|
57 |
+
" [176, 179, 46]],\n",
|
58 |
+
" \n",
|
59 |
+
" [[145, 126, 5],\n",
|
60 |
+
" [155, 137, 13],\n",
|
61 |
+
" [177, 156, 31],\n",
|
62 |
+
" ...,\n",
|
63 |
+
" [156, 158, 31],\n",
|
64 |
+
" [166, 169, 40],\n",
|
65 |
+
" [175, 178, 47]],\n",
|
66 |
+
" \n",
|
67 |
+
" [[147, 128, 7],\n",
|
68 |
+
" [159, 141, 17],\n",
|
69 |
+
" [181, 160, 35],\n",
|
70 |
+
" ...,\n",
|
71 |
+
" [149, 151, 24],\n",
|
72 |
+
" [162, 164, 37],\n",
|
73 |
+
" [172, 175, 46]]], dtype=uint8),\n",
|
74 |
+
" 'focus_value': tensor(0.5450)}"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
"execution_count": 1,
|
78 |
+
"metadata": {},
|
79 |
+
"output_type": "execute_result"
|
80 |
+
}
|
81 |
+
],
|
82 |
+
"source": [
|
83 |
+
"from importlib.machinery import SourceFileLoader\n",
|
84 |
+
"\n",
|
85 |
+
"focus_datamodule = SourceFileLoader(\"focus_datamodule\", \"../src/datamodules/focus_datamodule.py\").load_module()\n",
|
86 |
+
"from focus_datamodule import FocusDataSet\n",
|
87 |
+
"\n",
|
88 |
+
"ds = FocusDataSet(\"../data/focus150/metadata.csv\", \"../data/focus150/\")\n",
|
89 |
+
"\n",
|
90 |
+
"counter = 0\n",
|
91 |
+
"for d in ds:\n",
|
92 |
+
" counter += 1\n",
|
93 |
+
"\n",
|
94 |
+
"print(counter)\n",
|
95 |
+
"\n",
|
96 |
+
"d"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 81,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [
|
104 |
+
{
|
105 |
+
"name": "stdout",
|
106 |
+
"output_type": "stream",
|
107 |
+
"text": [
|
108 |
+
"14\n"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"data": {
|
113 |
+
"text/plain": [
|
114 |
+
"torch.Size([64, 1])"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
"execution_count": 81,
|
118 |
+
"metadata": {},
|
119 |
+
"output_type": "execute_result"
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"source": [
|
123 |
+
"from focus_datamodule import FocusDataModule\n",
|
124 |
+
"\n",
|
125 |
+
"datamodule = FocusDataModule(data_dir=\"../data/focus150\", csv_file=\"../data/focus150/metadata.csv\")\n",
|
126 |
+
"datamodule.setup()\n",
|
127 |
+
"\n",
|
128 |
+
"for data in datamodule.test_dataloader():\n",
|
129 |
+
" break\n",
|
130 |
+
"\n",
|
131 |
+
"len(data[\"focus_value\"])\n",
|
132 |
+
"\n",
|
133 |
+
"# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html\n",
|
134 |
+
"import torch\n",
|
135 |
+
"import torch.nn as nn\n",
|
136 |
+
"import torch.nn.functional as F\n",
|
137 |
+
"\n",
|
138 |
+
"class Net(nn.Module):\n",
|
139 |
+
" def __init__(self):\n",
|
140 |
+
" super().__init__()\n",
|
141 |
+
" pool_size = 3\n",
|
142 |
+
" \n",
|
143 |
+
" conv1_size = 5\n",
|
144 |
+
" conv1_out = 6\n",
|
145 |
+
" conv2_size = 5\n",
|
146 |
+
" conv2_out = 16\n",
|
147 |
+
" size_img = 150\n",
|
148 |
+
"\n",
|
149 |
+
" size_img -= conv1_size - 1\n",
|
150 |
+
" size_img = int( (size_img) / pool_size)\n",
|
151 |
+
" size_img -= conv2_size - 1\n",
|
152 |
+
" size_img = int(size_img / pool_size)\n",
|
153 |
+
"\n",
|
154 |
+
" print(size_img)\n",
|
155 |
+
"\n",
|
156 |
+
" self.model = nn.Sequential(\n",
|
157 |
+
" nn.Conv2d(3, conv1_out, conv1_size),\n",
|
158 |
+
" nn.MaxPool2d(pool_size, pool_size),\n",
|
159 |
+
" nn.Conv2d(conv1_out, conv2_out, conv2_size),\n",
|
160 |
+
" nn.MaxPool2d(pool_size, pool_size),\n",
|
161 |
+
" nn.Flatten(),\n",
|
162 |
+
" nn.Linear(conv2_out * size_img * size_img, 120), # 16 * 34 * 34 or [64, 16, 15, 15]\n",
|
163 |
+
" nn.Linear(120, 84),\n",
|
164 |
+
" nn.Linear(84, 1)\n",
|
165 |
+
" )\n",
|
166 |
+
"\n",
|
167 |
+
" def forward(self, x):\n",
|
168 |
+
" x = self.model(x)\n",
|
169 |
+
" return x\n",
|
170 |
+
"\n",
|
171 |
+
"\n",
|
172 |
+
"net = Net()\n",
|
173 |
+
"\n",
|
174 |
+
"net(data[\"image\"]).shape"
|
175 |
+
]
|
176 |
+
}
|
177 |
+
],
|
178 |
+
"metadata": {
|
179 |
+
"interpreter": {
|
180 |
+
"hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac"
|
181 |
+
},
|
182 |
+
"kernelspec": {
|
183 |
+
"display_name": "Python 3.9.7 64-bit",
|
184 |
+
"language": "python",
|
185 |
+
"name": "python3"
|
186 |
+
},
|
187 |
+
"language_info": {
|
188 |
+
"codemirror_mode": {
|
189 |
+
"name": "ipython",
|
190 |
+
"version": 3
|
191 |
+
},
|
192 |
+
"file_extension": ".py",
|
193 |
+
"mimetype": "text/x-python",
|
194 |
+
"name": "python",
|
195 |
+
"nbconvert_exporter": "python",
|
196 |
+
"pygments_lexer": "ipython3",
|
197 |
+
"version": "3.8.10"
|
198 |
+
},
|
199 |
+
"orig_nbformat": 4
|
200 |
+
},
|
201 |
+
"nbformat": 4,
|
202 |
+
"nbformat_minor": 2
|
203 |
+
}
|
src/models/focus_conv_module.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from pytorch_lightning import LightningModule
|
7 |
+
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
|
8 |
+
from torchmetrics.classification.accuracy import Accuracy
|
9 |
+
|
10 |
+
|
11 |
+
class SimpleConvNet(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.conv1 = nn.Conv2d(3, 6, 5)
|
15 |
+
self.pool = nn.MaxPool2d(2, 2)
|
16 |
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
17 |
+
self.pool = nn.MaxPool2d(2, 2)
|
18 |
+
self.conv3 = nn.Conv2d(6, 16, 5)
|
19 |
+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
20 |
+
self.fc2 = nn.Linear(120, 84)
|
21 |
+
self.fc3 = nn.Linear(84, 10)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.pool(F.relu(self.conv1(x)))
|
25 |
+
x = self.pool(F.relu(self.conv2(x)))
|
26 |
+
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
27 |
+
x = F.relu(self.fc1(x))
|
28 |
+
x = F.relu(self.fc2(x))
|
29 |
+
x = self.fc3(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class SimpleDenseNet(nn.Module):
|
34 |
+
def __init__(self, hparams: dict):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.model = nn.Sequential(
|
38 |
+
nn.Linear(hparams["input_size"], hparams["lin1_size"]),
|
39 |
+
nn.BatchNorm1d(hparams["lin1_size"]),
|
40 |
+
nn.ReLU(),
|
41 |
+
nn.Linear(hparams["lin1_size"], hparams["lin2_size"]),
|
42 |
+
nn.BatchNorm1d(hparams["lin2_size"]),
|
43 |
+
nn.ReLU(),
|
44 |
+
nn.Linear(hparams["lin2_size"], hparams["lin3_size"]),
|
45 |
+
nn.BatchNorm1d(hparams["lin3_size"]),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Linear(hparams["lin3_size"], hparams["output_size"]),
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
batch_size, channels, width, height = x.size()
|
52 |
+
|
53 |
+
# (batch, 1, width, height) -> (batch, 1*width*height)
|
54 |
+
x = x.view(batch_size, -1)
|
55 |
+
|
56 |
+
return self.model(x)
|
57 |
+
|
58 |
+
|
59 |
+
class FocusLitModule(LightningModule):
|
60 |
+
"""
|
61 |
+
Example of LightningModule for MNIST classification.
|
62 |
+
|
63 |
+
A LightningModule organizes your PyTorch code into 5 sections:
|
64 |
+
- Computations (init).
|
65 |
+
- Train loop (training_step)
|
66 |
+
- Validation loop (validation_step)
|
67 |
+
- Test loop (test_step)
|
68 |
+
- Optimizers (configure_optimizers)
|
69 |
+
|
70 |
+
Read the docs:
|
71 |
+
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
input_size: int = 75 * 75 * 3,
|
77 |
+
lin1_size: int = 256,
|
78 |
+
lin2_size: int = 256,
|
79 |
+
lin3_size: int = 256,
|
80 |
+
output_size: int = 1,
|
81 |
+
lr: float = 0.001,
|
82 |
+
weight_decay: float = 0.0005,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
# this line allows to access init params with 'self.hparams' attribute
|
87 |
+
# it also ensures init params will be stored in ckpt
|
88 |
+
self.save_hyperparameters(logger=False)
|
89 |
+
|
90 |
+
self.model = SimpleDenseNet(hparams=self.hparams)
|
91 |
+
|
92 |
+
# loss function
|
93 |
+
self.criterion = torch.nn.L1Loss()
|
94 |
+
|
95 |
+
# use separate metric instance for train, val and test step
|
96 |
+
# to ensure a proper reduction over the epoch
|
97 |
+
self.train_mae = MeanAbsoluteError()
|
98 |
+
self.val_mae = MeanAbsoluteError()
|
99 |
+
self.test_mae = MeanAbsoluteError()
|
100 |
+
|
101 |
+
# for logging best so far validation accuracy
|
102 |
+
self.val_mae_best = MinMetric()
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor):
|
105 |
+
return self.model(x)
|
106 |
+
|
107 |
+
def step(self, batch: Any):
|
108 |
+
x = batch["image"]
|
109 |
+
y = batch["focus_value"]
|
110 |
+
logits = self.forward(x)
|
111 |
+
loss = self.criterion(logits, y)
|
112 |
+
preds = torch.squeeze(logits)
|
113 |
+
return loss, preds, y
|
114 |
+
|
115 |
+
def training_step(self, batch: Any, batch_idx: int):
|
116 |
+
loss, preds, targets = self.step(batch)
|
117 |
+
|
118 |
+
# log train metrics
|
119 |
+
mae = self.train_mae(preds, targets)
|
120 |
+
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
121 |
+
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
122 |
+
|
123 |
+
# we can return here dict with any tensors
|
124 |
+
# and then read it in some callback or in `training_epoch_end()`` below
|
125 |
+
# remember to always return loss from `training_step()` or else backpropagation will fail!
|
126 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
127 |
+
|
128 |
+
def training_epoch_end(self, outputs: List[Any]):
|
129 |
+
# `outputs` is a list of dicts returned from `training_step()`
|
130 |
+
pass
|
131 |
+
|
132 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
133 |
+
loss, preds, targets = self.step(batch)
|
134 |
+
|
135 |
+
# log val metrics
|
136 |
+
mae = self.val_mae(preds, targets)
|
137 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
138 |
+
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
139 |
+
|
140 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
141 |
+
|
142 |
+
def validation_epoch_end(self, outputs: List[Any]):
|
143 |
+
mae = self.val_mae.compute() # get val accuracy from current epoch
|
144 |
+
self.val_mae_best.update(mae)
|
145 |
+
self.log(
|
146 |
+
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
|
147 |
+
)
|
148 |
+
|
149 |
+
def test_step(self, batch: Any, batch_idx: int):
|
150 |
+
loss, preds, targets = self.step(batch)
|
151 |
+
|
152 |
+
# log test metrics
|
153 |
+
mae = self.test_mae(preds, targets)
|
154 |
+
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
155 |
+
self.log("test/mae", mae, on_step=False, on_epoch=True)
|
156 |
+
|
157 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
158 |
+
|
159 |
+
def test_epoch_end(self, outputs: List[Any]):
|
160 |
+
print(outputs)
|
161 |
+
pass
|
162 |
+
|
163 |
+
def on_epoch_end(self):
|
164 |
+
# reset metrics at the end of every epoch
|
165 |
+
self.train_mae.reset()
|
166 |
+
self.test_mae.reset()
|
167 |
+
self.val_mae.reset()
|
168 |
+
|
169 |
+
def configure_optimizers(self):
|
170 |
+
"""Choose what optimizers and learning-rate schedulers.
|
171 |
+
|
172 |
+
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
173 |
+
|
174 |
+
See examples here:
|
175 |
+
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
176 |
+
"""
|
177 |
+
return torch.optim.Adam(
|
178 |
+
params=self.parameters(),
|
179 |
+
lr=self.hparams.lr,
|
180 |
+
weight_decay=self.hparams.weight_decay,
|
181 |
+
)
|