import torch.utils.data as data from PIL import Image import os import os.path from io import BytesIO import lmdb from torch.utils.data import Dataset class MultiResolutionDataset(Dataset): def __init__(self, path, transform, resolution=256): self.env = lmdb.open( path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False, ) if not self.env: raise IOError('Cannot open lmdb dataset', path) with self.env.begin(write=False) as txn: self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) self.resolution = resolution self.transform = transform def __len__(self): return self.length def __getitem__(self, index): with self.env.begin(write=False) as txn: key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') img_bytes = txn.get(key) buffer = BytesIO(img_bytes) img = Image.open(buffer) img = self.transform(img) return img def has_file_allowed_extension(filename, extensions): """Checks if a file is an allowed extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in extensions) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def make_dataset(dir, extensions): images = [] for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = (path, 0) images.append(item) return images class DatasetFolder(data.Dataset): def __init__(self, root, loader, extensions, transform=None, target_transform=None): # classes, class_to_idx = find_classes(root) samples = make_dataset(root, extensions) if len(samples) == 0: raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" "Supported extensions are: " + ",".join(extensions))) self.root = root self.loader = loader self.extensions = extensions self.samples = samples self.transform = transform self.target_transform = target_transform def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def default_loader(path): return pil_loader(path) class ImageFolder(DatasetFolder): def __init__(self, root, transform1=None, transform2=None, target_transform=None, loader=default_loader): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, transform=transform1, target_transform=target_transform) self.imgs = self.samples self.transform2 = transform2 def set_stage(self, stage): if stage == 'last': self.transform = self.transform2 class ListFolder(Dataset): def __init__(self, txt, transform): with open(txt) as f: imgpaths= f.readlines() self.imgpaths = [x.strip() for x in imgpaths] self.transform = transform def __getitem__(self, idx): path = self.imgpaths[idx] image = Image.open(path) return self.transform(image) def __len__(self): return len(self.imgpaths)