from io import BytesIO import lmdb from PIL import Image from torch.utils.data import Dataset class MultiResolutionDataset(Dataset): def __init__(self, path, transform, resolution=256): self.env = lmdb.open( path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False, ) if not self.env: raise IOError('Cannot open lmdb dataset', path) with self.env.begin(write=False) as txn: self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) self.resolution = resolution self.transform = transform def __len__(self): return self.length def __getitem__(self, index): with self.env.begin(write=False) as txn: key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') img_bytes = txn.get(key) buffer = BytesIO(img_bytes) img = Image.open(buffer) img = self.transform(img) return img