# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Collate extensions # -------------------------------------------------------- import torch import collections from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format from typing import Callable, Dict, Optional, Tuple, Type, Union, List def cat_collate_tensor_fn(batch, *, collate_fn_map): return torch.cat(batch, dim=0) def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): return [item for bb in batch for item in bb] # concatenate all lists cat_collate_fn_map = default_collate_fn_map.copy() cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn cat_collate_fn_map[List] = cat_collate_list_fn cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """ elem = batch[0] elem_type = type(elem) if collate_fn_map is not None: if elem_type in collate_fn_map: return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) for collate_type in collate_fn_map: if isinstance(elem, collate_type): return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) if isinstance(elem, collections.abc.Mapping): try: return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) except TypeError: # The mapping type may not support `__init__(iterable)`. return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. if isinstance(elem, tuple): # Backwards compatibility. return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] else: try: return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) except TypeError: # The sequence type may not support `__init__(iterable)` (e.g., `range`). return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type))