|
"""Generic dataset loader""" |
|
|
|
from importlib import import_module |
|
|
|
from torch.utils.data import DataLoader |
|
from torch.utils.data import SequentialSampler, RandomSampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
from .sampler import DistributedEvalSampler |
|
|
|
class Data(): |
|
def __init__(self, args): |
|
|
|
self.modes = ['train', 'val', 'test', 'demo'] |
|
|
|
self.action = { |
|
'train': args.do_train, |
|
'val': args.do_validate, |
|
'test': args.do_test, |
|
'demo': args.demo |
|
} |
|
|
|
self.dataset_name = { |
|
'train': args.data_train, |
|
'val': args.data_val, |
|
'test': args.data_test, |
|
'demo': 'Demo' |
|
} |
|
|
|
self.args = args |
|
|
|
def _get_data_loader(mode='train'): |
|
dataset_name = self.dataset_name[mode] |
|
dataset = import_module('data.' + dataset_name.lower()) |
|
dataset = getattr(dataset, dataset_name)(args, mode) |
|
|
|
if mode == 'train': |
|
if args.distributed: |
|
batch_size = int(args.batch_size / args.n_GPUs) |
|
sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank) |
|
num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) |
|
else: |
|
batch_size = args.batch_size |
|
sampler = RandomSampler(dataset, replacement=False) |
|
num_workers = args.num_workers |
|
drop_last = True |
|
|
|
elif mode in ('val', 'test', 'demo'): |
|
if args.distributed: |
|
batch_size = 1 |
|
sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank) |
|
num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) |
|
else: |
|
batch_size = args.n_GPUs |
|
sampler = SequentialSampler(dataset) |
|
num_workers = args.num_workers |
|
drop_last = False |
|
|
|
loader = DataLoader( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
sampler=sampler, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
drop_last=drop_last, |
|
) |
|
|
|
return loader |
|
|
|
self.loaders = {} |
|
for mode in self.modes: |
|
if self.action[mode]: |
|
self.loaders[mode] = _get_data_loader(mode) |
|
print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode])) |
|
else: |
|
self.loaders[mode] = None |
|
|
|
def get_loader(self): |
|
return self.loaders |
|
|