haakohu's picture
fix
44539fc
import torch
import tops
from .utils import collate_fn
def get_dataloader(
dataset, gpu_transform: torch.nn.Module,
num_workers,
batch_size,
infinite: bool,
drop_last: bool,
prefetch_factor: int,
shuffle,
channels_last=False
):
sampler = None
dl_kwargs = dict(
pin_memory=True,
)
if infinite:
sampler = tops.InfiniteSampler(
dataset, rank=tops.rank(),
num_replicas=tops.world_size(),
shuffle=shuffle
)
elif tops.world_size() > 1:
sampler = torch.utils.data.DistributedSampler(
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
dl_kwargs["drop_last"] = drop_last
else:
dl_kwargs["shuffle"] = shuffle
dl_kwargs["drop_last"] = drop_last
dataloader = torch.utils.data.DataLoader(
dataset, sampler=sampler, collate_fn=collate_fn,
batch_size=batch_size,
num_workers=num_workers, prefetch_factor=prefetch_factor,
**dl_kwargs
)
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
return dataloader