|
|
|
|
|
|
|
import torch |
|
from .datasets.ab_dataset import ABDataset |
|
|
|
|
|
class _InfiniteSampler(torch.utils.data.Sampler): |
|
"""Wraps another Sampler to yield an infinite stream.""" |
|
|
|
def __init__(self, sampler): |
|
self.sampler = sampler |
|
|
|
def __iter__(self): |
|
while True: |
|
for batch in self.sampler: |
|
yield batch |
|
|
|
|
|
class InfiniteDataLoader: |
|
def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None): |
|
super().__init__() |
|
|
|
if weights: |
|
sampler = torch.utils.data.WeightedRandomSampler( |
|
weights, replacement=True, num_samples=batch_size |
|
) |
|
else: |
|
sampler = torch.utils.data.RandomSampler(dataset, replacement=True) |
|
|
|
batch_sampler = torch.utils.data.BatchSampler( |
|
sampler, batch_size=batch_size, drop_last=True |
|
) |
|
|
|
if collate_fn is not None: |
|
self._infinite_iterator = iter( |
|
torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=_InfiniteSampler(batch_sampler), |
|
pin_memory=False, |
|
collate_fn=collate_fn |
|
) |
|
) |
|
else: |
|
self._infinite_iterator = iter( |
|
torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=_InfiniteSampler(batch_sampler), |
|
pin_memory=False |
|
) |
|
) |
|
self.dataset = dataset |
|
|
|
def __iter__(self): |
|
while True: |
|
yield next(self._infinite_iterator) |
|
|
|
def __len__(self): |
|
raise ValueError |
|
|
|
|
|
class FastDataLoader: |
|
""" |
|
DataLoader wrapper with slightly improved speed by not respawning worker |
|
processes at every epoch. |
|
""" |
|
|
|
def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None): |
|
super().__init__() |
|
|
|
self.num_workers = num_workers |
|
|
|
if shuffle: |
|
sampler = torch.utils.data.RandomSampler(dataset, replacement=False) |
|
else: |
|
sampler = torch.utils.data.SequentialSampler(dataset) |
|
|
|
batch_sampler = torch.utils.data.BatchSampler( |
|
sampler, |
|
batch_size=batch_size, |
|
drop_last=False, |
|
) |
|
if collate_fn is not None: |
|
self._infinite_iterator = iter( |
|
torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=_InfiniteSampler(batch_sampler), |
|
pin_memory=False, |
|
collate_fn=collate_fn |
|
) |
|
) |
|
else: |
|
self._infinite_iterator = iter( |
|
torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=_InfiniteSampler(batch_sampler), |
|
pin_memory=False, |
|
) |
|
) |
|
|
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self._length = len(batch_sampler) |
|
|
|
def __iter__(self): |
|
for _ in range(len(self)): |
|
yield next(self._infinite_iterator) |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
|
|
def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None): |
|
assert batch_size <= len(dataset), len(dataset) |
|
if infinite: |
|
dataloader = InfiniteDataLoader( |
|
dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn) |
|
else: |
|
dataloader = FastDataLoader( |
|
dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn) |
|
|
|
return dataloader |
|
|
|
|
|
def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool): |
|
pass |
|
|