Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import os | |
import cv2 | |
import numpy as np | |
import torch.utils.data as data | |
from PIL import Image | |
from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS | |
import imageio | |
class FolderDataset(data.Dataset): | |
r"""This deals with opening, and reading from an Folder dataset. | |
Args: | |
root (str): Path to the folder. | |
metadata (dict): Containing extensions. | |
""" | |
def __init__(self, root, metadata): | |
self.root = os.path.expanduser(root) | |
self.extensions = metadata | |
print('Folder at %s opened.' % (root)) | |
def getitem_by_path(self, path, data_type): | |
r"""Load data item stored for key = path. | |
Args: | |
path (str): Key into Folder dataset. | |
data_type (str): Key into self.extensions e.g. data/data_segmaps/... | |
Returns: | |
img (PIL.Image) or buf (str): Contents of file for this key. | |
""" | |
# Figure out decoding params. | |
ext = self.extensions[data_type] | |
is_image = False | |
is_hdr = False | |
if ext in IMG_EXTENSIONS: | |
is_image = True | |
if 'tif' in ext: | |
dtype, mode = np.uint16, -1 | |
elif 'JPEG' in ext or 'JPG' in ext \ | |
or 'jpeg' in ext or 'jpg' in ext: | |
dtype, mode = np.uint8, 3 | |
else: | |
dtype, mode = np.uint8, -1 | |
elif ext in HDR_IMG_EXTENSIONS: | |
is_hdr = True | |
else: | |
is_image = False | |
# Get value from key. | |
filepath = os.path.join(self.root, path.decode() + '.' + ext) | |
assert os.path.exists(filepath), '%s does not exist' % (filepath) | |
with open(filepath, 'rb') as f: | |
buf = f.read() | |
# Decode and return. | |
if is_image: | |
try: | |
img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode) | |
except Exception: | |
print(path) | |
# BGR to RGB if 3 channels. | |
if img.ndim == 3 and img.shape[-1] == 3: | |
img = img[:, :, ::-1] | |
img = Image.fromarray(img) | |
return img | |
elif is_hdr: | |
try: | |
imageio.plugins.freeimage.download() | |
img = imageio.imread(buf) | |
except Exception: | |
print(path) | |
return img # Return a numpy array | |
else: | |
return buf | |
def __len__(self): | |
r"""Return number of keys in Folder dataset.""" | |
return self.length | |