Hannes Kuchelmeister commited on
Commit
ad947ed
·
1 Parent(s): 596687f

add testing notebook for convolutional neural network

Browse files
notebooks/2.0-hfk-convolutional_testing.ipynb ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Testing FocusDataSet"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "2450\n"
20
+ ]
21
+ },
22
+ {
23
+ "data": {
24
+ "text/plain": [
25
+ "{'image': array([[[211, 185, 62],\n",
26
+ " [216, 192, 68],\n",
27
+ " [223, 198, 79],\n",
28
+ " ...,\n",
29
+ " [214, 190, 64],\n",
30
+ " [222, 199, 71],\n",
31
+ " [224, 201, 73]],\n",
32
+ " \n",
33
+ " [[218, 192, 69],\n",
34
+ " [223, 197, 74],\n",
35
+ " [229, 205, 83],\n",
36
+ " ...,\n",
37
+ " [216, 193, 65],\n",
38
+ " [225, 202, 74],\n",
39
+ " [226, 203, 75]],\n",
40
+ " \n",
41
+ " [[223, 198, 72],\n",
42
+ " [228, 202, 79],\n",
43
+ " [234, 210, 88],\n",
44
+ " ...,\n",
45
+ " [220, 197, 69],\n",
46
+ " [228, 205, 77],\n",
47
+ " [226, 203, 73]],\n",
48
+ " \n",
49
+ " ...,\n",
50
+ " \n",
51
+ " [[157, 138, 17],\n",
52
+ " [163, 145, 21],\n",
53
+ " [178, 157, 32],\n",
54
+ " ...,\n",
55
+ " [166, 169, 40],\n",
56
+ " [170, 173, 42],\n",
57
+ " [176, 179, 46]],\n",
58
+ " \n",
59
+ " [[145, 126, 5],\n",
60
+ " [155, 137, 13],\n",
61
+ " [177, 156, 31],\n",
62
+ " ...,\n",
63
+ " [156, 158, 31],\n",
64
+ " [166, 169, 40],\n",
65
+ " [175, 178, 47]],\n",
66
+ " \n",
67
+ " [[147, 128, 7],\n",
68
+ " [159, 141, 17],\n",
69
+ " [181, 160, 35],\n",
70
+ " ...,\n",
71
+ " [149, 151, 24],\n",
72
+ " [162, 164, 37],\n",
73
+ " [172, 175, 46]]], dtype=uint8),\n",
74
+ " 'focus_value': tensor(0.5450)}"
75
+ ]
76
+ },
77
+ "execution_count": 1,
78
+ "metadata": {},
79
+ "output_type": "execute_result"
80
+ }
81
+ ],
82
+ "source": [
83
+ "from importlib.machinery import SourceFileLoader\n",
84
+ "\n",
85
+ "focus_datamodule = SourceFileLoader(\"focus_datamodule\", \"../src/datamodules/focus_datamodule.py\").load_module()\n",
86
+ "from focus_datamodule import FocusDataSet\n",
87
+ "\n",
88
+ "ds = FocusDataSet(\"../data/focus150/metadata.csv\", \"../data/focus150/\")\n",
89
+ "\n",
90
+ "counter = 0\n",
91
+ "for d in ds:\n",
92
+ " counter += 1\n",
93
+ "\n",
94
+ "print(counter)\n",
95
+ "\n",
96
+ "d"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 81,
102
+ "metadata": {},
103
+ "outputs": [
104
+ {
105
+ "name": "stdout",
106
+ "output_type": "stream",
107
+ "text": [
108
+ "14\n"
109
+ ]
110
+ },
111
+ {
112
+ "data": {
113
+ "text/plain": [
114
+ "torch.Size([64, 1])"
115
+ ]
116
+ },
117
+ "execution_count": 81,
118
+ "metadata": {},
119
+ "output_type": "execute_result"
120
+ }
121
+ ],
122
+ "source": [
123
+ "from focus_datamodule import FocusDataModule\n",
124
+ "\n",
125
+ "datamodule = FocusDataModule(data_dir=\"../data/focus150\", csv_file=\"../data/focus150/metadata.csv\")\n",
126
+ "datamodule.setup()\n",
127
+ "\n",
128
+ "for data in datamodule.test_dataloader():\n",
129
+ " break\n",
130
+ "\n",
131
+ "len(data[\"focus_value\"])\n",
132
+ "\n",
133
+ "# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html\n",
134
+ "import torch\n",
135
+ "import torch.nn as nn\n",
136
+ "import torch.nn.functional as F\n",
137
+ "\n",
138
+ "class Net(nn.Module):\n",
139
+ " def __init__(self):\n",
140
+ " super().__init__()\n",
141
+ " pool_size = 3\n",
142
+ " \n",
143
+ " conv1_size = 5\n",
144
+ " conv1_out = 6\n",
145
+ " conv2_size = 5\n",
146
+ " conv2_out = 16\n",
147
+ " size_img = 150\n",
148
+ "\n",
149
+ " size_img -= conv1_size - 1\n",
150
+ " size_img = int( (size_img) / pool_size)\n",
151
+ " size_img -= conv2_size - 1\n",
152
+ " size_img = int(size_img / pool_size)\n",
153
+ "\n",
154
+ " print(size_img)\n",
155
+ "\n",
156
+ " self.model = nn.Sequential(\n",
157
+ " nn.Conv2d(3, conv1_out, conv1_size),\n",
158
+ " nn.MaxPool2d(pool_size, pool_size),\n",
159
+ " nn.Conv2d(conv1_out, conv2_out, conv2_size),\n",
160
+ " nn.MaxPool2d(pool_size, pool_size),\n",
161
+ " nn.Flatten(),\n",
162
+ " nn.Linear(conv2_out * size_img * size_img, 120), # 16 * 34 * 34 or [64, 16, 15, 15]\n",
163
+ " nn.Linear(120, 84),\n",
164
+ " nn.Linear(84, 1)\n",
165
+ " )\n",
166
+ "\n",
167
+ " def forward(self, x):\n",
168
+ " x = self.model(x)\n",
169
+ " return x\n",
170
+ "\n",
171
+ "\n",
172
+ "net = Net()\n",
173
+ "\n",
174
+ "net(data[\"image\"]).shape"
175
+ ]
176
+ }
177
+ ],
178
+ "metadata": {
179
+ "interpreter": {
180
+ "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac"
181
+ },
182
+ "kernelspec": {
183
+ "display_name": "Python 3.9.7 64-bit",
184
+ "language": "python",
185
+ "name": "python3"
186
+ },
187
+ "language_info": {
188
+ "codemirror_mode": {
189
+ "name": "ipython",
190
+ "version": 3
191
+ },
192
+ "file_extension": ".py",
193
+ "mimetype": "text/x-python",
194
+ "name": "python",
195
+ "nbconvert_exporter": "python",
196
+ "pygments_lexer": "ipython3",
197
+ "version": "3.8.10"
198
+ },
199
+ "orig_nbformat": 4
200
+ },
201
+ "nbformat": 4,
202
+ "nbformat_minor": 2
203
+ }
src/models/focus_conv_module.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from pytorch_lightning import LightningModule
7
+ from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
8
+ from torchmetrics.classification.accuracy import Accuracy
9
+
10
+
11
+ class SimpleConvNet(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.conv1 = nn.Conv2d(3, 6, 5)
15
+ self.pool = nn.MaxPool2d(2, 2)
16
+ self.conv2 = nn.Conv2d(6, 16, 5)
17
+ self.pool = nn.MaxPool2d(2, 2)
18
+ self.conv3 = nn.Conv2d(6, 16, 5)
19
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
20
+ self.fc2 = nn.Linear(120, 84)
21
+ self.fc3 = nn.Linear(84, 10)
22
+
23
+ def forward(self, x):
24
+ x = self.pool(F.relu(self.conv1(x)))
25
+ x = self.pool(F.relu(self.conv2(x)))
26
+ x = torch.flatten(x, 1) # flatten all dimensions except batch
27
+ x = F.relu(self.fc1(x))
28
+ x = F.relu(self.fc2(x))
29
+ x = self.fc3(x)
30
+ return x
31
+
32
+
33
+ class SimpleDenseNet(nn.Module):
34
+ def __init__(self, hparams: dict):
35
+ super().__init__()
36
+
37
+ self.model = nn.Sequential(
38
+ nn.Linear(hparams["input_size"], hparams["lin1_size"]),
39
+ nn.BatchNorm1d(hparams["lin1_size"]),
40
+ nn.ReLU(),
41
+ nn.Linear(hparams["lin1_size"], hparams["lin2_size"]),
42
+ nn.BatchNorm1d(hparams["lin2_size"]),
43
+ nn.ReLU(),
44
+ nn.Linear(hparams["lin2_size"], hparams["lin3_size"]),
45
+ nn.BatchNorm1d(hparams["lin3_size"]),
46
+ nn.ReLU(),
47
+ nn.Linear(hparams["lin3_size"], hparams["output_size"]),
48
+ )
49
+
50
+ def forward(self, x):
51
+ batch_size, channels, width, height = x.size()
52
+
53
+ # (batch, 1, width, height) -> (batch, 1*width*height)
54
+ x = x.view(batch_size, -1)
55
+
56
+ return self.model(x)
57
+
58
+
59
+ class FocusLitModule(LightningModule):
60
+ """
61
+ Example of LightningModule for MNIST classification.
62
+
63
+ A LightningModule organizes your PyTorch code into 5 sections:
64
+ - Computations (init).
65
+ - Train loop (training_step)
66
+ - Validation loop (validation_step)
67
+ - Test loop (test_step)
68
+ - Optimizers (configure_optimizers)
69
+
70
+ Read the docs:
71
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ input_size: int = 75 * 75 * 3,
77
+ lin1_size: int = 256,
78
+ lin2_size: int = 256,
79
+ lin3_size: int = 256,
80
+ output_size: int = 1,
81
+ lr: float = 0.001,
82
+ weight_decay: float = 0.0005,
83
+ ):
84
+ super().__init__()
85
+
86
+ # this line allows to access init params with 'self.hparams' attribute
87
+ # it also ensures init params will be stored in ckpt
88
+ self.save_hyperparameters(logger=False)
89
+
90
+ self.model = SimpleDenseNet(hparams=self.hparams)
91
+
92
+ # loss function
93
+ self.criterion = torch.nn.L1Loss()
94
+
95
+ # use separate metric instance for train, val and test step
96
+ # to ensure a proper reduction over the epoch
97
+ self.train_mae = MeanAbsoluteError()
98
+ self.val_mae = MeanAbsoluteError()
99
+ self.test_mae = MeanAbsoluteError()
100
+
101
+ # for logging best so far validation accuracy
102
+ self.val_mae_best = MinMetric()
103
+
104
+ def forward(self, x: torch.Tensor):
105
+ return self.model(x)
106
+
107
+ def step(self, batch: Any):
108
+ x = batch["image"]
109
+ y = batch["focus_value"]
110
+ logits = self.forward(x)
111
+ loss = self.criterion(logits, y)
112
+ preds = torch.squeeze(logits)
113
+ return loss, preds, y
114
+
115
+ def training_step(self, batch: Any, batch_idx: int):
116
+ loss, preds, targets = self.step(batch)
117
+
118
+ # log train metrics
119
+ mae = self.train_mae(preds, targets)
120
+ self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
121
+ self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
122
+
123
+ # we can return here dict with any tensors
124
+ # and then read it in some callback or in `training_epoch_end()`` below
125
+ # remember to always return loss from `training_step()` or else backpropagation will fail!
126
+ return {"loss": loss, "preds": preds, "targets": targets}
127
+
128
+ def training_epoch_end(self, outputs: List[Any]):
129
+ # `outputs` is a list of dicts returned from `training_step()`
130
+ pass
131
+
132
+ def validation_step(self, batch: Any, batch_idx: int):
133
+ loss, preds, targets = self.step(batch)
134
+
135
+ # log val metrics
136
+ mae = self.val_mae(preds, targets)
137
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
138
+ self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
139
+
140
+ return {"loss": loss, "preds": preds, "targets": targets}
141
+
142
+ def validation_epoch_end(self, outputs: List[Any]):
143
+ mae = self.val_mae.compute() # get val accuracy from current epoch
144
+ self.val_mae_best.update(mae)
145
+ self.log(
146
+ "val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
147
+ )
148
+
149
+ def test_step(self, batch: Any, batch_idx: int):
150
+ loss, preds, targets = self.step(batch)
151
+
152
+ # log test metrics
153
+ mae = self.test_mae(preds, targets)
154
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
155
+ self.log("test/mae", mae, on_step=False, on_epoch=True)
156
+
157
+ return {"loss": loss, "preds": preds, "targets": targets}
158
+
159
+ def test_epoch_end(self, outputs: List[Any]):
160
+ print(outputs)
161
+ pass
162
+
163
+ def on_epoch_end(self):
164
+ # reset metrics at the end of every epoch
165
+ self.train_mae.reset()
166
+ self.test_mae.reset()
167
+ self.val_mae.reset()
168
+
169
+ def configure_optimizers(self):
170
+ """Choose what optimizers and learning-rate schedulers.
171
+
172
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
173
+
174
+ See examples here:
175
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
176
+ """
177
+ return torch.optim.Adam(
178
+ params=self.parameters(),
179
+ lr=self.hparams.lr,
180
+ weight_decay=self.hparams.weight_decay,
181
+ )