from .base import ImageDataModule from torch.utils.data import random_split from torchvision.datasets import MNIST, CIFAR10 from typing import Optional class MNISTDataModule(ImageDataModule): """Datamodule for the MNIST dataset.""" def prepare_data(self): # Download MNIST MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) def setup(self, stage: Optional[str] = None): # Set the training and validation data if stage == "fit" or stage is None: mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.train_data, self.val_data = random_split(mnist_full, [55000, 5000]) # Set the test data if stage == "test" or stage is None: self.test_data = MNIST(self.data_dir, train=False, transform=self.transform) class CIFAR10DataModule(ImageDataModule): """Datamodule for the CIFAR10 dataset.""" def prepare_data(self): # Download CIFAR10 CIFAR10(self.data_dir, train=True, download=True) CIFAR10(self.data_dir, train=False, download=True) def setup(self, stage: Optional[str] = None): # Set the training and validation data if stage == "fit" or stage is None: cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform) self.train_data, self.val_data = random_split(cifar10_full, [45000, 5000]) # Set the test data if stage == "test" or stage is None: self.test_data = CIFAR10( self.data_dir, train=False, transform=self.transform )