T-Flet commited on
Commit
7d4973c
1 Parent(s): 104ddc3

Added PyTorch Lightning utilities and model pipeline

Browse files
Files changed (5) hide show
  1. .gitignore +174 -0
  2. lightning_utils.py +179 -0
  3. main.ipynb +0 -0
  4. pytorch_utils.py +12 -5
  5. pytorch_vision_utils.py +168 -0
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Misc
132
+ *.DS_Store
133
+ */.DS_Store
134
+ extras/data/*
135
+
136
+ # Data files
137
+ data/10_whole_foods*
138
+ data/FashionMNIST
139
+ data/10_whole_foods*
140
+ data/food*
141
+ data/pizza_steak_sushi/*
142
+ data/pizza_steak_sushi_20_percent/
143
+ 10_whole_foods.zip
144
+ going_modular/data/
145
+ data/cifar-*
146
+ logs/
147
+ runs/
148
+ lightning_logs/
149
+ extras/cifar-10*
150
+
151
+ # Notebooks
152
+ 09_pytorch_model_deployment-*
153
+
154
+ # Models
155
+ models/03_pytorch_computer_vision_model_2.pth
156
+ models/04_pytorch_custom_datasets_tinyvgg.pth
157
+ models/07_effnetb0_data_10_percent_10_epochs.pth
158
+ models/07_effnetb0_data_10_percent_5_epochs.pth
159
+ models/07_effnetb0_data_20_percent_10_epochs.pth
160
+ models/07_effnetb0_data_20_percent_5_epochs.pth
161
+ models/07_effnetb2_data_10_percent_10_epochs.pth
162
+ models/07_effnetb2_data_10_percent_5_epochs.pth
163
+ models/07_effnetb2_data_20_percent_5_epochs.pth
164
+ models/08_*
165
+ models/09_*
166
+
167
+ # Demos
168
+ demos/foodvision_big/
169
+ demos/foodvision_mini/
170
+ flagged/
171
+
172
+ # Docs
173
+ .cache
174
+ mkdocs-material-insiders
lightning_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Collection of boilerplate and utility functions for PyTorch Lightning.
3
+ '''
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import DataLoader, random_split
7
+ import torchvision as tv
8
+ # from torch.optim.optimizer import ParamsT # Could use instead of nn.Module as optimiser_factory argument
9
+ # # I.e. ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
10
+
11
+ import pytorch_lightning as L
12
+
13
+ import os
14
+ from pathlib import Path
15
+
16
+ from typing import Callable
17
+
18
+
19
+ class Strike(L.LightningModule):
20
+ '''As in "Lightning Strike", to make a PyTorch Module a LightningModule'''
21
+ def __init__(self, model: nn.Module,
22
+ loss_fn: Callable[[torch.Tensor], torch.Tensor], metric_name_and_fn: tuple[str, Callable[[torch.Tensor, torch.Tensor], torch.tensor]],
23
+ optimiser_factory: Callable[[nn.Module], torch.optim.Optimizer],
24
+ prediction_fn: Callable[[torch.Tensor], torch.Tensor],
25
+ learning_rate = 0.001, log_at_every_step = False):
26
+ '''Class for turning a nn.Module into a LightningModule (a lightning strike of sorts).
27
+ The optimiser_factory argument is a callable taking in the module, from which it extracts .parameters() and .learning_rate to produce an optimiser.'''
28
+ super().__init__()
29
+
30
+ self.model = model
31
+ # If the model form were known then its layers could be moved to this object's level rather than a nested one (not necessary but neater)
32
+ # The procedural versions pf this are not useful since a nested nn.Sequential still exists, i.e. any of
33
+ # self.model = nn.Sequential(target._modules) # Preserves layer names
34
+ # self.model = nn.Sequential(*source.children()) # *source.modules() would return the larger container as well
35
+
36
+ self.loss_fn = loss_fn
37
+ self.metric_name, self.metric_fn = metric_name_and_fn
38
+ self.optimiser_factory = optimiser_factory
39
+ self.prediction_fn = prediction_fn
40
+
41
+ self.learning_rate = learning_rate
42
+ self.log_at_every_step = log_at_every_step
43
+ self.train_step_outputs, self.validation_step_outputs, self.test_step_outputs = dict(), dict(), dict()
44
+
45
+ def forward(self, x):
46
+ return self.model(x)
47
+
48
+ # No need to override these two hooks
49
+ # def backward(self, trainer, loss, optimizer, optimizer_idx):
50
+ # loss.backward()
51
+ # def optimizer_step(self, epoch, batch_idx, optimiser, optimizer_idx):
52
+ # optimiser.step()
53
+
54
+ def training_step(self, batch, batch_idx):
55
+ loss, metric, x_hat, y = self._common_step(batch, batch_idx)
56
+ self.train_step_outputs = dict(prefix = 'train', loss = loss, metric = metric)
57
+ return loss
58
+
59
+ def on_train_epoch_end(self):
60
+ self._common_epoch_end_step(self.train_step_outputs)
61
+
62
+ def validation_step(self, batch, batch_idx):
63
+ loss, metric, x_hat, y = self._common_step(batch, batch_idx)
64
+ self.validation_step_outputs = dict(prefix = 'val', loss = loss, metric = metric)
65
+ return loss
66
+
67
+ def on_validation_epoch_end(self):
68
+ self._common_epoch_end_step(self.validation_step_outputs)
69
+
70
+ def test_step(self, batch, batch_idx):
71
+ loss, metric, x_hat, y = self._common_step(batch, batch_idx)
72
+ self.test_step_outputs = dict(prefix = 'test', loss = loss, metric = metric)
73
+ return loss
74
+
75
+ def on_test_epoch_end(self):
76
+ self._common_epoch_end_step(self.test_step_outputs)
77
+
78
+ def _common_step(self, batch, batch_idx):
79
+ x, y = batch
80
+ x_hat = self.forward(x)
81
+ loss = self.loss_fn(x_hat, y)
82
+ metric = self.metric_fn(x_hat, y)
83
+ return loss, metric, x_hat, y
84
+
85
+ def _common_epoch_end_step(self, outputs):
86
+ self.log_dict({f'{outputs["prefix"]}_loss': outputs['loss'], f'{outputs["prefix"]}_{self.metric_name}': outputs['metric']}, prog_bar = True, on_step = self.log_at_every_step, on_epoch = True)
87
+ outputs.clear() # Freeing memory is suggested in the docs, though it is trivial in this class
88
+
89
+ def predict_step(self, batch, batch_idx):
90
+ x, y = batch
91
+ x_hat = self.forward(x)
92
+ preds = self.prediction_fn(x_hat)
93
+ return preds
94
+
95
+ def configure_optimizers(self):
96
+ return self.optimiser_factory(self)
97
+
98
+
99
+
100
+ class IteratedLearningRateFinder(L.callbacks.LearningRateFinder):
101
+ def __init__(self, at_epochs: list[int], *args, **kwargs):
102
+ '''CURRENTLY FAILS AT THE 2ND OCCURRENCE (despite being directly from the docs: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateFinder.html)
103
+ The lr finding tuns at epoch 0 regardless of whether 0 is in at_epochs.
104
+ E.g. for periodic lr adjustments pass [e for e in range(epochs) if e % period == 0]'''
105
+ super().__init__(*args, **kwargs)
106
+ self.at_epochs = at_epochs
107
+
108
+ def on_fit_start(self, *args, **kwargs):
109
+ return
110
+
111
+ def on_train_epoch_start(self, trainer, pl_module):
112
+ if trainer.current_epoch in self.at_epochs or trainer.current_epoch == 0:
113
+ self.lr_find(trainer, pl_module)
114
+
115
+
116
+
117
+ class LocalImageDataModule(L.LightningDataModule):
118
+ def __init__(self, folders: str | Path | dict[str, str | Path], transform: tv.transforms.Compose,
119
+ batch_size: int, num_workers: int = os.cpu_count(), split: tuple[float, float, float] = (0.7, 0.2, 0.1)):
120
+ super().__init__()
121
+ '''Return a LightningDataModule for a local image folder (or folders) for classification purposes.
122
+ Images are expected to be in subfolders named by their classes.
123
+ In the str or Path folders cases, the folder content is checked for subfolders called train, test and valid (yes, in this order for consistency), and if any is present they are treated as the list input,
124
+ however, if none is present, then the split argument is required, containing a tuple of proportions to allocate to training, validation and testing datasets.
125
+ In the dict folders case the keys are expected to be in ['train', 'valid', 'test'].
126
+ The class names are from the first folder and assumed to be consistent across the others.
127
+ '''
128
+
129
+ ########### Could relax requirement to train and test and then produce a validate dataset from the training one #########
130
+
131
+ self.prefixes = ['train', 'valid', 'test']
132
+
133
+ data_path = None
134
+ if isinstance(folders, (str, Path)):
135
+ data_path = Path(folders)
136
+ folders = {sub: full_sub for sub in self.prefixes if (full_sub := data_path / sub).is_dir()}
137
+ elif not isinstance(folders, dict): raise ValueError('Please provide a folders argument of types str | Path | dict[str, str | Path].')
138
+
139
+ assert set(folders.keys()).issubset(self.prefixes), f'Exactly the {self.prefixes} folders are required; {folders.keys()} were provided.'
140
+ if len(folders) == 3: folders = folders
141
+ elif len(folders) == 0 and data_path is not None:
142
+ assert sum(split) == 1
143
+ folders = (data_path, dict(zip(self.prefixes, split)))
144
+ else: raise ValueError(f'All of {self.prefixes} subfolders are required for the single-folder folders argument; only {folders.keys()} were provided.')
145
+
146
+ self.folders = folders
147
+ self.transform = transform
148
+ self.batch_size = batch_size
149
+ self.num_workers = num_workers
150
+
151
+ self.train_ds, self.val_ds, self.test_ds = None, None, None
152
+ self.classes = None
153
+
154
+ # def prepare_data(self):
155
+ # '''Not currently implemented. Mostly meant for downloading data.'''
156
+ # pass
157
+
158
+ def setup(self, stage):
159
+ if isinstance(self.folders, tuple):
160
+ all_data = tv.datasets.ImageFolder(self.folders[0], transform = self.transform)
161
+ self.classes = all_data.classes
162
+ self.train_ds, self.val_ds, self.test_ds = random_split(all_data, self.folders[1])
163
+ else:
164
+ if stage == 'fit':
165
+ self.train_ds, self.val_ds = [tv.datasets.ImageFolder(self.folders[k], transform = self.transform) for k in self.prefixes[:-1]]
166
+ self.classes = self.train_ds.classes
167
+ if stage == 'test':
168
+ self.test_ds = tv.datasets.ImageFolder(self.folders[self.prefixes[-1]], transform = self.transform)
169
+
170
+ def train_dataloader(self):
171
+ return DataLoader(self.train_ds, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers, pin_memory = True, persistent_workers = True)
172
+
173
+ def val_dataloader(self):
174
+ return DataLoader(self.val_ds, batch_size = self.batch_size, shuffle = False, num_workers = self.num_workers, pin_memory = True, persistent_workers = True)
175
+
176
+ def test_dataloader(self):
177
+ return DataLoader(self.test_ds, batch_size = self.batch_size, shuffle = False, num_workers = self.num_workers, pin_memory = True, persistent_workers = True)
178
+
179
+
main.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
pytorch_utils.py CHANGED
@@ -7,7 +7,6 @@ from torch import nn
7
  from torch.utils.data import DataLoader
8
  from torch.utils.tensorboard import SummaryWriter
9
 
10
- # Here just to be exported
11
  from torchinfo import summary
12
  import torchmetrics
13
 
@@ -41,7 +40,7 @@ def set_seeds(seed: int = 42):
41
  def train_combinations(combinations: dict[str, tuple[str, str, str, int, str]],
42
  model_factories: dict[str, Callable[[], nn.Module]], train_dataloaders: dict[str, DataLoader],
43
  optimiser_factories: dict[str, Callable[[nn.Module], torch.optim.Optimizer]],
44
- test_dataloader: DataLoader, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.tensor, torch.tensor], torch.tensor]],
45
  reset_seed: int = 42, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True):
46
  '''Run a series of modelling tasks by defining combinations of models, dataloaders, optimisers and epochs, as well as an optional previously-fit combination
47
  to start from (e.g. for a combination which is the same as a previous one but with more epochs or different training data).
@@ -96,7 +95,7 @@ def train_combinations(combinations: dict[str, tuple[str, str, str, int, str]],
96
 
97
 
98
  def fit(model: nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoader,
99
- optimiser: torch.optim.Optimizer, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.tensor, torch.tensor], torch.tensor]],
100
  epochs: int, writer: torch.utils.tensorboard.writer.SummaryWriter,
101
  device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, model_name: str = None) -> dict[str, list]:
102
  '''Trains and tests a PyTorch model.
@@ -148,7 +147,7 @@ def fit(model: nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoa
148
 
149
 
150
  def training_step(model: nn.Module, dataloader: DataLoader,
151
- loss_fn: nn.Module, metric_fn: Callable[[torch.tensor, torch.tensor], torch.tensor], optimiser: torch.optim.Optimizer,
152
  device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
153
  '''Trains a PyTorch model for a single epoch.
154
 
@@ -189,7 +188,7 @@ def training_step(model: nn.Module, dataloader: DataLoader,
189
 
190
 
191
  def testing_step(model: nn.Module, dataloader: DataLoader,
192
- loss_fn: nn.Module, metric_fn: Callable[[torch.tensor, torch.tensor], torch.tensor],
193
  device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
194
  '''Tests a PyTorch model for a single epoch.
195
 
@@ -302,6 +301,14 @@ def download_unzip(source: str, destination: str, remove_source: bool = True) ->
302
 
303
 
304
 
 
 
 
 
 
 
 
 
305
 
306
  #### Plotting Functions ####
307
 
 
7
  from torch.utils.data import DataLoader
8
  from torch.utils.tensorboard import SummaryWriter
9
 
 
10
  from torchinfo import summary
11
  import torchmetrics
12
 
 
40
  def train_combinations(combinations: dict[str, tuple[str, str, str, int, str]],
41
  model_factories: dict[str, Callable[[], nn.Module]], train_dataloaders: dict[str, DataLoader],
42
  optimiser_factories: dict[str, Callable[[nn.Module], torch.optim.Optimizer]],
43
+ test_dataloader: DataLoader, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
44
  reset_seed: int = 42, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True):
45
  '''Run a series of modelling tasks by defining combinations of models, dataloaders, optimisers and epochs, as well as an optional previously-fit combination
46
  to start from (e.g. for a combination which is the same as a previous one but with more epochs or different training data).
 
95
 
96
 
97
  def fit(model: nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoader,
98
+ optimiser: torch.optim.Optimizer, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
99
  epochs: int, writer: torch.utils.tensorboard.writer.SummaryWriter,
100
  device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, model_name: str = None) -> dict[str, list]:
101
  '''Trains and tests a PyTorch model.
 
147
 
148
 
149
  def training_step(model: nn.Module, dataloader: DataLoader,
150
+ loss_fn: nn.Module, metric_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: torch.optim.Optimizer,
151
  device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
152
  '''Trains a PyTorch model for a single epoch.
153
 
 
188
 
189
 
190
  def testing_step(model: nn.Module, dataloader: DataLoader,
191
+ loss_fn: nn.Module, metric_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
192
  device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
193
  '''Tests a PyTorch model for a single epoch.
194
 
 
301
 
302
 
303
 
304
+ #### Info Functions ####
305
+
306
+ def summ(model: nn.Module, input_size: tuple):
307
+ '''Shorthand for typical summary specification'''
308
+ return summary(model = model, input_size = (32, 3, 224, 224),
309
+ col_names = ['input_size', 'output_size', 'num_params', 'trainable'], col_width = 20, row_settings = ['var_names'])
310
+
311
+
312
 
313
  #### Plotting Functions ####
314
 
pytorch_vision_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Torchvision and related utility functions'''
2
+
3
+ import torch
4
+ import torchvision as tv
5
+ from torch.utils.data import DataLoader
6
+ import timm # Here just to be exported
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from tqdm.auto import tqdm
11
+ import base64
12
+ import altair as alt
13
+ import matplotlib.pyplot as plt # REMOVE IN FAVOUR OF ALTAIR
14
+
15
+ import os
16
+ import io
17
+ from pathlib import Path
18
+ from PIL import Image
19
+ # from itertools import batched # in Python>=3.12
20
+
21
+
22
+ def image_dataloaders(folders: str | Path | list[str | Path], transform: tv.transforms.Compose, batch_size: int, num_workers: int = os.cpu_count()) -> tuple[list[DataLoader], list[str]]:
23
+ '''Return PyTorch DataLoaders and class names for the given folder or list of folders (with expected subfolders named by class).
24
+ In the non-list folders case, the folder content is checked for subfolders called train, test and valid (yes, in this order for consistency), and if any is present they are treated as the list input.
25
+ The first folder is assumed to be the training data and will therefore produce a shuffling dataloader, while the others will not.
26
+ The class names are from the first folder and assumed to be consistent across the others.
27
+ '''
28
+ if isinstance(folders, (str, Path)):
29
+ data_path = Path(folders)
30
+ folders = subfolders if (subfolders := [full_sub for sub in ['train', 'valid', 'test'] if (full_sub := data_path / sub).is_dir()]) else [folders]
31
+
32
+ datasets = [tv.datasets.ImageFolder(folder, transform = transform) for folder in folders]
33
+ dataloaders = [DataLoader(ds, batch_size = batch_size, shuffle = i == 0, num_workers = num_workers, pin_memory = True, persistent_workers = True) for i, ds in enumerate(datasets)]
34
+
35
+ return dataloaders, datasets[0].classes
36
+
37
+
38
+ def plot_img_preds(model: torch.nn.Module, image_path: str, class_names: list[str], transform: tv.transforms, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu'):
39
+ '''Plot one image with its prediction and probability as the title.
40
+ '''
41
+ img = Image.open(image_path)
42
+
43
+ model.to(device)
44
+ model.eval()
45
+ with torch.inference_mode(): pred_logit = model(transform(img).unsqueeze(dim = 0).to(device)) # Prepend "batch" dimension (-> [batch_size, color_channels, height, width])
46
+ pred_prob = torch.softmax(pred_logit, dim = 1)
47
+ pred_label = torch.argmax(pred_prob, dim = 1)
48
+
49
+ plt.figure()
50
+ plt.imshow(img)
51
+ plt.title(f"Pred: {class_names[pred_label]} | Prob: {pred_prob.max():.3f}")
52
+ plt.axis(False)
53
+
54
+ # Change text colour based on correctness?
55
+
56
+
57
+ def record_image_preds(image_paths: str | list[str], model: torch.nn.Module, transform: tv.transforms.Compose, class_names: list[str],
58
+ sort_by_correctness = True, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu'):
59
+ '''Generate a dataframe of paths, true classes, (single) predicted classes and their confidence.
60
+ Column names: path, true_class, pred_class, pred_prob, correct.
61
+ If sort_by_correctness, then the dataframe is sorted by increasing correctness and confidence, i.e. first by prediction correctness and then by its probability,
62
+ with wrong predictions first, and both wrong and right by decreasing confidence.
63
+ If a single string is given as image_paths, then all */*.jpg and */*.png matches from it are used instead.
64
+ '''
65
+ true_classes, pred_classes, pred_probs, correctness, image_data = [], [], [], [], []
66
+
67
+ if isinstance(image_paths, str): image_paths = list(Path(image_paths).glob('*/*.jpg')) + list(Path(image_paths).glob('*/*.png'))
68
+
69
+ for path in tqdm(image_paths):
70
+ img = Image.open(path)
71
+
72
+ model.eval()
73
+ with torch.inference_mode(): pred_logit = model(transform(img).unsqueeze(0).to(device)) # Prepend "batch" dimension (-> [batch_size, color_channels, height, width])
74
+ pred_prob = torch.softmax(pred_logit, dim = 1)
75
+ pred_label = torch.argmax(pred_prob, dim = 1)
76
+
77
+ true_classes.append(class_name := path.parent.stem)
78
+ pred_classes.append(pred_class := class_names[pred_label.cpu()])
79
+ pred_probs.append(pred_prob.unsqueeze(0).max().cpu().item())
80
+ correctness.append(class_name == pred_class)
81
+
82
+
83
+ res = pd.DataFrame(dict(path = [str(p) for p in image_paths], true_class = true_classes, pred_class = pred_classes, pred_prob = pred_probs, correct = correctness))
84
+ return res.sort_values(by = ['correct', 'pred_prob'], ascending = [True, False]) if sort_by_correctness else res
85
+
86
+
87
+ def base64_image_formatter(image_or_path: Image.Image | str) -> str:
88
+ '''Generate a base64-encoded string representation of the given image (or path).
89
+ Example usecase: a dataframe meant for Altair contains PIL images (or their paths) in a column, in which case pass this temporary dataframe to the alt.Chart:
90
+ `df.assign(image = df.image_or_path_column.apply(base64_image_formatter))`
91
+ '''
92
+ if isinstance(image_or_path, str): image_or_path = Image.open(image_or_path)
93
+ with io.BytesIO() as buffer: # Docs: https://altair-viz.github.io/user_guide/marks/image.html#use-local-images-as-image-marks
94
+ image_or_path.save(buffer, format = 'PNG')
95
+ data = base64.b64encode(buffer.getvalue()).decode('utf-8')
96
+ return f'data:image/png;base64,{data}'
97
+
98
+
99
+ def image_pred_grid(image_df: pd.DataFrame, ncols = 4, img_width = 200, img_height = 200, allow_1_col_reduction = True):
100
+ '''Create an Altair plot displaying a grid of images and their predicted classes, highlighting incorrect predictions.
101
+ image_df is expected to have the columns: path, true_class, pred_class, pred_prob, correct.
102
+ If allow_1_col_reduction and the last row (by the given ncols) is at least half empty and using ncols-1 would not increase rows, then ncols-1 is used instead.
103
+ '''
104
+ # Docs: https://altair-viz.github.io/user_guide/compound_charts.html
105
+ # Opened issue on making it easier through alt.Facet: https://github.com/altair-viz/altair/issues/3398
106
+
107
+ ncols = min(ncols, len(image_df))
108
+ nrows = 1 + len(image_df) // ncols
109
+ # If the last row is at least half empty and could reduce columns without increasing rows, do so
110
+ if allow_1_col_reduction and nrows > 1 and len(image_df) % ncols <= ncols / 2 and 1 + len(image_df) // (ncols - 1) == nrows: ncols -= 1
111
+
112
+ expanded_df = image_df.assign(
113
+ image = image_df.path.apply(base64_image_formatter),
114
+ title = image_df.pred_class + ' - ' + image_df.pred_prob.map(lambda p: f'{p:.2f}'),
115
+ index = image_df.index
116
+ )
117
+
118
+ base = alt.Chart(expanded_df).mark_image(width = img_width, height = img_height).encode(url = 'image:N')
119
+ chart = alt.vconcat()
120
+ for row_indices in (expanded_df.index[i:i + ncols] for i in range(0, len(expanded_df), ncols)): # itertools.batched(expanded_df.index, ncols) in Python>=3.12
121
+ row_chart = alt.hconcat()
122
+ for index in row_indices:
123
+ row_chart |= base.transform_filter(alt.datum.index == index).properties(
124
+ title = alt.Title(expanded_df.title[index], fontSize = 17, color = 'green' if expanded_df.correct[index] else 'red'))
125
+ chart &= row_chart
126
+
127
+ ## Version with no subplots (but no titles)
128
+ # chart = alt.Chart(image_df.assign( # vv cannot trust the df index since it might not be ordered
129
+ # row = np.arange(len(image_df)) // ncols, col = np.arange(len(image_df)) % ncols # Could use the transform_compose block for this, but no // in the alt.expr language
130
+ # )).mark_image(width = img_width, height = img_height).encode(
131
+ # alt.X('col:O', title = None, axis = None), alt.Y('row:O', title = None, axis = None), url = 'image:N'
132
+ # ).properties(
133
+ # width = img_width * 1.1 * ncols, height = img_height * 1.1 * nrows
134
+ # )
135
+
136
+ ## Version with faceting (but not coloured titles (no titles in fact, but non-coloured headers))
137
+ # chart = alt.Chart(image_df.assign(
138
+ # image = image_df.path.apply(base64_image_formatter),
139
+ # title = image_df.pred_class + ' - ' + image_df.pred_prob.map(lambda p: f'{p:.2f}')
140
+ # )).mark_image(width = img_width, height = img_height).encode(url = 'image:N'
141
+ # ).facet( # Header fields: https://altair-viz.github.io/user_guide/generated/core/altair.Header.html
142
+ # alt.Facet('title:N', header = alt.Header(labelFontSize = 17, labelColor = 'red')).title('Prediction and Confidence'), columns = ncols, title = 'Hi'
143
+ # )
144
+
145
+ return chart
146
+
147
+
148
+
149
+
150
+
151
+ # import torchvision
152
+ # import matplotlib.pyplot as plt
153
+ # # Plot the top 5 most wrong images
154
+ # for row in top_5_most_wrong.iterrows():
155
+ # row = row[1]
156
+ # image_path = row[0]
157
+ # true_label = row[1]
158
+ # pred_prob = row[2]
159
+ # pred_class = row[3]
160
+ # # Plot the image and various details
161
+ # img = torchvision.io.read_image(str(image_path)) # get image as tensor
162
+ # plt.figure()
163
+ # plt.imshow(img.permute(1, 2, 0)) # matplotlib likes images in [height, width, color_channels]
164
+ # plt.title(f"True: {true_label} | Pred: {pred_class} | Prob: {pred_prob:.3f}")
165
+ # plt.axis(False);
166
+
167
+
168
+