AnTo2209's picture
refactor
e32c848
raw
history blame contribute delete
704 Bytes
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T
from torch.utils.data import DataLoader
class StyleDataset(Dataset):
def __init__(self, datadir, batch_size, sampler, image_side_length=256, num_workers=2):
transform = T.Compose([
T.Resize(size=(image_side_length * 2, image_side_length * 2)),
T.RandomCrop(image_side_length),
T.ToTensor(),
])
dataset = datasets.ImageFolder(datadir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler(len(dataset)),
num_workers=num_workers)
return dataloader