Spaces:
Build error
Build error
# python3.7 | |
"""Contains the class of data loader.""" | |
import argparse | |
from torch.utils.data import DataLoader | |
from .distributed_sampler import DistributedSampler | |
from .datasets import BaseDataset | |
__all__ = ['IterDataLoader'] | |
class IterDataLoader(object): | |
"""Iteration-based data loader.""" | |
def __init__(self, | |
dataset, | |
batch_size, | |
shuffle=True, | |
num_workers=1, | |
current_iter=0, | |
repeat=1): | |
"""Initializes the data loader. | |
Args: | |
dataset: The dataset to load data from. | |
batch_size: The batch size on each GPU. | |
shuffle: Whether to shuffle the data. (default: True) | |
num_workers: Number of data workers for each GPU. (default: 1) | |
current_iter: The current number of iterations. (default: 0) | |
repeat: The repeating number of the whole dataloader. (default: 1) | |
""" | |
self._dataset = dataset | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.num_workers = num_workers | |
self._dataloader = None | |
self.iter_loader = None | |
self._iter = current_iter | |
self.repeat = repeat | |
self.build_dataloader() | |
def build_dataloader(self): | |
"""Builds data loader.""" | |
dist_sampler = DistributedSampler(self._dataset, | |
shuffle=self.shuffle, | |
current_iter=self._iter, | |
repeat=self.repeat) | |
self._dataloader = DataLoader(self._dataset, | |
batch_size=self.batch_size, | |
shuffle=(dist_sampler is None), | |
num_workers=self.num_workers, | |
drop_last=self.shuffle, | |
pin_memory=True, | |
sampler=dist_sampler) | |
self.iter_loader = iter(self._dataloader) | |
def overwrite_param(self, batch_size=None, resolution=None): | |
"""Overwrites some parameters for progressive training.""" | |
if (not batch_size) and (not resolution): | |
return | |
if (batch_size == self.batch_size) and ( | |
resolution == self.dataset.resolution): | |
return | |
if batch_size: | |
self.batch_size = batch_size | |
if resolution: | |
self._dataset.resolution = resolution | |
self.build_dataloader() | |
def iter(self): | |
"""Returns the current iteration.""" | |
return self._iter | |
def dataset(self): | |
"""Returns the dataset.""" | |
return self._dataset | |
def dataloader(self): | |
"""Returns the data loader.""" | |
return self._dataloader | |
def __next__(self): | |
try: | |
data = next(self.iter_loader) | |
self._iter += 1 | |
except StopIteration: | |
self._dataloader.sampler.__reset__(self._iter) | |
self.iter_loader = iter(self._dataloader) | |
data = next(self.iter_loader) | |
self._iter += 1 | |
return data | |
def __len__(self): | |
return len(self._dataloader) | |
def dataloader_test(root_dir, test_num=10): | |
"""Tests data loader.""" | |
res = 2 | |
bs = 2 | |
dataset = BaseDataset(root_dir=root_dir, resolution=res) | |
dataloader = IterDataLoader(dataset=dataset, | |
batch_size=bs, | |
shuffle=False) | |
for _ in range(test_num): | |
data_batch = next(dataloader) | |
image = data_batch['image'] | |
assert image.shape == (bs, 3, res, res) | |
res *= 2 | |
bs += 1 | |
dataloader.overwrite_param(batch_size=bs, resolution=res) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Test Data Loader.') | |
parser.add_argument('root_dir', type=str, | |
help='Root directory of the dataset.') | |
parser.add_argument('--test_num', type=int, default=10, | |
help='Number of tests. (default: %(default)s)') | |
args = parser.parse_args() | |
dataloader_test(args.root_dir, args.test_num) | |