Spaces:
Build error
Build error
import os | |
import numpy as np | |
import albumentations | |
from torch.utils.data import Dataset | |
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex | |
class CustomBase(Dataset): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.data = None | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, i): | |
example = self.data[i] | |
return example | |
class CustomTrain(CustomBase): | |
def __init__(self, size, training_images_list_file): | |
super().__init__() | |
with open(training_images_list_file, "r") as f: | |
paths = f.read().splitlines() | |
self.data = ImagePaths(paths=paths, size=size, random_crop=False) | |
class CustomTest(CustomBase): | |
def __init__(self, size, test_images_list_file): | |
super().__init__() | |
with open(test_images_list_file, "r") as f: | |
paths = f.read().splitlines() | |
self.data = ImagePaths(paths=paths, size=size, random_crop=False) | |