sat3density / imaginaire /datasets /paired_images.py
venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from imaginaire.datasets.paired_videos import Dataset as VideoDataset
class Dataset(VideoDataset):
r"""Paired image dataset for use in pix2pixHD, SPADE.
Args:
cfg (Config): Loaded config object.
is_inference (bool): In train or inference mode?
"""
def __init__(self, cfg, is_inference=False, is_test=False):
self.paired = True
super(Dataset, self).__init__(cfg, is_inference,
sequence_length=1,
is_test=is_test)
self.is_video_dataset = False
def _create_mapping(self):
r"""Creates mapping from idx to key in LMDB.
Returns:
(tuple):
- self.mapping (list): List mapping idx to key.
- self.epoch_length (int): Number of samples in an epoch.
"""
idx_to_key = []
for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
for sequence_name, filenames in sequence_list.items():
for filename in filenames:
idx_to_key.append({
'lmdb_root': self.lmdb_roots[lmdb_idx],
'lmdb_idx': lmdb_idx,
'sequence_name': sequence_name,
'filenames': [filename],
})
self.mapping = idx_to_key
self.epoch_length = len(self.mapping)
return self.mapping, self.epoch_length
def _sample_keys(self, index):
r"""Gets files to load for this sample.
Args:
index (int): Index in [0, len(dataset)].
Returns:
key (dict):
- lmdb_idx (int): Chosen LMDB dataset root.
- sequence_name (str): Chosen sequence in chosen dataset.
- filenames (list of str): Chosen filenames in chosen sequence.
"""
assert self.sequence_length == 1, \
'Image dataset can only have sequence length = 1, not %d' % (
self.sequence_length)
return self.mapping[index]
def set_sequence_length(self, sequence_length):
r"""Set the length of sequence you want as output from dataloader.
Ignore this as this is an image loader.
Args:
sequence_length (int): Length of output sequences.
"""
pass
def set_inference_sequence_idx(self, index):
r"""Get frames from this sequence during inference.
Overriden from super as this is not applicable for images.
Args:
index (int): Index of inference sequence.
"""
raise RuntimeError('Image dataset does not have sequences.')
def num_inference_sequences(self):
r"""Number of sequences available for inference.
Overriden from super as this is not applicable for images.
Returns:
(int)
"""
raise RuntimeError('Image dataset does not have sequences.')