from datasets import get_dataset from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate def collate_fn(batch): batch.sort(key=lambda x: x[3], reverse=True) return default_collate(batch) def get_dataset_loader(opt, batch_size, mode="eval", split="test", accelerator=None): dataset = get_dataset(opt, split, mode, accelerator) if mode in ["eval", "gt_eval"]: dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, collate_fn=collate_fn, ) else: dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True, ) return dataloader