SivaResearch's picture
demo
b6d5990
raw
history blame
617 Bytes
import torch.utils.data
class DataProvider():
def __init__(self, cfg, dataset, batch_size=None, shuffle=True):
super().__init__()
self.dataset = dataset
if batch_size is None:
batch_size = cfg.BATCH_SIZE
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=int(cfg.WORKERS),
drop_last=False)
def __len__(self):
return len(self.dataset)
def __iter__(self):
for i, data in enumerate(self.dataloader):
yield data