sat3density / imaginaire /datasets /object_store.py
venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import io
import json
# import cv2
import boto3
from botocore.config import Config
import numpy as np
import torch.utils.data as data
from PIL import Image
import imageio
from botocore.exceptions import ClientError
from imaginaire.datasets.cache import Cache
from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
Image.MAX_IMAGE_PIXELS = None
class ObjectStoreDataset(data.Dataset):
r"""This deals with opening, and reading from an AWS S3 bucket.
Args:
root (str): Path to the AWS S3 bucket.
aws_credentials_file (str): Path to file containing AWS credentials.
data_type (str): Which data type should this dataset load?
"""
def __init__(self, root, aws_credentials_file, data_type='', cache=None):
# Cache.
self.cache = False
if cache is not None:
# raise NotImplementedError
self.cache = Cache(cache.root, cache.size_GB)
# Get bucket info, and keys to info about dataset.
with open(aws_credentials_file) as fin:
self.credentials = json.load(fin)
parts = root.split('/')
self.bucket = parts[0]
self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json'
self.metadata_key = '/'.join(parts[1:]) + '/metadata.json'
# Get list of filenames.
filename_info = self._get_object(self.all_filenames_key)
self.sequence_list = json.loads(filename_info.decode('utf-8'))
# Get length.
length = 0
for _, value in self.sequence_list.items():
length += len(value)
self.length = length
# Read metadata.
metadata_info = self._get_object(self.metadata_key)
self.extensions = json.loads(metadata_info.decode('utf-8'))
self.data_type = data_type
print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type))
def _get_object(self, key):
r"""Download object from bucket.
Args:
key (str): Key inside bucket.
"""
# Look up value in cache.
object_content = self.cache.read(key) if self.cache else False
if not object_content:
# Either no cache used or key not found in cache.
config = Config(connect_timeout=30,
signature_version="s3",
retries={"max_attempts": 999999})
s3 = boto3.client('s3', **self.credentials, config=config)
try:
s3_response_object = s3.get_object(Bucket=self.bucket, Key=key)
object_content = s3_response_object['Body'].read()
except Exception as e:
print('%s not found' % (key))
print(e)
# Save content to cache.
if self.cache:
self.cache.write(key, object_content)
return object_content
def getitem_by_path(self, path, data_type):
r"""Load data item stored for key = path.
Args:
path (str): Path into AWS S3 bucket, without data_type prefix.
data_type (str): Key into self.extensions e.g. data/data_segmaps/...
Returns:
img (PIL.Image) or buf (str): Contents of LMDB value for this key.
"""
# Figure out decoding params.
ext = self.extensions[data_type]
is_image = False
is_hdr = False
parts = path.split('/')
key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext
if ext in IMG_EXTENSIONS:
is_image = True
if 'tif' in ext:
_, mode = np.uint16, -1
elif 'JPEG' in ext or 'JPG' in ext \
or 'jpeg' in ext or 'jpg' in ext:
_, mode = np.uint8, 3
else:
_, mode = np.uint8, -1
elif ext in HDR_IMG_EXTENSIONS:
is_hdr = True
else:
is_image = False
# Get value from key.
buf = self._get_object(key)
# Decode and return.
if is_image:
# This is totally a hack.
# We should have a better way to handle grayscale images.
img = Image.open(io.BytesIO(buf))
if mode == 3:
img = img.convert('RGB')
return img
elif is_hdr:
try:
imageio.plugins.freeimage.download()
img = imageio.imread(buf)
except Exception:
print(path)
return img # Return a numpy array
else:
return buf
def __len__(self):
r"""Return number of keys in LMDB dataset."""
return self.length