Spaces:
Sleeping
Sleeping
Added (Altair) plotting of images and their prediction correctness
Browse files- pytorch_utils.py +6 -19
pytorch_utils.py
CHANGED
@@ -6,9 +6,9 @@ import torch
|
|
6 |
from torch import nn
|
7 |
from torch.utils.data import DataLoader
|
8 |
from torch.utils.tensorboard import SummaryWriter
|
9 |
-
from torchinfo import summary # Here just to be exported
|
10 |
-
import torchvision as tv
|
11 |
|
|
|
|
|
12 |
import torchmetrics
|
13 |
|
14 |
import numpy as np
|
@@ -265,21 +265,6 @@ def tensorboard_writer(experiment_name: str, model_name: str, extra: str = None)
|
|
265 |
return SummaryWriter(log_dir = log_dir)
|
266 |
|
267 |
|
268 |
-
def image_dataloaders(folders: str | Path | list[str | Path], transform: tv.transforms.Compose, batch_size: int, num_workers: int = os.cpu_count()) -> tuple[list[DataLoader], list[str]]:
|
269 |
-
'''Return PyTorch DataLoaders and class names for the given folder or list of folders (with expected subfolders named by class).
|
270 |
-
In the non-list folders case, the folder content is checked for subfolders called train, test and valid (yes, in this order for consistency), and if any is present they are treated as the list input.
|
271 |
-
The first folder is assumed to be the training data and will therefore produce a shuffling dataloader, while the others will not.
|
272 |
-
The class names are from the first folder and assumed to be consistent across the others.'''
|
273 |
-
if isinstance(folders, (str, Path)):
|
274 |
-
data_path = Path(folders)
|
275 |
-
folders = subfolders if (subfolders := [full_sub for sub in ['train', 'valid', 'test'] if (full_sub := data_path / sub).is_dir()]) else [folders]
|
276 |
-
|
277 |
-
datasets = [tv.datasets.ImageFolder(folder, transform = transform) for folder in folders]
|
278 |
-
dataloaders = [DataLoader(ds, batch_size = batch_size, shuffle = i == 0, num_workers = num_workers, pin_memory = True) for i, ds in enumerate(datasets)]
|
279 |
-
|
280 |
-
return dataloaders, datasets[0].classes
|
281 |
-
|
282 |
-
|
283 |
def download_unzip(source: str, destination: str, remove_source: bool = True) -> Path:
|
284 |
'''Downloads a zipped dataset from source and unzips it at destination.
|
285 |
|
@@ -322,7 +307,8 @@ def download_unzip(source: str, destination: str, remove_source: bool = True) ->
|
|
322 |
|
323 |
def plot_predictions(train_data, train_labels, test_data, test_labels, predictions = None):
|
324 |
'''Plots (matplotlib) linear training data and test data and compares predictions.
|
325 |
-
Training data is in blue, test data in green, and predictions in red (if present).
|
|
|
326 |
plt.figure(figsize = (10, 7))
|
327 |
|
328 |
plt.scatter(train_data, train_labels, c = 'b', s = 4, label = 'Training data')
|
@@ -333,7 +319,8 @@ def plot_predictions(train_data, train_labels, test_data, test_labels, predictio
|
|
333 |
|
334 |
|
335 |
def plot_loss_curves(train_loss: list, train_metric: list, test_loss: list, test_metric: list):
|
336 |
-
'''Plots (matplotlib) training (and testing) curves from lists of values.
|
|
|
337 |
epochs = range(len(train_loss))
|
338 |
|
339 |
plt.figure(figsize = (15, 7))
|
|
|
6 |
from torch import nn
|
7 |
from torch.utils.data import DataLoader
|
8 |
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
9 |
|
10 |
+
# Here just to be exported
|
11 |
+
from torchinfo import summary
|
12 |
import torchmetrics
|
13 |
|
14 |
import numpy as np
|
|
|
265 |
return SummaryWriter(log_dir = log_dir)
|
266 |
|
267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
def download_unzip(source: str, destination: str, remove_source: bool = True) -> Path:
|
269 |
'''Downloads a zipped dataset from source and unzips it at destination.
|
270 |
|
|
|
307 |
|
308 |
def plot_predictions(train_data, train_labels, test_data, test_labels, predictions = None):
|
309 |
'''Plots (matplotlib) linear training data and test data and compares predictions.
|
310 |
+
Training data is in blue, test data in green, and predictions in red (if present).
|
311 |
+
'''
|
312 |
plt.figure(figsize = (10, 7))
|
313 |
|
314 |
plt.scatter(train_data, train_labels, c = 'b', s = 4, label = 'Training data')
|
|
|
319 |
|
320 |
|
321 |
def plot_loss_curves(train_loss: list, train_metric: list, test_loss: list, test_metric: list):
|
322 |
+
'''Plots (matplotlib) training (and testing) curves from lists of values.
|
323 |
+
'''
|
324 |
epochs = range(len(train_loss))
|
325 |
|
326 |
plt.figure(figsize = (15, 7))
|