File size: 604 Bytes
cff2458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def load_mnist():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
return {'train': train_loader, 'test': test_loader}
|