s12 / utils /dataloader.py
srikanthp07's picture
Upload 27 files
9022436
raw
history blame contribute delete
No virus
374 Bytes
import torch
from utils.dataset import get_dataset
batch_size = 10
kwargs = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 2, 'pin_memory': True}
def get_dataloader():
test_data = get_dataset()
test_loader = torch.utils.data.DataLoader(test_data, **kwargs)
# train_loader = torch.utils.data.DataLoader(train_data, **kwargs)
return test_loader