Hannes Kuchelmeister
commited on
Commit
·
12319ae
1
Parent(s):
eef4d7c
add FNN with MSE loss
Browse files
src/datamodules/focus_datamodule.py
CHANGED
@@ -3,6 +3,8 @@ from typing import Optional, Tuple
|
|
3 |
import pandas as pd
|
4 |
from skimage import io
|
5 |
|
|
|
|
|
6 |
import torch
|
7 |
from pytorch_lightning import LightningDataModule
|
8 |
from torch.utils.data import DataLoader, Dataset, random_split
|
@@ -51,7 +53,9 @@ class FocusDataSet(Dataset):
|
|
51 |
self.root_dir, self.metadata.iloc[idx, self.col_index_path]
|
52 |
)
|
53 |
image = io.imread(img_name)
|
54 |
-
focus_value =
|
|
|
|
|
55 |
sample = {"image": image, "focus_value": focus_value}
|
56 |
|
57 |
if self.transform:
|
|
|
3 |
import pandas as pd
|
4 |
from skimage import io
|
5 |
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
import torch
|
9 |
from pytorch_lightning import LightningDataModule
|
10 |
from torch.utils.data import DataLoader, Dataset, random_split
|
|
|
53 |
self.root_dir, self.metadata.iloc[idx, self.col_index_path]
|
54 |
)
|
55 |
image = io.imread(img_name)
|
56 |
+
focus_value = torch.from_numpy(
|
57 |
+
np.asarray(self.metadata.iloc[idx, self.col_index_focus])
|
58 |
+
).float()
|
59 |
sample = {"image": image, "focus_value": focus_value}
|
60 |
|
61 |
if self.transform:
|
src/models/focus_module.py
CHANGED
@@ -156,3 +156,128 @@ class FocusLitModule(LightningModule):
|
|
156 |
lr=self.hparams.lr,
|
157 |
weight_decay=self.hparams.weight_decay,
|
158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
lr=self.hparams.lr,
|
157 |
weight_decay=self.hparams.weight_decay,
|
158 |
)
|
159 |
+
|
160 |
+
|
161 |
+
class FocusMSELitModule(LightningModule):
|
162 |
+
"""
|
163 |
+
Example of LightningModule for MNIST classification.
|
164 |
+
|
165 |
+
A LightningModule organizes your PyTorch code into 5 sections:
|
166 |
+
- Computations (init).
|
167 |
+
- Train loop (training_step)
|
168 |
+
- Validation loop (validation_step)
|
169 |
+
- Test loop (test_step)
|
170 |
+
- Optimizers (configure_optimizers)
|
171 |
+
|
172 |
+
Read the docs:
|
173 |
+
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
input_size: int = 75 * 75 * 3,
|
179 |
+
lin1_size: int = 256,
|
180 |
+
lin2_size: int = 256,
|
181 |
+
lin3_size: int = 256,
|
182 |
+
output_size: int = 1,
|
183 |
+
lr: float = 0.001,
|
184 |
+
weight_decay: float = 0.0005,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
# this line allows to access init params with 'self.hparams' attribute
|
189 |
+
# it also ensures init params will be stored in ckpt
|
190 |
+
self.save_hyperparameters(logger=False)
|
191 |
+
|
192 |
+
self.model = SimpleDenseNet(hparams=self.hparams)
|
193 |
+
|
194 |
+
# loss function
|
195 |
+
self.criterion = torch.nn.MSELoss()
|
196 |
+
|
197 |
+
# use separate metric instance for train, val and test step
|
198 |
+
# to ensure a proper reduction over the epoch
|
199 |
+
self.train_mae = MeanAbsoluteError()
|
200 |
+
self.val_mae = MeanAbsoluteError()
|
201 |
+
self.test_mae = MeanAbsoluteError()
|
202 |
+
|
203 |
+
# for logging best so far validation accuracy
|
204 |
+
self.val_mae_best = MinMetric()
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor):
|
207 |
+
return self.model(x)
|
208 |
+
|
209 |
+
def step(self, batch: Any):
|
210 |
+
x = batch["image"]
|
211 |
+
y = batch["focus_value"]
|
212 |
+
logits = self.forward(x)
|
213 |
+
loss = self.criterion(logits, y)
|
214 |
+
preds = torch.squeeze(logits)
|
215 |
+
return loss, preds, y
|
216 |
+
|
217 |
+
def training_step(self, batch: Any, batch_idx: int):
|
218 |
+
loss, preds, targets = self.step(batch)
|
219 |
+
|
220 |
+
# log train metrics
|
221 |
+
mae = self.train_mae(preds, targets)
|
222 |
+
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
223 |
+
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
224 |
+
|
225 |
+
# we can return here dict with any tensors
|
226 |
+
# and then read it in some callback or in `training_epoch_end()`` below
|
227 |
+
# remember to always return loss from `training_step()` or else backpropagation will fail!
|
228 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
229 |
+
|
230 |
+
def training_epoch_end(self, outputs: List[Any]):
|
231 |
+
# `outputs` is a list of dicts returned from `training_step()`
|
232 |
+
pass
|
233 |
+
|
234 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
235 |
+
loss, preds, targets = self.step(batch)
|
236 |
+
|
237 |
+
# log val metrics
|
238 |
+
mae = self.val_mae(preds, targets)
|
239 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
240 |
+
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
241 |
+
|
242 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
243 |
+
|
244 |
+
def validation_epoch_end(self, outputs: List[Any]):
|
245 |
+
mae = self.val_mae.compute() # get val accuracy from current epoch
|
246 |
+
self.val_mae_best.update(mae)
|
247 |
+
self.log(
|
248 |
+
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
|
249 |
+
)
|
250 |
+
|
251 |
+
def test_step(self, batch: Any, batch_idx: int):
|
252 |
+
loss, preds, targets = self.step(batch)
|
253 |
+
|
254 |
+
# log test metrics
|
255 |
+
mae = self.test_mae(preds, targets)
|
256 |
+
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
257 |
+
self.log("test/mae", mae, on_step=False, on_epoch=True)
|
258 |
+
|
259 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
260 |
+
|
261 |
+
def test_epoch_end(self, outputs: List[Any]):
|
262 |
+
print(outputs)
|
263 |
+
pass
|
264 |
+
|
265 |
+
def on_epoch_end(self):
|
266 |
+
# reset metrics at the end of every epoch
|
267 |
+
self.train_mae.reset()
|
268 |
+
self.test_mae.reset()
|
269 |
+
self.val_mae.reset()
|
270 |
+
|
271 |
+
def configure_optimizers(self):
|
272 |
+
"""Choose what optimizers and learning-rate schedulers.
|
273 |
+
|
274 |
+
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
275 |
+
|
276 |
+
See examples here:
|
277 |
+
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
278 |
+
"""
|
279 |
+
return torch.optim.Adam(
|
280 |
+
params=self.parameters(),
|
281 |
+
lr=self.hparams.lr,
|
282 |
+
weight_decay=self.hparams.weight_decay,
|
283 |
+
)
|