james-oldfield's picture
Upload 194 files
2a76164
# python3.7
"""Contains the class of data loader."""
import argparse
from torch.utils.data import DataLoader
from .distributed_sampler import DistributedSampler
from .datasets import BaseDataset
__all__ = ['IterDataLoader']
class IterDataLoader(object):
"""Iteration-based data loader."""
def __init__(self,
dataset,
batch_size,
shuffle=True,
num_workers=1,
current_iter=0,
repeat=1):
"""Initializes the data loader.
Args:
dataset: The dataset to load data from.
batch_size: The batch size on each GPU.
shuffle: Whether to shuffle the data. (default: True)
num_workers: Number of data workers for each GPU. (default: 1)
current_iter: The current number of iterations. (default: 0)
repeat: The repeating number of the whole dataloader. (default: 1)
"""
self._dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self._dataloader = None
self.iter_loader = None
self._iter = current_iter
self.repeat = repeat
self.build_dataloader()
def build_dataloader(self):
"""Builds data loader."""
dist_sampler = DistributedSampler(self._dataset,
shuffle=self.shuffle,
current_iter=self._iter,
repeat=self.repeat)
self._dataloader = DataLoader(self._dataset,
batch_size=self.batch_size,
shuffle=(dist_sampler is None),
num_workers=self.num_workers,
drop_last=self.shuffle,
pin_memory=True,
sampler=dist_sampler)
self.iter_loader = iter(self._dataloader)
def overwrite_param(self, batch_size=None, resolution=None):
"""Overwrites some parameters for progressive training."""
if (not batch_size) and (not resolution):
return
if (batch_size == self.batch_size) and (
resolution == self.dataset.resolution):
return
if batch_size:
self.batch_size = batch_size
if resolution:
self._dataset.resolution = resolution
self.build_dataloader()
@property
def iter(self):
"""Returns the current iteration."""
return self._iter
@property
def dataset(self):
"""Returns the dataset."""
return self._dataset
@property
def dataloader(self):
"""Returns the data loader."""
return self._dataloader
def __next__(self):
try:
data = next(self.iter_loader)
self._iter += 1
except StopIteration:
self._dataloader.sampler.__reset__(self._iter)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
self._iter += 1
return data
def __len__(self):
return len(self._dataloader)
def dataloader_test(root_dir, test_num=10):
"""Tests data loader."""
res = 2
bs = 2
dataset = BaseDataset(root_dir=root_dir, resolution=res)
dataloader = IterDataLoader(dataset=dataset,
batch_size=bs,
shuffle=False)
for _ in range(test_num):
data_batch = next(dataloader)
image = data_batch['image']
assert image.shape == (bs, 3, res, res)
res *= 2
bs += 1
dataloader.overwrite_param(batch_size=bs, resolution=res)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test Data Loader.')
parser.add_argument('root_dir', type=str,
help='Root directory of the dataset.')
parser.add_argument('--test_num', type=int, default=10,
help='Number of tests. (default: %(default)s)')
args = parser.parse_args()
dataloader_test(args.root_dir, args.test_num)