# Copyright (c) OpenMMLab. All rights reserved. from collections.abc import Mapping, Sequence import torch import torch.nn.functional as F from torch.utils.data.dataloader import default_collate from .data_container import DataContainer def collate(batch, samples_per_gpu=1): """Puts each data field into a tensor/DataContainer with outer dimension batch size. Extend default_collate to add support for :type:`~mmcv.parallel.DataContainer`. There are 3 cases. 1. cpu_only = True, e.g., meta data 2. cpu_only = False, stack = True, e.g., images tensors 3. cpu_only = False, stack = False, e.g., gt bboxes """ if not isinstance(batch, Sequence): raise TypeError(f'{batch.dtype} is not supported.') if isinstance(batch[0], DataContainer): stacked = [] if batch[0].cpu_only: for i in range(0, len(batch), samples_per_gpu): stacked.append( [sample.data for sample in batch[i:i + samples_per_gpu]]) return DataContainer( stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) elif batch[0].stack: for i in range(0, len(batch), samples_per_gpu): assert isinstance(batch[i].data, torch.Tensor) if batch[i].pad_dims is not None: ndim = batch[i].dim() assert ndim > batch[i].pad_dims max_shape = [0 for _ in range(batch[i].pad_dims)] for dim in range(1, batch[i].pad_dims + 1): max_shape[dim - 1] = batch[i].size(-dim) for sample in batch[i:i + samples_per_gpu]: for dim in range(0, ndim - batch[i].pad_dims): assert batch[i].size(dim) == sample.size(dim) for dim in range(1, batch[i].pad_dims + 1): max_shape[dim - 1] = max(max_shape[dim - 1], sample.size(-dim)) padded_samples = [] for sample in batch[i:i + samples_per_gpu]: pad = [0 for _ in range(batch[i].pad_dims * 2)] for dim in range(1, batch[i].pad_dims + 1): pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim) padded_samples.append( F.pad( sample.data, pad, value=sample.padding_value)) stacked.append(default_collate(padded_samples)) elif batch[i].pad_dims is None: stacked.append( default_collate([ sample.data for sample in batch[i:i + samples_per_gpu] ])) else: raise ValueError( 'pad_dims should be either None or integers (1-3)') else: for i in range(0, len(batch), samples_per_gpu): stacked.append( [sample.data for sample in batch[i:i + samples_per_gpu]]) return DataContainer(stacked, batch[0].stack, batch[0].padding_value) elif isinstance(batch[0], Sequence): transposed = zip(*batch) return [collate(samples, samples_per_gpu) for samples in transposed] elif isinstance(batch[0], Mapping): return { key: collate([d[key] for d in batch], samples_per_gpu) for key in batch[0] } else: return default_collate(batch)