lambdanet / deblur /src /data /__init__.py
hyliu's picture
Upload folder using huggingface_hub
e98653e verified
"""Generic dataset loader"""
from importlib import import_module
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from .sampler import DistributedEvalSampler
class Data():
def __init__(self, args):
self.modes = ['train', 'val', 'test', 'demo']
self.action = {
'train': args.do_train,
'val': args.do_validate,
'test': args.do_test,
'demo': args.demo
}
self.dataset_name = {
'train': args.data_train,
'val': args.data_val,
'test': args.data_test,
'demo': 'Demo'
}
self.args = args
def _get_data_loader(mode='train'):
dataset_name = self.dataset_name[mode]
dataset = import_module('data.' + dataset_name.lower())
dataset = getattr(dataset, dataset_name)(args, mode)
if mode == 'train':
if args.distributed:
batch_size = int(args.batch_size / args.n_GPUs) # batch size per GPU (single-node training)
sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank)
num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
else:
batch_size = args.batch_size
sampler = RandomSampler(dataset, replacement=False)
num_workers = args.num_workers
drop_last = True
elif mode in ('val', 'test', 'demo'):
if args.distributed:
batch_size = 1 # 1 image per GPU
sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank)
num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
else:
batch_size = args.n_GPUs # 1 image per GPU
sampler = SequentialSampler(dataset)
num_workers = args.num_workers
drop_last = False
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
drop_last=drop_last,
)
return loader
self.loaders = {}
for mode in self.modes:
if self.action[mode]:
self.loaders[mode] = _get_data_loader(mode)
print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode]))
else:
self.loaders[mode] = None
def get_loader(self):
return self.loaders