import pytorch_lightning as pl import torchvision from torch.utils.data import DataLoader, Dataset from torchvision import transforms class MNISTDataDictWrapper(Dataset): def __init__(self, dset): super().__init__() self.dset = dset def __getitem__(self, i): x, y = self.dset[i] return {"jpg": x, "cls": y} def __len__(self): return len(self.dset) class MNISTLoader(pl.LightningDataModule): def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): super().__init__() transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] ) self.batch_size = batch_size self.num_workers = num_workers self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 self.shuffle = shuffle self.train_dataset = MNISTDataDictWrapper( torchvision.datasets.MNIST( root=".data/", train=True, download=True, transform=transform ) ) self.test_dataset = MNISTDataDictWrapper( torchvision.datasets.MNIST( root=".data/", train=False, download=True, transform=transform ) ) def prepare_data(self): pass def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, ) def val_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, ) if __name__ == "__main__": dset = MNISTDataDictWrapper( torchvision.datasets.MNIST( root=".data/", train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] ), ) ) ex = dset[0]