# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import json import os import cv2 import lmdb import numpy as np import torch.utils.data as data from PIL import Image from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS from imaginaire.utils.distributed import master_only_print as print import imageio class LMDBDataset(data.Dataset): r"""This deals with opening, and reading from an LMDB dataset. Args: root (str): Path to the LMDB file. """ def __init__(self, root): self.root = os.path.expanduser(root) self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = txn.stat()['entries'] # Read metadata. with open(os.path.join(self.root, '..', 'metadata.json')) as fin: self.extensions = json.load(fin) print('LMDB file at %s opened.' % (root)) def getitem_by_path(self, path, data_type): r"""Load data item stored for key = path. Args: path (str): Key into LMDB dataset. data_type (str): Key into self.extensions e.g. data/data_segmaps/... Returns: img (PIL.Image) or buf (str): Contents of LMDB value for this key. """ # Figure out decoding params. ext = self.extensions[data_type] is_image = False is_hdr = False if ext in IMG_EXTENSIONS: is_image = True if 'tif' in ext: dtype, mode = np.uint16, -1 elif 'JPEG' in ext or 'JPG' in ext \ or 'jpeg' in ext or 'jpg' in ext: dtype, mode = np.uint8, 3 else: dtype, mode = np.uint8, -1 elif ext in HDR_IMG_EXTENSIONS: is_hdr = True else: is_image = False # Get value from key. with self.env.begin(write=False) as txn: buf = txn.get(path) # Decode and return. if is_image: try: img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode) except Exception: print(path) # BGR to RGB if 3 channels. if img.ndim == 3 and img.shape[-1] == 3: img = img[:, :, ::-1] img = Image.fromarray(img) return img elif is_hdr: try: imageio.plugins.freeimage.download() img = imageio.imread(buf) except Exception: print(path) return img # Return a numpy array else: return buf def __len__(self): r"""Return number of keys in LMDB dataset.""" return self.length