Spaces:
Sleeping
Sleeping
Added PyTorch Lightning utilities and model pipeline
Browse files- .gitignore +174 -0
- lightning_utils.py +179 -0
- main.ipynb +0 -0
- pytorch_utils.py +12 -5
- pytorch_vision_utils.py +168 -0
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# Misc
|
132 |
+
*.DS_Store
|
133 |
+
*/.DS_Store
|
134 |
+
extras/data/*
|
135 |
+
|
136 |
+
# Data files
|
137 |
+
data/10_whole_foods*
|
138 |
+
data/FashionMNIST
|
139 |
+
data/10_whole_foods*
|
140 |
+
data/food*
|
141 |
+
data/pizza_steak_sushi/*
|
142 |
+
data/pizza_steak_sushi_20_percent/
|
143 |
+
10_whole_foods.zip
|
144 |
+
going_modular/data/
|
145 |
+
data/cifar-*
|
146 |
+
logs/
|
147 |
+
runs/
|
148 |
+
lightning_logs/
|
149 |
+
extras/cifar-10*
|
150 |
+
|
151 |
+
# Notebooks
|
152 |
+
09_pytorch_model_deployment-*
|
153 |
+
|
154 |
+
# Models
|
155 |
+
models/03_pytorch_computer_vision_model_2.pth
|
156 |
+
models/04_pytorch_custom_datasets_tinyvgg.pth
|
157 |
+
models/07_effnetb0_data_10_percent_10_epochs.pth
|
158 |
+
models/07_effnetb0_data_10_percent_5_epochs.pth
|
159 |
+
models/07_effnetb0_data_20_percent_10_epochs.pth
|
160 |
+
models/07_effnetb0_data_20_percent_5_epochs.pth
|
161 |
+
models/07_effnetb2_data_10_percent_10_epochs.pth
|
162 |
+
models/07_effnetb2_data_10_percent_5_epochs.pth
|
163 |
+
models/07_effnetb2_data_20_percent_5_epochs.pth
|
164 |
+
models/08_*
|
165 |
+
models/09_*
|
166 |
+
|
167 |
+
# Demos
|
168 |
+
demos/foodvision_big/
|
169 |
+
demos/foodvision_mini/
|
170 |
+
flagged/
|
171 |
+
|
172 |
+
# Docs
|
173 |
+
.cache
|
174 |
+
mkdocs-material-insiders
|
lightning_utils.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Collection of boilerplate and utility functions for PyTorch Lightning.
|
3 |
+
'''
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.data import DataLoader, random_split
|
7 |
+
import torchvision as tv
|
8 |
+
# from torch.optim.optimizer import ParamsT # Could use instead of nn.Module as optimiser_factory argument
|
9 |
+
# # I.e. ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
|
10 |
+
|
11 |
+
import pytorch_lightning as L
|
12 |
+
|
13 |
+
import os
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
from typing import Callable
|
17 |
+
|
18 |
+
|
19 |
+
class Strike(L.LightningModule):
|
20 |
+
'''As in "Lightning Strike", to make a PyTorch Module a LightningModule'''
|
21 |
+
def __init__(self, model: nn.Module,
|
22 |
+
loss_fn: Callable[[torch.Tensor], torch.Tensor], metric_name_and_fn: tuple[str, Callable[[torch.Tensor, torch.Tensor], torch.tensor]],
|
23 |
+
optimiser_factory: Callable[[nn.Module], torch.optim.Optimizer],
|
24 |
+
prediction_fn: Callable[[torch.Tensor], torch.Tensor],
|
25 |
+
learning_rate = 0.001, log_at_every_step = False):
|
26 |
+
'''Class for turning a nn.Module into a LightningModule (a lightning strike of sorts).
|
27 |
+
The optimiser_factory argument is a callable taking in the module, from which it extracts .parameters() and .learning_rate to produce an optimiser.'''
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.model = model
|
31 |
+
# If the model form were known then its layers could be moved to this object's level rather than a nested one (not necessary but neater)
|
32 |
+
# The procedural versions pf this are not useful since a nested nn.Sequential still exists, i.e. any of
|
33 |
+
# self.model = nn.Sequential(target._modules) # Preserves layer names
|
34 |
+
# self.model = nn.Sequential(*source.children()) # *source.modules() would return the larger container as well
|
35 |
+
|
36 |
+
self.loss_fn = loss_fn
|
37 |
+
self.metric_name, self.metric_fn = metric_name_and_fn
|
38 |
+
self.optimiser_factory = optimiser_factory
|
39 |
+
self.prediction_fn = prediction_fn
|
40 |
+
|
41 |
+
self.learning_rate = learning_rate
|
42 |
+
self.log_at_every_step = log_at_every_step
|
43 |
+
self.train_step_outputs, self.validation_step_outputs, self.test_step_outputs = dict(), dict(), dict()
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return self.model(x)
|
47 |
+
|
48 |
+
# No need to override these two hooks
|
49 |
+
# def backward(self, trainer, loss, optimizer, optimizer_idx):
|
50 |
+
# loss.backward()
|
51 |
+
# def optimizer_step(self, epoch, batch_idx, optimiser, optimizer_idx):
|
52 |
+
# optimiser.step()
|
53 |
+
|
54 |
+
def training_step(self, batch, batch_idx):
|
55 |
+
loss, metric, x_hat, y = self._common_step(batch, batch_idx)
|
56 |
+
self.train_step_outputs = dict(prefix = 'train', loss = loss, metric = metric)
|
57 |
+
return loss
|
58 |
+
|
59 |
+
def on_train_epoch_end(self):
|
60 |
+
self._common_epoch_end_step(self.train_step_outputs)
|
61 |
+
|
62 |
+
def validation_step(self, batch, batch_idx):
|
63 |
+
loss, metric, x_hat, y = self._common_step(batch, batch_idx)
|
64 |
+
self.validation_step_outputs = dict(prefix = 'val', loss = loss, metric = metric)
|
65 |
+
return loss
|
66 |
+
|
67 |
+
def on_validation_epoch_end(self):
|
68 |
+
self._common_epoch_end_step(self.validation_step_outputs)
|
69 |
+
|
70 |
+
def test_step(self, batch, batch_idx):
|
71 |
+
loss, metric, x_hat, y = self._common_step(batch, batch_idx)
|
72 |
+
self.test_step_outputs = dict(prefix = 'test', loss = loss, metric = metric)
|
73 |
+
return loss
|
74 |
+
|
75 |
+
def on_test_epoch_end(self):
|
76 |
+
self._common_epoch_end_step(self.test_step_outputs)
|
77 |
+
|
78 |
+
def _common_step(self, batch, batch_idx):
|
79 |
+
x, y = batch
|
80 |
+
x_hat = self.forward(x)
|
81 |
+
loss = self.loss_fn(x_hat, y)
|
82 |
+
metric = self.metric_fn(x_hat, y)
|
83 |
+
return loss, metric, x_hat, y
|
84 |
+
|
85 |
+
def _common_epoch_end_step(self, outputs):
|
86 |
+
self.log_dict({f'{outputs["prefix"]}_loss': outputs['loss'], f'{outputs["prefix"]}_{self.metric_name}': outputs['metric']}, prog_bar = True, on_step = self.log_at_every_step, on_epoch = True)
|
87 |
+
outputs.clear() # Freeing memory is suggested in the docs, though it is trivial in this class
|
88 |
+
|
89 |
+
def predict_step(self, batch, batch_idx):
|
90 |
+
x, y = batch
|
91 |
+
x_hat = self.forward(x)
|
92 |
+
preds = self.prediction_fn(x_hat)
|
93 |
+
return preds
|
94 |
+
|
95 |
+
def configure_optimizers(self):
|
96 |
+
return self.optimiser_factory(self)
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
class IteratedLearningRateFinder(L.callbacks.LearningRateFinder):
|
101 |
+
def __init__(self, at_epochs: list[int], *args, **kwargs):
|
102 |
+
'''CURRENTLY FAILS AT THE 2ND OCCURRENCE (despite being directly from the docs: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateFinder.html)
|
103 |
+
The lr finding tuns at epoch 0 regardless of whether 0 is in at_epochs.
|
104 |
+
E.g. for periodic lr adjustments pass [e for e in range(epochs) if e % period == 0]'''
|
105 |
+
super().__init__(*args, **kwargs)
|
106 |
+
self.at_epochs = at_epochs
|
107 |
+
|
108 |
+
def on_fit_start(self, *args, **kwargs):
|
109 |
+
return
|
110 |
+
|
111 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
112 |
+
if trainer.current_epoch in self.at_epochs or trainer.current_epoch == 0:
|
113 |
+
self.lr_find(trainer, pl_module)
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
class LocalImageDataModule(L.LightningDataModule):
|
118 |
+
def __init__(self, folders: str | Path | dict[str, str | Path], transform: tv.transforms.Compose,
|
119 |
+
batch_size: int, num_workers: int = os.cpu_count(), split: tuple[float, float, float] = (0.7, 0.2, 0.1)):
|
120 |
+
super().__init__()
|
121 |
+
'''Return a LightningDataModule for a local image folder (or folders) for classification purposes.
|
122 |
+
Images are expected to be in subfolders named by their classes.
|
123 |
+
In the str or Path folders cases, the folder content is checked for subfolders called train, test and valid (yes, in this order for consistency), and if any is present they are treated as the list input,
|
124 |
+
however, if none is present, then the split argument is required, containing a tuple of proportions to allocate to training, validation and testing datasets.
|
125 |
+
In the dict folders case the keys are expected to be in ['train', 'valid', 'test'].
|
126 |
+
The class names are from the first folder and assumed to be consistent across the others.
|
127 |
+
'''
|
128 |
+
|
129 |
+
########### Could relax requirement to train and test and then produce a validate dataset from the training one #########
|
130 |
+
|
131 |
+
self.prefixes = ['train', 'valid', 'test']
|
132 |
+
|
133 |
+
data_path = None
|
134 |
+
if isinstance(folders, (str, Path)):
|
135 |
+
data_path = Path(folders)
|
136 |
+
folders = {sub: full_sub for sub in self.prefixes if (full_sub := data_path / sub).is_dir()}
|
137 |
+
elif not isinstance(folders, dict): raise ValueError('Please provide a folders argument of types str | Path | dict[str, str | Path].')
|
138 |
+
|
139 |
+
assert set(folders.keys()).issubset(self.prefixes), f'Exactly the {self.prefixes} folders are required; {folders.keys()} were provided.'
|
140 |
+
if len(folders) == 3: folders = folders
|
141 |
+
elif len(folders) == 0 and data_path is not None:
|
142 |
+
assert sum(split) == 1
|
143 |
+
folders = (data_path, dict(zip(self.prefixes, split)))
|
144 |
+
else: raise ValueError(f'All of {self.prefixes} subfolders are required for the single-folder folders argument; only {folders.keys()} were provided.')
|
145 |
+
|
146 |
+
self.folders = folders
|
147 |
+
self.transform = transform
|
148 |
+
self.batch_size = batch_size
|
149 |
+
self.num_workers = num_workers
|
150 |
+
|
151 |
+
self.train_ds, self.val_ds, self.test_ds = None, None, None
|
152 |
+
self.classes = None
|
153 |
+
|
154 |
+
# def prepare_data(self):
|
155 |
+
# '''Not currently implemented. Mostly meant for downloading data.'''
|
156 |
+
# pass
|
157 |
+
|
158 |
+
def setup(self, stage):
|
159 |
+
if isinstance(self.folders, tuple):
|
160 |
+
all_data = tv.datasets.ImageFolder(self.folders[0], transform = self.transform)
|
161 |
+
self.classes = all_data.classes
|
162 |
+
self.train_ds, self.val_ds, self.test_ds = random_split(all_data, self.folders[1])
|
163 |
+
else:
|
164 |
+
if stage == 'fit':
|
165 |
+
self.train_ds, self.val_ds = [tv.datasets.ImageFolder(self.folders[k], transform = self.transform) for k in self.prefixes[:-1]]
|
166 |
+
self.classes = self.train_ds.classes
|
167 |
+
if stage == 'test':
|
168 |
+
self.test_ds = tv.datasets.ImageFolder(self.folders[self.prefixes[-1]], transform = self.transform)
|
169 |
+
|
170 |
+
def train_dataloader(self):
|
171 |
+
return DataLoader(self.train_ds, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers, pin_memory = True, persistent_workers = True)
|
172 |
+
|
173 |
+
def val_dataloader(self):
|
174 |
+
return DataLoader(self.val_ds, batch_size = self.batch_size, shuffle = False, num_workers = self.num_workers, pin_memory = True, persistent_workers = True)
|
175 |
+
|
176 |
+
def test_dataloader(self):
|
177 |
+
return DataLoader(self.test_ds, batch_size = self.batch_size, shuffle = False, num_workers = self.num_workers, pin_memory = True, persistent_workers = True)
|
178 |
+
|
179 |
+
|
main.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pytorch_utils.py
CHANGED
@@ -7,7 +7,6 @@ from torch import nn
|
|
7 |
from torch.utils.data import DataLoader
|
8 |
from torch.utils.tensorboard import SummaryWriter
|
9 |
|
10 |
-
# Here just to be exported
|
11 |
from torchinfo import summary
|
12 |
import torchmetrics
|
13 |
|
@@ -41,7 +40,7 @@ def set_seeds(seed: int = 42):
|
|
41 |
def train_combinations(combinations: dict[str, tuple[str, str, str, int, str]],
|
42 |
model_factories: dict[str, Callable[[], nn.Module]], train_dataloaders: dict[str, DataLoader],
|
43 |
optimiser_factories: dict[str, Callable[[nn.Module], torch.optim.Optimizer]],
|
44 |
-
test_dataloader: DataLoader, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.
|
45 |
reset_seed: int = 42, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True):
|
46 |
'''Run a series of modelling tasks by defining combinations of models, dataloaders, optimisers and epochs, as well as an optional previously-fit combination
|
47 |
to start from (e.g. for a combination which is the same as a previous one but with more epochs or different training data).
|
@@ -96,7 +95,7 @@ def train_combinations(combinations: dict[str, tuple[str, str, str, int, str]],
|
|
96 |
|
97 |
|
98 |
def fit(model: nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoader,
|
99 |
-
optimiser: torch.optim.Optimizer, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.
|
100 |
epochs: int, writer: torch.utils.tensorboard.writer.SummaryWriter,
|
101 |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, model_name: str = None) -> dict[str, list]:
|
102 |
'''Trains and tests a PyTorch model.
|
@@ -148,7 +147,7 @@ def fit(model: nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoa
|
|
148 |
|
149 |
|
150 |
def training_step(model: nn.Module, dataloader: DataLoader,
|
151 |
-
loss_fn: nn.Module, metric_fn: Callable[[torch.
|
152 |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
|
153 |
'''Trains a PyTorch model for a single epoch.
|
154 |
|
@@ -189,7 +188,7 @@ def training_step(model: nn.Module, dataloader: DataLoader,
|
|
189 |
|
190 |
|
191 |
def testing_step(model: nn.Module, dataloader: DataLoader,
|
192 |
-
loss_fn: nn.Module, metric_fn: Callable[[torch.
|
193 |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
|
194 |
'''Tests a PyTorch model for a single epoch.
|
195 |
|
@@ -302,6 +301,14 @@ def download_unzip(source: str, destination: str, remove_source: bool = True) ->
|
|
302 |
|
303 |
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
#### Plotting Functions ####
|
307 |
|
|
|
7 |
from torch.utils.data import DataLoader
|
8 |
from torch.utils.tensorboard import SummaryWriter
|
9 |
|
|
|
10 |
from torchinfo import summary
|
11 |
import torchmetrics
|
12 |
|
|
|
40 |
def train_combinations(combinations: dict[str, tuple[str, str, str, int, str]],
|
41 |
model_factories: dict[str, Callable[[], nn.Module]], train_dataloaders: dict[str, DataLoader],
|
42 |
optimiser_factories: dict[str, Callable[[nn.Module], torch.optim.Optimizer]],
|
43 |
+
test_dataloader: DataLoader, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
|
44 |
reset_seed: int = 42, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True):
|
45 |
'''Run a series of modelling tasks by defining combinations of models, dataloaders, optimisers and epochs, as well as an optional previously-fit combination
|
46 |
to start from (e.g. for a combination which is the same as a previous one but with more epochs or different training data).
|
|
|
95 |
|
96 |
|
97 |
def fit(model: nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoader,
|
98 |
+
optimiser: torch.optim.Optimizer, loss_fn: nn.Module, metric_name_and_fn: tuple[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
|
99 |
epochs: int, writer: torch.utils.tensorboard.writer.SummaryWriter,
|
100 |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, model_name: str = None) -> dict[str, list]:
|
101 |
'''Trains and tests a PyTorch model.
|
|
|
147 |
|
148 |
|
149 |
def training_step(model: nn.Module, dataloader: DataLoader,
|
150 |
+
loss_fn: nn.Module, metric_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: torch.optim.Optimizer,
|
151 |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
|
152 |
'''Trains a PyTorch model for a single epoch.
|
153 |
|
|
|
188 |
|
189 |
|
190 |
def testing_step(model: nn.Module, dataloader: DataLoader,
|
191 |
+
loss_fn: nn.Module, metric_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
192 |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu', show_progress_bar = True, epoch: int = None) -> tuple[float, float]:
|
193 |
'''Tests a PyTorch model for a single epoch.
|
194 |
|
|
|
301 |
|
302 |
|
303 |
|
304 |
+
#### Info Functions ####
|
305 |
+
|
306 |
+
def summ(model: nn.Module, input_size: tuple):
|
307 |
+
'''Shorthand for typical summary specification'''
|
308 |
+
return summary(model = model, input_size = (32, 3, 224, 224),
|
309 |
+
col_names = ['input_size', 'output_size', 'num_params', 'trainable'], col_width = 20, row_settings = ['var_names'])
|
310 |
+
|
311 |
+
|
312 |
|
313 |
#### Plotting Functions ####
|
314 |
|
pytorch_vision_utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Torchvision and related utility functions'''
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision as tv
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import timm # Here just to be exported
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
import base64
|
12 |
+
import altair as alt
|
13 |
+
import matplotlib.pyplot as plt # REMOVE IN FAVOUR OF ALTAIR
|
14 |
+
|
15 |
+
import os
|
16 |
+
import io
|
17 |
+
from pathlib import Path
|
18 |
+
from PIL import Image
|
19 |
+
# from itertools import batched # in Python>=3.12
|
20 |
+
|
21 |
+
|
22 |
+
def image_dataloaders(folders: str | Path | list[str | Path], transform: tv.transforms.Compose, batch_size: int, num_workers: int = os.cpu_count()) -> tuple[list[DataLoader], list[str]]:
|
23 |
+
'''Return PyTorch DataLoaders and class names for the given folder or list of folders (with expected subfolders named by class).
|
24 |
+
In the non-list folders case, the folder content is checked for subfolders called train, test and valid (yes, in this order for consistency), and if any is present they are treated as the list input.
|
25 |
+
The first folder is assumed to be the training data and will therefore produce a shuffling dataloader, while the others will not.
|
26 |
+
The class names are from the first folder and assumed to be consistent across the others.
|
27 |
+
'''
|
28 |
+
if isinstance(folders, (str, Path)):
|
29 |
+
data_path = Path(folders)
|
30 |
+
folders = subfolders if (subfolders := [full_sub for sub in ['train', 'valid', 'test'] if (full_sub := data_path / sub).is_dir()]) else [folders]
|
31 |
+
|
32 |
+
datasets = [tv.datasets.ImageFolder(folder, transform = transform) for folder in folders]
|
33 |
+
dataloaders = [DataLoader(ds, batch_size = batch_size, shuffle = i == 0, num_workers = num_workers, pin_memory = True, persistent_workers = True) for i, ds in enumerate(datasets)]
|
34 |
+
|
35 |
+
return dataloaders, datasets[0].classes
|
36 |
+
|
37 |
+
|
38 |
+
def plot_img_preds(model: torch.nn.Module, image_path: str, class_names: list[str], transform: tv.transforms, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
39 |
+
'''Plot one image with its prediction and probability as the title.
|
40 |
+
'''
|
41 |
+
img = Image.open(image_path)
|
42 |
+
|
43 |
+
model.to(device)
|
44 |
+
model.eval()
|
45 |
+
with torch.inference_mode(): pred_logit = model(transform(img).unsqueeze(dim = 0).to(device)) # Prepend "batch" dimension (-> [batch_size, color_channels, height, width])
|
46 |
+
pred_prob = torch.softmax(pred_logit, dim = 1)
|
47 |
+
pred_label = torch.argmax(pred_prob, dim = 1)
|
48 |
+
|
49 |
+
plt.figure()
|
50 |
+
plt.imshow(img)
|
51 |
+
plt.title(f"Pred: {class_names[pred_label]} | Prob: {pred_prob.max():.3f}")
|
52 |
+
plt.axis(False)
|
53 |
+
|
54 |
+
# Change text colour based on correctness?
|
55 |
+
|
56 |
+
|
57 |
+
def record_image_preds(image_paths: str | list[str], model: torch.nn.Module, transform: tv.transforms.Compose, class_names: list[str],
|
58 |
+
sort_by_correctness = True, device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
59 |
+
'''Generate a dataframe of paths, true classes, (single) predicted classes and their confidence.
|
60 |
+
Column names: path, true_class, pred_class, pred_prob, correct.
|
61 |
+
If sort_by_correctness, then the dataframe is sorted by increasing correctness and confidence, i.e. first by prediction correctness and then by its probability,
|
62 |
+
with wrong predictions first, and both wrong and right by decreasing confidence.
|
63 |
+
If a single string is given as image_paths, then all */*.jpg and */*.png matches from it are used instead.
|
64 |
+
'''
|
65 |
+
true_classes, pred_classes, pred_probs, correctness, image_data = [], [], [], [], []
|
66 |
+
|
67 |
+
if isinstance(image_paths, str): image_paths = list(Path(image_paths).glob('*/*.jpg')) + list(Path(image_paths).glob('*/*.png'))
|
68 |
+
|
69 |
+
for path in tqdm(image_paths):
|
70 |
+
img = Image.open(path)
|
71 |
+
|
72 |
+
model.eval()
|
73 |
+
with torch.inference_mode(): pred_logit = model(transform(img).unsqueeze(0).to(device)) # Prepend "batch" dimension (-> [batch_size, color_channels, height, width])
|
74 |
+
pred_prob = torch.softmax(pred_logit, dim = 1)
|
75 |
+
pred_label = torch.argmax(pred_prob, dim = 1)
|
76 |
+
|
77 |
+
true_classes.append(class_name := path.parent.stem)
|
78 |
+
pred_classes.append(pred_class := class_names[pred_label.cpu()])
|
79 |
+
pred_probs.append(pred_prob.unsqueeze(0).max().cpu().item())
|
80 |
+
correctness.append(class_name == pred_class)
|
81 |
+
|
82 |
+
|
83 |
+
res = pd.DataFrame(dict(path = [str(p) for p in image_paths], true_class = true_classes, pred_class = pred_classes, pred_prob = pred_probs, correct = correctness))
|
84 |
+
return res.sort_values(by = ['correct', 'pred_prob'], ascending = [True, False]) if sort_by_correctness else res
|
85 |
+
|
86 |
+
|
87 |
+
def base64_image_formatter(image_or_path: Image.Image | str) -> str:
|
88 |
+
'''Generate a base64-encoded string representation of the given image (or path).
|
89 |
+
Example usecase: a dataframe meant for Altair contains PIL images (or their paths) in a column, in which case pass this temporary dataframe to the alt.Chart:
|
90 |
+
`df.assign(image = df.image_or_path_column.apply(base64_image_formatter))`
|
91 |
+
'''
|
92 |
+
if isinstance(image_or_path, str): image_or_path = Image.open(image_or_path)
|
93 |
+
with io.BytesIO() as buffer: # Docs: https://altair-viz.github.io/user_guide/marks/image.html#use-local-images-as-image-marks
|
94 |
+
image_or_path.save(buffer, format = 'PNG')
|
95 |
+
data = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
96 |
+
return f'data:image/png;base64,{data}'
|
97 |
+
|
98 |
+
|
99 |
+
def image_pred_grid(image_df: pd.DataFrame, ncols = 4, img_width = 200, img_height = 200, allow_1_col_reduction = True):
|
100 |
+
'''Create an Altair plot displaying a grid of images and their predicted classes, highlighting incorrect predictions.
|
101 |
+
image_df is expected to have the columns: path, true_class, pred_class, pred_prob, correct.
|
102 |
+
If allow_1_col_reduction and the last row (by the given ncols) is at least half empty and using ncols-1 would not increase rows, then ncols-1 is used instead.
|
103 |
+
'''
|
104 |
+
# Docs: https://altair-viz.github.io/user_guide/compound_charts.html
|
105 |
+
# Opened issue on making it easier through alt.Facet: https://github.com/altair-viz/altair/issues/3398
|
106 |
+
|
107 |
+
ncols = min(ncols, len(image_df))
|
108 |
+
nrows = 1 + len(image_df) // ncols
|
109 |
+
# If the last row is at least half empty and could reduce columns without increasing rows, do so
|
110 |
+
if allow_1_col_reduction and nrows > 1 and len(image_df) % ncols <= ncols / 2 and 1 + len(image_df) // (ncols - 1) == nrows: ncols -= 1
|
111 |
+
|
112 |
+
expanded_df = image_df.assign(
|
113 |
+
image = image_df.path.apply(base64_image_formatter),
|
114 |
+
title = image_df.pred_class + ' - ' + image_df.pred_prob.map(lambda p: f'{p:.2f}'),
|
115 |
+
index = image_df.index
|
116 |
+
)
|
117 |
+
|
118 |
+
base = alt.Chart(expanded_df).mark_image(width = img_width, height = img_height).encode(url = 'image:N')
|
119 |
+
chart = alt.vconcat()
|
120 |
+
for row_indices in (expanded_df.index[i:i + ncols] for i in range(0, len(expanded_df), ncols)): # itertools.batched(expanded_df.index, ncols) in Python>=3.12
|
121 |
+
row_chart = alt.hconcat()
|
122 |
+
for index in row_indices:
|
123 |
+
row_chart |= base.transform_filter(alt.datum.index == index).properties(
|
124 |
+
title = alt.Title(expanded_df.title[index], fontSize = 17, color = 'green' if expanded_df.correct[index] else 'red'))
|
125 |
+
chart &= row_chart
|
126 |
+
|
127 |
+
## Version with no subplots (but no titles)
|
128 |
+
# chart = alt.Chart(image_df.assign( # vv cannot trust the df index since it might not be ordered
|
129 |
+
# row = np.arange(len(image_df)) // ncols, col = np.arange(len(image_df)) % ncols # Could use the transform_compose block for this, but no // in the alt.expr language
|
130 |
+
# )).mark_image(width = img_width, height = img_height).encode(
|
131 |
+
# alt.X('col:O', title = None, axis = None), alt.Y('row:O', title = None, axis = None), url = 'image:N'
|
132 |
+
# ).properties(
|
133 |
+
# width = img_width * 1.1 * ncols, height = img_height * 1.1 * nrows
|
134 |
+
# )
|
135 |
+
|
136 |
+
## Version with faceting (but not coloured titles (no titles in fact, but non-coloured headers))
|
137 |
+
# chart = alt.Chart(image_df.assign(
|
138 |
+
# image = image_df.path.apply(base64_image_formatter),
|
139 |
+
# title = image_df.pred_class + ' - ' + image_df.pred_prob.map(lambda p: f'{p:.2f}')
|
140 |
+
# )).mark_image(width = img_width, height = img_height).encode(url = 'image:N'
|
141 |
+
# ).facet( # Header fields: https://altair-viz.github.io/user_guide/generated/core/altair.Header.html
|
142 |
+
# alt.Facet('title:N', header = alt.Header(labelFontSize = 17, labelColor = 'red')).title('Prediction and Confidence'), columns = ncols, title = 'Hi'
|
143 |
+
# )
|
144 |
+
|
145 |
+
return chart
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
# import torchvision
|
152 |
+
# import matplotlib.pyplot as plt
|
153 |
+
# # Plot the top 5 most wrong images
|
154 |
+
# for row in top_5_most_wrong.iterrows():
|
155 |
+
# row = row[1]
|
156 |
+
# image_path = row[0]
|
157 |
+
# true_label = row[1]
|
158 |
+
# pred_prob = row[2]
|
159 |
+
# pred_class = row[3]
|
160 |
+
# # Plot the image and various details
|
161 |
+
# img = torchvision.io.read_image(str(image_path)) # get image as tensor
|
162 |
+
# plt.figure()
|
163 |
+
# plt.imshow(img.permute(1, 2, 0)) # matplotlib likes images in [height, width, color_channels]
|
164 |
+
# plt.title(f"True: {true_label} | Pred: {pred_class} | Prob: {pred_prob:.3f}")
|
165 |
+
# plt.axis(False);
|
166 |
+
|
167 |
+
|
168 |
+
|