Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| import torch.utils.data | |
| from importlib import import_module | |
| def create_dataloader(phase, dataset, dataset_opt, opt=None, sampler=None): | |
| logger = logging.getLogger('base') | |
| if phase == 'train': | |
| num_workers = dataset_opt['n_workers'] * opt['world_size'] | |
| batch_size = dataset_opt['batch_size'] | |
| if sampler is not None: | |
| logger.info('N_workers: {}, batch_size: {} DDP train dataloader has been established'.format(num_workers, | |
| batch_size)) | |
| return torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
| num_workers=num_workers, sampler=sampler, | |
| pin_memory=True) | |
| else: | |
| logger.info('N_workers: {}, batch_size: {} train dataloader has been established'.format(num_workers, | |
| batch_size)) | |
| return torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
| num_workers=num_workers, shuffle=True, | |
| pin_memory=True) | |
| else: | |
| logger.info( | |
| 'N_workers: {}, batch_size: {} validate/test dataloader has been established'.format( | |
| dataset_opt['n_workers'], | |
| dataset_opt['batch_size'])) | |
| return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False, | |
| num_workers=dataset_opt['n_workers'], | |
| pin_memory=False) | |
| def create_dataset(dataset_opt, dataInfo, phase, dataset_name): | |
| if phase == 'train': | |
| dataset_package = import_module('data.{}'.format(dataset_name)) | |
| dataset = dataset_package.VideoBasedDataset(dataset_opt, dataInfo) | |
| mode = dataset_opt['mode'] | |
| logger = logging.getLogger('base') | |
| logger.info( | |
| '{} train dataset [{:s} - {:s} - {:s}] is created.'.format(dataset_opt['type'].upper(), | |
| dataset.__class__.__name__, | |
| dataset_opt['name'], mode)) | |
| else: # validate and test dataset | |
| return ValueError('No dataset initialized for valdataset') | |
| return dataset | |