Spaces:
Runtime error
Runtime error
File size: 1,654 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from .base import ImageDataModule
from torch.utils.data import random_split
from torchvision.datasets import MNIST, CIFAR10
from typing import Optional
class MNISTDataModule(ImageDataModule):
"""Datamodule for the MNIST dataset."""
def prepare_data(self):
# Download MNIST
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Set the training and validation data
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.train_data, self.val_data = random_split(mnist_full, [55000, 5000])
# Set the test data
if stage == "test" or stage is None:
self.test_data = MNIST(self.data_dir, train=False, transform=self.transform)
class CIFAR10DataModule(ImageDataModule):
"""Datamodule for the CIFAR10 dataset."""
def prepare_data(self):
# Download CIFAR10
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Set the training and validation data
if stage == "fit" or stage is None:
cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.train_data, self.val_data = random_split(cifar10_full, [45000, 5000])
# Set the test data
if stage == "test" or stage is None:
self.test_data = CIFAR10(
self.data_dir, train=False, transform=self.transform
)
|