T-Flet commited on
Commit
104ddc3
1 Parent(s): bae138d

Added (Altair) plotting of images and their prediction correctness

Browse files
Files changed (1) hide show
  1. pytorch_utils.py +6 -19
pytorch_utils.py CHANGED
@@ -6,9 +6,9 @@ import torch
6
  from torch import nn
7
  from torch.utils.data import DataLoader
8
  from torch.utils.tensorboard import SummaryWriter
9
- from torchinfo import summary # Here just to be exported
10
- import torchvision as tv
11
 
 
 
12
  import torchmetrics
13
 
14
  import numpy as np
@@ -265,21 +265,6 @@ def tensorboard_writer(experiment_name: str, model_name: str, extra: str = None)
265
  return SummaryWriter(log_dir = log_dir)
266
 
267
 
268
- def image_dataloaders(folders: str | Path | list[str | Path], transform: tv.transforms.Compose, batch_size: int, num_workers: int = os.cpu_count()) -> tuple[list[DataLoader], list[str]]:
269
- '''Return PyTorch DataLoaders and class names for the given folder or list of folders (with expected subfolders named by class).
270
- In the non-list folders case, the folder content is checked for subfolders called train, test and valid (yes, in this order for consistency), and if any is present they are treated as the list input.
271
- The first folder is assumed to be the training data and will therefore produce a shuffling dataloader, while the others will not.
272
- The class names are from the first folder and assumed to be consistent across the others.'''
273
- if isinstance(folders, (str, Path)):
274
- data_path = Path(folders)
275
- folders = subfolders if (subfolders := [full_sub for sub in ['train', 'valid', 'test'] if (full_sub := data_path / sub).is_dir()]) else [folders]
276
-
277
- datasets = [tv.datasets.ImageFolder(folder, transform = transform) for folder in folders]
278
- dataloaders = [DataLoader(ds, batch_size = batch_size, shuffle = i == 0, num_workers = num_workers, pin_memory = True) for i, ds in enumerate(datasets)]
279
-
280
- return dataloaders, datasets[0].classes
281
-
282
-
283
  def download_unzip(source: str, destination: str, remove_source: bool = True) -> Path:
284
  '''Downloads a zipped dataset from source and unzips it at destination.
285
 
@@ -322,7 +307,8 @@ def download_unzip(source: str, destination: str, remove_source: bool = True) ->
322
 
323
  def plot_predictions(train_data, train_labels, test_data, test_labels, predictions = None):
324
  '''Plots (matplotlib) linear training data and test data and compares predictions.
325
- Training data is in blue, test data in green, and predictions in red (if present).'''
 
326
  plt.figure(figsize = (10, 7))
327
 
328
  plt.scatter(train_data, train_labels, c = 'b', s = 4, label = 'Training data')
@@ -333,7 +319,8 @@ def plot_predictions(train_data, train_labels, test_data, test_labels, predictio
333
 
334
 
335
  def plot_loss_curves(train_loss: list, train_metric: list, test_loss: list, test_metric: list):
336
- '''Plots (matplotlib) training (and testing) curves from lists of values.'''
 
337
  epochs = range(len(train_loss))
338
 
339
  plt.figure(figsize = (15, 7))
 
6
  from torch import nn
7
  from torch.utils.data import DataLoader
8
  from torch.utils.tensorboard import SummaryWriter
 
 
9
 
10
+ # Here just to be exported
11
+ from torchinfo import summary
12
  import torchmetrics
13
 
14
  import numpy as np
 
265
  return SummaryWriter(log_dir = log_dir)
266
 
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  def download_unzip(source: str, destination: str, remove_source: bool = True) -> Path:
269
  '''Downloads a zipped dataset from source and unzips it at destination.
270
 
 
307
 
308
  def plot_predictions(train_data, train_labels, test_data, test_labels, predictions = None):
309
  '''Plots (matplotlib) linear training data and test data and compares predictions.
310
+ Training data is in blue, test data in green, and predictions in red (if present).
311
+ '''
312
  plt.figure(figsize = (10, 7))
313
 
314
  plt.scatter(train_data, train_labels, c = 'b', s = 4, label = 'Training data')
 
319
 
320
 
321
  def plot_loss_curves(train_loss: list, train_metric: list, test_loss: list, test_metric: list):
322
+ '''Plots (matplotlib) training (and testing) curves from lists of values.
323
+ '''
324
  epochs = range(len(train_loss))
325
 
326
  plt.figure(figsize = (15, 7))