Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import io | |
import json | |
# import cv2 | |
import boto3 | |
from botocore.config import Config | |
import numpy as np | |
import torch.utils.data as data | |
from PIL import Image | |
import imageio | |
from botocore.exceptions import ClientError | |
from imaginaire.datasets.cache import Cache | |
from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS | |
Image.MAX_IMAGE_PIXELS = None | |
class ObjectStoreDataset(data.Dataset): | |
r"""This deals with opening, and reading from an AWS S3 bucket. | |
Args: | |
root (str): Path to the AWS S3 bucket. | |
aws_credentials_file (str): Path to file containing AWS credentials. | |
data_type (str): Which data type should this dataset load? | |
""" | |
def __init__(self, root, aws_credentials_file, data_type='', cache=None): | |
# Cache. | |
self.cache = False | |
if cache is not None: | |
# raise NotImplementedError | |
self.cache = Cache(cache.root, cache.size_GB) | |
# Get bucket info, and keys to info about dataset. | |
with open(aws_credentials_file) as fin: | |
self.credentials = json.load(fin) | |
parts = root.split('/') | |
self.bucket = parts[0] | |
self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json' | |
self.metadata_key = '/'.join(parts[1:]) + '/metadata.json' | |
# Get list of filenames. | |
filename_info = self._get_object(self.all_filenames_key) | |
self.sequence_list = json.loads(filename_info.decode('utf-8')) | |
# Get length. | |
length = 0 | |
for _, value in self.sequence_list.items(): | |
length += len(value) | |
self.length = length | |
# Read metadata. | |
metadata_info = self._get_object(self.metadata_key) | |
self.extensions = json.loads(metadata_info.decode('utf-8')) | |
self.data_type = data_type | |
print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type)) | |
def _get_object(self, key): | |
r"""Download object from bucket. | |
Args: | |
key (str): Key inside bucket. | |
""" | |
# Look up value in cache. | |
object_content = self.cache.read(key) if self.cache else False | |
if not object_content: | |
# Either no cache used or key not found in cache. | |
config = Config(connect_timeout=30, | |
signature_version="s3", | |
retries={"max_attempts": 999999}) | |
s3 = boto3.client('s3', **self.credentials, config=config) | |
try: | |
s3_response_object = s3.get_object(Bucket=self.bucket, Key=key) | |
object_content = s3_response_object['Body'].read() | |
except Exception as e: | |
print('%s not found' % (key)) | |
print(e) | |
# Save content to cache. | |
if self.cache: | |
self.cache.write(key, object_content) | |
return object_content | |
def getitem_by_path(self, path, data_type): | |
r"""Load data item stored for key = path. | |
Args: | |
path (str): Path into AWS S3 bucket, without data_type prefix. | |
data_type (str): Key into self.extensions e.g. data/data_segmaps/... | |
Returns: | |
img (PIL.Image) or buf (str): Contents of LMDB value for this key. | |
""" | |
# Figure out decoding params. | |
ext = self.extensions[data_type] | |
is_image = False | |
is_hdr = False | |
parts = path.split('/') | |
key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext | |
if ext in IMG_EXTENSIONS: | |
is_image = True | |
if 'tif' in ext: | |
_, mode = np.uint16, -1 | |
elif 'JPEG' in ext or 'JPG' in ext \ | |
or 'jpeg' in ext or 'jpg' in ext: | |
_, mode = np.uint8, 3 | |
else: | |
_, mode = np.uint8, -1 | |
elif ext in HDR_IMG_EXTENSIONS: | |
is_hdr = True | |
else: | |
is_image = False | |
# Get value from key. | |
buf = self._get_object(key) | |
# Decode and return. | |
if is_image: | |
# This is totally a hack. | |
# We should have a better way to handle grayscale images. | |
img = Image.open(io.BytesIO(buf)) | |
if mode == 3: | |
img = img.convert('RGB') | |
return img | |
elif is_hdr: | |
try: | |
imageio.plugins.freeimage.download() | |
img = imageio.imread(buf) | |
except Exception: | |
print(path) | |
return img # Return a numpy array | |
else: | |
return buf | |
def __len__(self): | |
r"""Return number of keys in LMDB dataset.""" | |
return self.length | |