|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import collections |
|
from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format |
|
from typing import Callable, Dict, Optional, Tuple, Type, Union, List |
|
|
|
|
|
def cat_collate_tensor_fn(batch, *, collate_fn_map): |
|
return torch.cat(batch, dim=0) |
|
|
|
|
|
def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): |
|
return [item for bb in batch for item in bb] |
|
|
|
|
|
cat_collate_fn_map = default_collate_fn_map.copy() |
|
cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn |
|
cat_collate_fn_map[List] = cat_collate_list_fn |
|
cat_collate_fn_map[type(None)] = lambda _, **kw: None |
|
|
|
|
|
def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): |
|
r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """ |
|
elem = batch[0] |
|
elem_type = type(elem) |
|
|
|
if collate_fn_map is not None: |
|
if elem_type in collate_fn_map: |
|
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) |
|
|
|
for collate_type in collate_fn_map: |
|
if isinstance(elem, collate_type): |
|
return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) |
|
|
|
if isinstance(elem, collections.abc.Mapping): |
|
try: |
|
return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) |
|
except TypeError: |
|
|
|
return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} |
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
|
return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) |
|
elif isinstance(elem, collections.abc.Sequence): |
|
transposed = list(zip(*batch)) |
|
|
|
if isinstance(elem, tuple): |
|
|
|
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] |
|
else: |
|
try: |
|
return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) |
|
except TypeError: |
|
|
|
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] |
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type)) |
|
|