|
""" |
|
Contains functionality for creating PyTorch DataLoaders for |
|
image classification data. |
|
""" |
|
import torch |
|
from torchvision import transforms, datasets |
|
from torch.utils.data import DataLoader |
|
from torchvision.datasets import ImageFolder |
|
|
|
def train_test_dataloader(train_dir: str, |
|
test_dir: str, |
|
transform: transforms.Compose, |
|
batch_size: int): |
|
"""Creates training and testing DataLoaders. |
|
|
|
Takes in a training directory and testing directory path and turns |
|
them into PyTorch Datasets using ImageFolder and then into PyTorch DataLoaders. |
|
|
|
Args: |
|
train_dir: Path to training directory. |
|
test_dir: Path to testing directory. |
|
transform: torchvision transforms to perform on training and testing data. |
|
batch_size: Number of samples per batch in each of the DataLoaders. |
|
|
|
Returns: |
|
A tuple of (train_dataloader, test_dataloader, class_names). |
|
Where class_names is a list of the target classes. |
|
Example usage: |
|
train_dataloader, test_dataloader, class_names = \ |
|
= create_dataloaders(train_dir=path/to/train_dir, |
|
test_dir=path/to/test_dir, |
|
transform=some_transform, |
|
batch_size=32) |
|
""" |
|
|
|
dataset_train = ImageFolder(root=train_dir, transform=transform) |
|
dataset_test = ImageFolder(root=test_dir, transform=transform) |
|
|
|
|
|
class_names = dataset_train.classes |
|
|
|
|
|
train_dataloader = DataLoader(dataset_train, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
test_dataloader = DataLoader(dataset_test, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
|
|
return train_dataloader, test_dataloader, class_names |