Hannes Kuchelmeister commited on
Commit
12319ae
·
1 Parent(s): eef4d7c

add FNN with MSE loss

Browse files
src/datamodules/focus_datamodule.py CHANGED
@@ -3,6 +3,8 @@ from typing import Optional, Tuple
3
  import pandas as pd
4
  from skimage import io
5
 
 
 
6
  import torch
7
  from pytorch_lightning import LightningDataModule
8
  from torch.utils.data import DataLoader, Dataset, random_split
@@ -51,7 +53,9 @@ class FocusDataSet(Dataset):
51
  self.root_dir, self.metadata.iloc[idx, self.col_index_path]
52
  )
53
  image = io.imread(img_name)
54
- focus_value = self.metadata.iloc[idx, self.col_index_focus]
 
 
55
  sample = {"image": image, "focus_value": focus_value}
56
 
57
  if self.transform:
 
3
  import pandas as pd
4
  from skimage import io
5
 
6
+ import numpy as np
7
+
8
  import torch
9
  from pytorch_lightning import LightningDataModule
10
  from torch.utils.data import DataLoader, Dataset, random_split
 
53
  self.root_dir, self.metadata.iloc[idx, self.col_index_path]
54
  )
55
  image = io.imread(img_name)
56
+ focus_value = torch.from_numpy(
57
+ np.asarray(self.metadata.iloc[idx, self.col_index_focus])
58
+ ).float()
59
  sample = {"image": image, "focus_value": focus_value}
60
 
61
  if self.transform:
src/models/focus_module.py CHANGED
@@ -156,3 +156,128 @@ class FocusLitModule(LightningModule):
156
  lr=self.hparams.lr,
157
  weight_decay=self.hparams.weight_decay,
158
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  lr=self.hparams.lr,
157
  weight_decay=self.hparams.weight_decay,
158
  )
159
+
160
+
161
+ class FocusMSELitModule(LightningModule):
162
+ """
163
+ Example of LightningModule for MNIST classification.
164
+
165
+ A LightningModule organizes your PyTorch code into 5 sections:
166
+ - Computations (init).
167
+ - Train loop (training_step)
168
+ - Validation loop (validation_step)
169
+ - Test loop (test_step)
170
+ - Optimizers (configure_optimizers)
171
+
172
+ Read the docs:
173
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ input_size: int = 75 * 75 * 3,
179
+ lin1_size: int = 256,
180
+ lin2_size: int = 256,
181
+ lin3_size: int = 256,
182
+ output_size: int = 1,
183
+ lr: float = 0.001,
184
+ weight_decay: float = 0.0005,
185
+ ):
186
+ super().__init__()
187
+
188
+ # this line allows to access init params with 'self.hparams' attribute
189
+ # it also ensures init params will be stored in ckpt
190
+ self.save_hyperparameters(logger=False)
191
+
192
+ self.model = SimpleDenseNet(hparams=self.hparams)
193
+
194
+ # loss function
195
+ self.criterion = torch.nn.MSELoss()
196
+
197
+ # use separate metric instance for train, val and test step
198
+ # to ensure a proper reduction over the epoch
199
+ self.train_mae = MeanAbsoluteError()
200
+ self.val_mae = MeanAbsoluteError()
201
+ self.test_mae = MeanAbsoluteError()
202
+
203
+ # for logging best so far validation accuracy
204
+ self.val_mae_best = MinMetric()
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ return self.model(x)
208
+
209
+ def step(self, batch: Any):
210
+ x = batch["image"]
211
+ y = batch["focus_value"]
212
+ logits = self.forward(x)
213
+ loss = self.criterion(logits, y)
214
+ preds = torch.squeeze(logits)
215
+ return loss, preds, y
216
+
217
+ def training_step(self, batch: Any, batch_idx: int):
218
+ loss, preds, targets = self.step(batch)
219
+
220
+ # log train metrics
221
+ mae = self.train_mae(preds, targets)
222
+ self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
223
+ self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
224
+
225
+ # we can return here dict with any tensors
226
+ # and then read it in some callback or in `training_epoch_end()`` below
227
+ # remember to always return loss from `training_step()` or else backpropagation will fail!
228
+ return {"loss": loss, "preds": preds, "targets": targets}
229
+
230
+ def training_epoch_end(self, outputs: List[Any]):
231
+ # `outputs` is a list of dicts returned from `training_step()`
232
+ pass
233
+
234
+ def validation_step(self, batch: Any, batch_idx: int):
235
+ loss, preds, targets = self.step(batch)
236
+
237
+ # log val metrics
238
+ mae = self.val_mae(preds, targets)
239
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
240
+ self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
241
+
242
+ return {"loss": loss, "preds": preds, "targets": targets}
243
+
244
+ def validation_epoch_end(self, outputs: List[Any]):
245
+ mae = self.val_mae.compute() # get val accuracy from current epoch
246
+ self.val_mae_best.update(mae)
247
+ self.log(
248
+ "val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
249
+ )
250
+
251
+ def test_step(self, batch: Any, batch_idx: int):
252
+ loss, preds, targets = self.step(batch)
253
+
254
+ # log test metrics
255
+ mae = self.test_mae(preds, targets)
256
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
257
+ self.log("test/mae", mae, on_step=False, on_epoch=True)
258
+
259
+ return {"loss": loss, "preds": preds, "targets": targets}
260
+
261
+ def test_epoch_end(self, outputs: List[Any]):
262
+ print(outputs)
263
+ pass
264
+
265
+ def on_epoch_end(self):
266
+ # reset metrics at the end of every epoch
267
+ self.train_mae.reset()
268
+ self.test_mae.reset()
269
+ self.val_mae.reset()
270
+
271
+ def configure_optimizers(self):
272
+ """Choose what optimizers and learning-rate schedulers.
273
+
274
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
275
+
276
+ See examples here:
277
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
278
+ """
279
+ return torch.optim.Adam(
280
+ params=self.parameters(),
281
+ lr=self.hparams.lr,
282
+ weight_decay=self.hparams.weight_decay,
283
+ )