venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import json
import os
import cv2
import lmdb
import numpy as np
import torch.utils.data as data
from PIL import Image
from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
from imaginaire.utils.distributed import master_only_print as print
import imageio
class LMDBDataset(data.Dataset):
r"""This deals with opening, and reading from an LMDB dataset.
Args:
root (str): Path to the LMDB file.
"""
def __init__(self, root):
self.root = os.path.expanduser(root)
self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
# Read metadata.
with open(os.path.join(self.root, '..', 'metadata.json')) as fin:
self.extensions = json.load(fin)
print('LMDB file at %s opened.' % (root))
def getitem_by_path(self, path, data_type):
r"""Load data item stored for key = path.
Args:
path (str): Key into LMDB dataset.
data_type (str): Key into self.extensions e.g. data/data_segmaps/...
Returns:
img (PIL.Image) or buf (str): Contents of LMDB value for this key.
"""
# Figure out decoding params.
ext = self.extensions[data_type]
is_image = False
is_hdr = False
if ext in IMG_EXTENSIONS:
is_image = True
if 'tif' in ext:
dtype, mode = np.uint16, -1
elif 'JPEG' in ext or 'JPG' in ext \
or 'jpeg' in ext or 'jpg' in ext:
dtype, mode = np.uint8, 3
else:
dtype, mode = np.uint8, -1
elif ext in HDR_IMG_EXTENSIONS:
is_hdr = True
else:
is_image = False
# Get value from key.
with self.env.begin(write=False) as txn:
buf = txn.get(path)
# Decode and return.
if is_image:
try:
img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode)
except Exception:
print(path)
# BGR to RGB if 3 channels.
if img.ndim == 3 and img.shape[-1] == 3:
img = img[:, :, ::-1]
img = Image.fromarray(img)
return img
elif is_hdr:
try:
imageio.plugins.freeimage.download()
img = imageio.imread(buf)
except Exception:
print(path)
return img # Return a numpy array
else:
return buf
def __len__(self):
r"""Return number of keys in LMDB dataset."""
return self.length