Spaces:
Runtime error
Runtime error
from .base import ImageDataModule | |
from torch.utils.data import random_split | |
from torchvision.datasets import MNIST, CIFAR10 | |
from typing import Optional | |
class MNISTDataModule(ImageDataModule): | |
"""Datamodule for the MNIST dataset.""" | |
def prepare_data(self): | |
# Download MNIST | |
MNIST(self.data_dir, train=True, download=True) | |
MNIST(self.data_dir, train=False, download=True) | |
def setup(self, stage: Optional[str] = None): | |
# Set the training and validation data | |
if stage == "fit" or stage is None: | |
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) | |
self.train_data, self.val_data = random_split(mnist_full, [55000, 5000]) | |
# Set the test data | |
if stage == "test" or stage is None: | |
self.test_data = MNIST(self.data_dir, train=False, transform=self.transform) | |
class CIFAR10DataModule(ImageDataModule): | |
"""Datamodule for the CIFAR10 dataset.""" | |
def prepare_data(self): | |
# Download CIFAR10 | |
CIFAR10(self.data_dir, train=True, download=True) | |
CIFAR10(self.data_dir, train=False, download=True) | |
def setup(self, stage: Optional[str] = None): | |
# Set the training and validation data | |
if stage == "fit" or stage is None: | |
cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform) | |
self.train_data, self.val_data = random_split(cifar10_full, [45000, 5000]) | |
# Set the test data | |
if stage == "test" or stage is None: | |
self.test_data = CIFAR10( | |
self.data_dir, train=False, transform=self.transform | |
) | |