import torch from utils.dataset import get_dataset batch_size = 10 kwargs = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 2, 'pin_memory': True} def get_dataloader(): test_data = get_dataset() test_loader = torch.utils.data.DataLoader(test_data, **kwargs) # train_loader = torch.utils.data.DataLoader(train_data, **kwargs) return test_loader