Spaces:
Runtime error
Runtime error
| import os | |
| from os.path import join | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| class ImageDataset(Dataset): | |
| def __init__(self, folder_path): | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5)) | |
| ]) | |
| self.files = [ | |
| join(folder_path, file) for file in os.listdir(folder_path) | |
| ] | |
| def __getitem__(self, index): | |
| return self.transform(Image.open(self.files[index % len(self.files)])) | |
| def __len__(self): | |
| return len(self.files) |