Spaces:
Runtime error
Runtime error
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
import torchvision | |
def make_dataloaders(train_ds, | |
test_ds, | |
batch_size: int): | |
"""Creates dataloaders | |
Creates dataloaders by taking the directory in which train and test data are stored. | |
Args: | |
transforms(torchvision.transforms.Compose): Transform to apply to the dataset. | |
Returns: | |
tuple: train_dataloader, test_dataloader, class_names | |
""" | |
train_dataloader = DataLoader(dataset = train_ds, | |
batch_size = batch_size, | |
num_workers = 1, | |
shuffle = True) | |
test_dataloader = DataLoader(dataset = test_ds, | |
batch_size = batch_size, | |
num_workers = 1, | |
shuffle = False) | |
return train_dataloader, test_dataloader | |