Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
"""All datasets are inherited from this class.""" | |
import importlib | |
import json | |
import os | |
import pickle | |
from collections import OrderedDict | |
from functools import partial | |
from inspect import signature | |
import numpy as np | |
import torch | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
from imaginaire.datasets.folder import FolderDataset | |
from imaginaire.datasets.lmdb import \ | |
IMG_EXTENSIONS, HDR_IMG_EXTENSIONS, LMDBDataset | |
from imaginaire.datasets.object_store import ObjectStoreDataset | |
from imaginaire.utils.data import \ | |
(VIDEO_EXTENSIONS, Augmentor, | |
load_from_folder, load_from_lmdb, load_from_object_store) | |
from imaginaire.utils.lmdb import create_metadata | |
DATASET_TYPES = ['lmdb', 'folder', 'object_store'] | |
class BaseDataset(data.Dataset): | |
r"""Base class for image/video datasets. | |
Args: | |
cfg (Config object): Input config. | |
is_inference (bool): Training if False, else validation. | |
is_test (bool): Final test set after training and validation. | |
""" | |
def __init__(self, cfg, is_inference, is_test): | |
super(BaseDataset, self).__init__() | |
self.cfg = cfg | |
self.is_inference = is_inference | |
self.is_test = is_test | |
if self.is_test: | |
self.cfgdata = self.cfg.test_data | |
data_info = self.cfgdata.test | |
else: | |
self.cfgdata = self.cfg.data | |
if self.is_inference: | |
data_info = self.cfgdata.val | |
else: | |
data_info = self.cfgdata.train | |
self.name = self.cfgdata.name | |
self.lmdb_roots = data_info.roots | |
self.dataset_type = getattr(data_info, 'dataset_type', None) | |
self.cache = getattr(self.cfgdata, 'cache', None) | |
self.interpolator = getattr(self.cfgdata, 'interpolator', "INTER_LINEAR") | |
# Get AWS secret keys. | |
if self.dataset_type == 'object_store': | |
assert hasattr(cfg, 'aws_credentials_file') | |
self.aws_credentials_file = cfg.aws_credentials_file | |
# Legacy lmdb/folder only support. | |
if self.dataset_type is None: | |
self.dataset_is_lmdb = getattr(data_info, 'is_lmdb', False) | |
if self.dataset_is_lmdb: | |
self.dataset_type = 'lmdb' | |
else: | |
self.dataset_type = 'folder' | |
# Legacy support ends. | |
assert self.dataset_type in DATASET_TYPES | |
if self.dataset_type == 'lmdb': | |
# Add handle to function to load data from LMDB. | |
self.load_from_dataset = load_from_lmdb | |
elif self.dataset_type == 'folder': | |
# For some unpaired experiments, we would like the dataset to be presented in a paired way | |
if hasattr(self.cfgdata, 'paired') is False: | |
self.cfgdata.paired = self.paired | |
# Add handle to function to load data from folder. | |
self.load_from_dataset = load_from_folder | |
# Create metadata for folders. | |
print('Creating metadata') | |
all_filenames, all_metadata = [], [] | |
if self.is_test: | |
cfg.data_backup = cfg.data | |
cfg.data = cfg.test_data | |
for root in self.lmdb_roots: | |
filenames, metadata = create_metadata( | |
data_root=root, cfg=cfg, paired=self.cfgdata['paired']) | |
all_filenames.append(filenames) | |
all_metadata.append(metadata) | |
if self.is_test: | |
cfg.data = cfg.data_backup | |
elif self.dataset_type == 'object_store': | |
# Add handle to function to load data from AWS S3. | |
self.load_from_dataset = load_from_object_store | |
# Get the types of data stored in dataset, and their extensions. | |
self.data_types = [] # Names of data types. | |
self.dataset_data_types = [] # These data types are in the dataset. | |
self.image_data_types = [] # These types are images. | |
self.hdr_image_data_types = [] # These types are HDR images. | |
self.normalize = {} # Does this data type need normalization? | |
self.extensions = {} # What is this data type's file extension. | |
self.is_mask = {} # Whether this data type is discrete masks? | |
self.num_channels = {} # How many channels does this data type have? | |
self.pre_aug_ops = {} # Ops on data type before augmentation. | |
self.post_aug_ops = {} # Ops on data type after augmentation. | |
# Extract info from data types. | |
for data_type in self.cfgdata.input_types: | |
name = list(data_type.keys()) | |
assert len(name) == 1 | |
name = name[0] | |
info = data_type[name] | |
if 'ext' not in info: | |
info['ext'] = None | |
if 'normalize' not in info: | |
info['normalize'] = False | |
if 'is_mask' not in info: | |
info['is_mask'] = False | |
if 'pre_aug_ops' not in info: | |
info['pre_aug_ops'] = 'None' | |
if 'post_aug_ops' not in info: | |
info['post_aug_ops'] = 'None' | |
if 'computed_on_the_fly' not in info: | |
info['computed_on_the_fly'] = False | |
if 'num_channels' not in info: | |
info['num_channels'] = None | |
self.data_types.append(name) | |
if not info['computed_on_the_fly']: | |
self.dataset_data_types.append(name) | |
self.extensions[name] = info['ext'] | |
self.normalize[name] = info['normalize'] | |
self.num_channels[name] = info['num_channels'] | |
self.pre_aug_ops[name] = [op.strip() for op in | |
info['pre_aug_ops'].split(',')] | |
self.post_aug_ops[name] = [op.strip() for op in | |
info['post_aug_ops'].split(',')] | |
self.is_mask[name] = info['is_mask'] | |
if info['ext'] is not None and (info['ext'] in IMG_EXTENSIONS or info['ext'] in VIDEO_EXTENSIONS): | |
self.image_data_types.append(name) | |
if info['ext'] is not None and info['ext'] in HDR_IMG_EXTENSIONS: | |
self.hdr_image_data_types.append(name) | |
# Add some info into cfgdata for legacy support. | |
self.cfgdata.data_types = self.data_types | |
self.cfgdata.num_channels = [self.num_channels[name] | |
for name in self.data_types] | |
# Augmentations which need full dict. | |
self.full_data_post_aug_ops, self.full_data_ops = [], [] | |
if hasattr(self.cfgdata, 'full_data_ops'): | |
ops = self.cfgdata.full_data_ops | |
self.full_data_ops.extend([op.strip() for op in ops.split(',')]) | |
if hasattr(self.cfgdata, 'full_data_post_aug_ops'): | |
ops = self.cfgdata.full_data_post_aug_ops | |
self.full_data_post_aug_ops.extend( | |
[op.strip() for op in ops.split(',')]) | |
# These are the labels which will be concatenated for generator input. | |
self.input_labels = [] | |
if hasattr(self.cfgdata, 'input_labels'): | |
self.input_labels = self.cfgdata.input_labels | |
# These are the keypoints which also need to be augmented. | |
self.keypoint_data_types = [] | |
if hasattr(self.cfgdata, 'keypoint_data_types'): | |
self.keypoint_data_types = self.cfgdata.keypoint_data_types | |
# Create augmentation operations. | |
aug_list = data_info.augmentations | |
individual_video_frame_aug_list = getattr(data_info, 'individual_video_frame_augmentations', dict()) | |
self.augmentor = Augmentor( | |
aug_list, individual_video_frame_aug_list, self.image_data_types, self.is_mask, | |
self.keypoint_data_types, self.interpolator) | |
self.augmentable_types = self.image_data_types + \ | |
self.keypoint_data_types | |
# Create torch transformations. | |
self.transform = {} | |
for data_type in self.image_data_types: | |
normalize = self.normalize[data_type] | |
self.transform[data_type] = self._get_transform( | |
normalize, self.num_channels[data_type]) | |
# Create torch transformations for HDR images. | |
for data_type in self.hdr_image_data_types: | |
normalize = self.normalize[data_type] | |
self.transform[data_type] = self._get_transform( | |
normalize, self.num_channels[data_type]) | |
# Initialize handles. | |
self.sequence_lists = [] # List of sequences per dataset root. | |
self.lmdbs = {} # Dict for list of lmdb handles per data type. | |
for data_type in self.dataset_data_types: | |
self.lmdbs[data_type] = [] | |
self.dataset_probability = None | |
self.additional_lists = [] | |
# Load each dataset. | |
for idx, root in enumerate(self.lmdb_roots): | |
if self.dataset_type == 'lmdb': | |
self._add_dataset(root) | |
elif self.dataset_type == 'folder': | |
self._add_dataset(root, filenames=all_filenames[idx], | |
metadata=all_metadata[idx]) | |
elif self.dataset_type == 'object_store': | |
self._add_dataset( | |
root, aws_credentials_file=self.aws_credentials_file) | |
# Compute dataset statistics and create whatever self.variables required | |
# for the specific dataloader. | |
self._compute_dataset_stats() | |
# Build index of data to sample. | |
self.mapping, self.epoch_length = self._create_mapping() | |
def _create_mapping(self): | |
r"""Creates mapping from data sample idx to actual LMDB keys. | |
All children need to implement their own. | |
Returns: | |
self.mapping (list): List of LMDB keys. | |
""" | |
raise NotImplementedError | |
def _compute_dataset_stats(self): | |
r"""Computes required statistics about dataset. | |
All children need to implement their own. | |
""" | |
pass | |
def __getitem__(self, index): | |
r"""Entry function for dataset.""" | |
raise NotImplementedError | |
def _get_transform(self, normalize, num_channels): | |
r"""Convert numpy to torch tensor. | |
Args: | |
normalize (bool): Normalize image i.e. (x - 0.5) * 2. | |
Goes from [0, 1] -> [-1, 1]. | |
Returns: | |
Composed list of torch transforms. | |
""" | |
transform_list = [transforms.ToTensor()] | |
if normalize: | |
transform_list.append( | |
transforms.Normalize((0.5, ) * num_channels, | |
(0.5, ) * num_channels, inplace=True)) | |
return transforms.Compose(transform_list) | |
def _add_dataset(self, root, filenames=None, metadata=None, | |
aws_credentials_file=None): | |
r"""Adds an LMDB dataset to a list of datasets. | |
Args: | |
root (str): Path to LMDB or folder dataset. | |
filenames: List of filenames for folder dataset. | |
metadata: Metadata for folder dataset. | |
aws_credentials_file: Path to file containing AWS credentials. | |
""" | |
if aws_credentials_file and self.dataset_type == 'object_store': | |
object_store_dataset = ObjectStoreDataset( | |
root, aws_credentials_file, cache=self.cache) | |
sequence_list = object_store_dataset.sequence_list | |
else: | |
# Get sequences associated with this dataset. | |
if filenames is None: | |
list_path = 'all_filenames.json' | |
with open(os.path.join(root, list_path)) as fin: | |
sequence_list = OrderedDict(json.load(fin)) | |
else: | |
sequence_list = filenames | |
additional_path = 'all_indices.json' | |
if os.path.exists(os.path.join(root, additional_path)): | |
print('Using additional list for object indices.') | |
with open(os.path.join(root, additional_path)) as fin: | |
additional_list = OrderedDict(json.load(fin)) | |
self.additional_lists.append(additional_list) | |
self.sequence_lists.append(sequence_list) | |
# Get LMDB dataset handles. | |
for data_type in self.dataset_data_types: | |
if self.dataset_type == 'lmdb': | |
self.lmdbs[data_type].append( | |
LMDBDataset(os.path.join(root, data_type))) | |
elif self.dataset_type == 'folder': | |
self.lmdbs[data_type].append( | |
FolderDataset(os.path.join(root, data_type), metadata)) | |
elif self.dataset_type == 'object_store': | |
# All data types use the same handle. | |
self.lmdbs[data_type].append(object_store_dataset) | |
def perform_individual_video_frame(self, data, augment_ops): | |
r"""Perform data augmentation on images only. | |
Args: | |
data (dict): Keys are from data types. Values can be numpy.ndarray | |
or list of numpy.ndarray (image or list of images). | |
augment_ops (list): The augmentation operations for individual frames. | |
Returns: | |
(tuple): | |
- data (dict): Augmented data, with same keys as input data. | |
- is_flipped (bool): Flag which tells if images have been | |
left-right flipped. | |
""" | |
if augment_ops: | |
all_data = dict() | |
for ix, key in enumerate(data.keys()): | |
if ix == 0: | |
num = len(data[key]) | |
for j in range(num): | |
all_data['%d' % j] = dict() | |
for j in range(num): | |
all_data['%d' % j][key] = data[key][j:(j+1)] | |
for j in range(num): | |
all_data['%d' % j], _ = self.perform_augmentation( | |
all_data['%d' % j], paired=True, augment_ops=augment_ops) | |
for key in data.keys(): | |
tmp = [] | |
for j in range(num): | |
tmp += all_data['%d' % j][key] | |
data[key] = tmp | |
return data | |
def perform_augmentation(self, data, paired, augment_ops=None): | |
r"""Perform data augmentation on images only. | |
Args: | |
data (dict): Keys are from data types. Values can be numpy.ndarray | |
or list of numpy.ndarray (image or list of images). | |
paired (bool): Apply same augmentation to all input keys? | |
augment_ops (list): The augmentation operations. | |
Returns: | |
(tuple): | |
- data (dict): Augmented data, with same keys as input data. | |
- is_flipped (bool): Flag which tells if images have been | |
left-right flipped. | |
""" | |
aug_inputs = {} | |
for data_type in self.augmentable_types: | |
aug_inputs[data_type] = data[data_type] | |
augmented, is_flipped = self.augmentor.perform_augmentation( | |
aug_inputs, paired=paired, augment_ops=augment_ops) | |
for data_type in self.augmentable_types: | |
data[data_type] = augmented[data_type] | |
return data, is_flipped | |
def flip_hdr(self, data, is_flipped=False): | |
r"""Flip hdr images. | |
Args: | |
data (dict): Keys are from data types. Values can be numpy.ndarray | |
or list of numpy.ndarray (image or list of images). | |
is_flipped (bool): Applying left-right flip to the hdr images | |
Returns: | |
(tuple): | |
- data (dict): Augmented data, with same keys as input data. | |
""" | |
if is_flipped is False: | |
return data | |
for data_type in self.hdr_image_data_types: | |
# print('Length of data: {}'.format(len(data[data_type]))) | |
data[data_type][0] = data[data_type][0][:, ::-1, :].copy() | |
return data | |
def to_tensor(self, data): | |
r"""Convert all images to tensor. | |
Args: | |
data (dict): Dict containing data_type as key, with each value | |
as a list of numpy.ndarrays. | |
Returns: | |
data (dict): Dict containing data_type as key, with each value | |
as a list of torch.Tensors. | |
""" | |
for data_type in self.image_data_types: | |
for idx in range(len(data[data_type])): | |
if data[data_type][idx].dtype == np.uint16: | |
data[data_type][idx] = data[data_type][idx].astype( | |
np.float32) | |
data[data_type][idx] = self.transform[data_type]( | |
data[data_type][idx]) | |
for data_type in self.hdr_image_data_types: | |
for idx in range(len(data[data_type])): | |
data[data_type][idx] = self.transform[data_type]( | |
data[data_type][idx]) | |
return data | |
def apply_ops(self, data, op_dict, full_data=False): | |
r"""Apply any ops from op_dict to data types. | |
Args: | |
data (dict): Dict containing data_type as key, with each value | |
as a list of numpy.ndarrays. | |
op_dict (dict): Dict containing data_type as key, with each value | |
containing string of operations to apply. | |
full_data (bool): Do these ops require access to the full data? | |
Returns: | |
data (dict): Dict containing data_type as key, with each value | |
modified by the op if any. | |
""" | |
if full_data: | |
# op needs entire data dict. | |
for op in op_dict: | |
if op == 'None': | |
continue | |
op, op_type = self.get_op(op) | |
assert op_type == 'full_data' | |
data = op(data) | |
else: | |
# op per data type. | |
if not op_dict: | |
return data | |
for data_type in data: | |
for op in op_dict[data_type]: | |
if op == 'None': | |
continue | |
op, op_type = self.get_op(op) | |
data[data_type] = op(data[data_type]) | |
if op_type == 'vis': | |
# We have converted this data type to an image. Enter it | |
# in self.image_data_types and give it a torch | |
# transform. | |
if data_type not in self.image_data_types: | |
self.image_data_types.append(data_type) | |
normalize = self.normalize[data_type] | |
num_channels = self.num_channels[data_type] | |
self.transform[data_type] = \ | |
self._get_transform(normalize, num_channels) | |
elif op_type == 'convert': | |
continue | |
elif op_type is None: | |
continue | |
else: | |
raise NotImplementedError | |
return data | |
def get_op(self, op): | |
r"""Get function to apply for specific op. | |
Args: | |
op (str): Name of the op. | |
Returns: | |
function handle. | |
""" | |
def list_to_tensor(data): | |
r"""Convert list of numeric values to tensor.""" | |
assert isinstance(data, list) | |
return torch.from_numpy(np.array(data, dtype=np.float32)) | |
def decode_json_list(data): | |
r"""Decode list of strings in json to objects.""" | |
assert isinstance(data, list) | |
return [json.loads(item) for item in data] | |
def decode_pkl_list(data): | |
r"""Decode list of pickled strings to objects.""" | |
assert isinstance(data, list) | |
return [pickle.loads(item) for item in data] | |
def list_to_numpy(data): | |
r"""Convert list of numeric values to numpy array.""" | |
assert isinstance(data, list) | |
return np.array(data) | |
def l2_normalize(data): | |
r"""L2 normalization.""" | |
assert isinstance(data, torch.Tensor) | |
import torch.nn.functional as F | |
return F.normalize(data, dim=1) | |
if op == 'to_tensor': | |
return list_to_tensor, None | |
elif op == 'decode_json': | |
return decode_json_list, None | |
elif op == 'decode_pkl': | |
return decode_pkl_list, None | |
elif op == 'to_numpy': | |
return list_to_numpy, None | |
elif op == 'l2_norm': | |
return l2_normalize, None | |
elif '::' in op: | |
parts = op.split('::') | |
if len(parts) == 2: | |
module, function = parts | |
module = importlib.import_module(module) | |
function = getattr(module, function) | |
sig = signature(function) | |
num_params = len(sig.parameters) | |
assert num_params in [3, 4], \ | |
'Full data functions take in (cfgdata, is_inference, ' \ | |
'full_data) or (cfgdata, is_inference, self, full_data) ' \ | |
'as input.' | |
if num_params == 3: | |
function = partial( | |
function, self.cfgdata, self.is_inference) | |
elif num_params == 4: | |
function = partial( | |
function, self.cfgdata, self.is_inference, self) | |
function_type = 'full_data' | |
elif len(parts) == 3: | |
function_type, module, function = parts | |
module = importlib.import_module(module) | |
# Get function inputs, if provided. | |
partial_fn = False | |
if '(' in function and ')' in function: | |
partial_fn = True | |
function, params = self._get_fn_params(function) | |
function = getattr(module, function) | |
# Create partial function. | |
if partial_fn: | |
function = partial(function, **params) | |
# Get function signature. | |
sig = signature(function) | |
num_params = 0 | |
for param in sig.parameters.values(): | |
if param.kind == param.POSITIONAL_OR_KEYWORD: | |
num_params += 1 | |
if function_type == 'vis': | |
if num_params != 9: | |
raise ValueError( | |
'vis function type needs to take ' + | |
'(resize_h, resize_w, crop_h, crop_w, ' + | |
'original_h, original_w, is_flipped, cfgdata, ' + | |
'data) as input.') | |
function = partial(function, | |
self.augmentor.resize_h, | |
self.augmentor.resize_w, | |
self.augmentor.crop_h, | |
self.augmentor.crop_w, | |
self.augmentor.original_h, | |
self.augmentor.original_w, | |
self.augmentor.is_flipped, | |
self.cfgdata) | |
elif function_type == 'convert': | |
if num_params != 1: | |
raise ValueError( | |
'convert function type needs to take ' + | |
'(data) as input.') | |
else: | |
raise ValueError('Unknown op: %s' % (op)) | |
else: | |
raise ValueError('Unknown op: %s' % (op)) | |
return function, function_type | |
else: | |
raise ValueError('Unknown op: %s' % (op)) | |
def _get_fn_params(self, function_string): | |
r"""Find key-value inputs to function from string definition. | |
Args: | |
function_string (str): String with function name and args. e.g. | |
my_function(a=10, b=20). | |
Returns: | |
function (str): Name of function. | |
params (dict): Key-value params for function. | |
""" | |
start = function_string.find('(') | |
end = function_string.find(')') | |
function = function_string[:start] | |
params_str = function_string[start+1:end] | |
params = {} | |
for item in params_str.split(':'): | |
key, value = item.split('=') | |
try: | |
params[key] = float(value) | |
except: # noqa | |
params[key] = value | |
return function, params | |
def __len__(self): | |
return self.epoch_length | |