Spaces:
Runtime error
Runtime error
| from torch.utils.data import DataLoader | |
| from torchvision import datasets, transforms | |
| import torchvision | |
| def make_dataloaders(train_ds, | |
| test_ds, | |
| batch_size: int): | |
| """Creates dataloaders | |
| Creates dataloaders by taking the directory in which train and test data are stored. | |
| Args: | |
| transforms(torchvision.transforms.Compose): Transform to apply to the dataset. | |
| Returns: | |
| tuple: train_dataloader, test_dataloader, class_names | |
| """ | |
| train_dataloader = DataLoader(dataset = train_ds, | |
| batch_size = batch_size, | |
| num_workers = 1, | |
| shuffle = True) | |
| test_dataloader = DataLoader(dataset = test_ds, | |
| batch_size = batch_size, | |
| num_workers = 1, | |
| shuffle = False) | |
| return train_dataloader, test_dataloader | |