#based on https://github.com/CompVis/taming-transformers import pickle from torch.utils.data import Dataset from ldm.data.base import ImagePaths import ldm.data.constants as CONSTANTS class CustomBase(Dataset): def __init__(self, *args, **kwargs): super().__init__() self.data = None def __len__(self): return len(self.data) def __getitem__(self, i): example = self.data[i] return example class CustomTrain(CustomBase): def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None): super().__init__() with open(training_images_list_file, "r") as f: paths = sorted(f.read().splitlines()) labels=None if add_labels: labels_per_file = list(map(lambda path: path.split('/')[-2], paths)) labels_set = sorted(list(set(labels_per_file))) self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)} if class_to_node: with open(class_to_node, 'rb') as pickle_file: class_to_node_dict = pickle.load(pickle_file) labels = { CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file], CONSTANTS.DATASET_CLASSNAME: labels_per_file, 'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file] } # labels = [self.labels_to_idx[label_name] for label_name in labels_per_file] else: labels = { CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file], CONSTANTS.DATASET_CLASSNAME: labels_per_file } self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()} self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip, random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels, unique_skipped_labels=unique_skipped_labels) class CustomTest(CustomTrain): def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None): super().__init__(size, test_images_list_file, add_labels=add_labels, unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)