masabhuq's picture
Initial Commit
0c7049d
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
def make_dataloaders(train_ds,
test_ds,
batch_size: int):
"""Creates dataloaders
Creates dataloaders by taking the directory in which train and test data are stored.
Args:
transforms(torchvision.transforms.Compose): Transform to apply to the dataset.
Returns:
tuple: train_dataloader, test_dataloader, class_names
"""
train_dataloader = DataLoader(dataset = train_ds,
batch_size = batch_size,
num_workers = 1,
shuffle = True)
test_dataloader = DataLoader(dataset = test_ds,
batch_size = batch_size,
num_workers = 1,
shuffle = False)
return train_dataloader, test_dataloader