# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). from .utils.transforms import * from .base.batched_sampler import BatchedRandomSampler # noqa: F401 from .co3d import Co3d # noqa: F401 def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): import torch from croco.utils.misc import get_world_size, get_rank # pytorch dataset if isinstance(dataset, str): dataset = eval(dataset) world_size = get_world_size() rank = get_rank() try: sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, rank=rank, drop_last=drop_last) except (AttributeError, NotImplementedError): # not avail for this dataset if torch.distributed.is_initialized(): sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last ) elif shuffle: sampler = torch.utils.data.RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_mem, drop_last=drop_last, ) return data_loader