tinyvgg / data_setup.py
ajitsi's picture
tinyvgg cnn model for image classification
7fc0372
"""
Contains functionality for creating Pytorch DataLoaders for
image classification data.
"""
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
NUM_WORKERS = os.cpu_count()
def create_dataloaders(
train_dir: str,
test_dir: str,
transform: transforms.Compose,
batch_size: int,
num_workers: int=NUM_WORKERS
):
"""
Creates training and testing DataLoaders
"""
# Use ImageFolder to create Datasets
train_data = datasets.ImageFolder(train_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)
# Get class names
class_names = train_data.classes
# Turn images into data loaders
train_dataloader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
test_dataloader = DataLoader(
test_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
return train_dataloader, test_dataloader, class_names