Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset | |
from torchvision import datasets | |
import torchvision.transforms as T | |
from torch.utils.data import DataLoader | |
class StyleDataset(Dataset): | |
def __init__(self, datadir, batch_size, sampler, image_side_length=256, num_workers=2): | |
transform = T.Compose([ | |
T.Resize(size=(image_side_length * 2, image_side_length * 2)), | |
T.RandomCrop(image_side_length), | |
T.ToTensor(), | |
]) | |
dataset = datasets.ImageFolder(datadir, transform=transform) | |
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler(len(dataset)), | |
num_workers=num_workers) | |
return dataloader |