Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections.abc import Mapping, Sequence | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data.dataloader import default_collate | |
from .data_container import DataContainer | |
def collate(batch, samples_per_gpu=1): | |
"""Puts each data field into a tensor/DataContainer with outer dimension | |
batch size. | |
Extend default_collate to add support for | |
:type:`~mmcv.parallel.DataContainer`. There are 3 cases. | |
1. cpu_only = True, e.g., meta data | |
2. cpu_only = False, stack = True, e.g., images tensors | |
3. cpu_only = False, stack = False, e.g., gt bboxes | |
""" | |
if not isinstance(batch, Sequence): | |
raise TypeError(f'{batch.dtype} is not supported.') | |
if isinstance(batch[0], DataContainer): | |
stacked = [] | |
if batch[0].cpu_only: | |
for i in range(0, len(batch), samples_per_gpu): | |
stacked.append( | |
[sample.data for sample in batch[i:i + samples_per_gpu]]) | |
return DataContainer( | |
stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) | |
elif batch[0].stack: | |
for i in range(0, len(batch), samples_per_gpu): | |
assert isinstance(batch[i].data, torch.Tensor) | |
if batch[i].pad_dims is not None: | |
ndim = batch[i].dim() | |
assert ndim > batch[i].pad_dims | |
max_shape = [0 for _ in range(batch[i].pad_dims)] | |
for dim in range(1, batch[i].pad_dims + 1): | |
max_shape[dim - 1] = batch[i].size(-dim) | |
for sample in batch[i:i + samples_per_gpu]: | |
for dim in range(0, ndim - batch[i].pad_dims): | |
assert batch[i].size(dim) == sample.size(dim) | |
for dim in range(1, batch[i].pad_dims + 1): | |
max_shape[dim - 1] = max(max_shape[dim - 1], | |
sample.size(-dim)) | |
padded_samples = [] | |
for sample in batch[i:i + samples_per_gpu]: | |
pad = [0 for _ in range(batch[i].pad_dims * 2)] | |
for dim in range(1, batch[i].pad_dims + 1): | |
pad[2 * dim - | |
1] = max_shape[dim - 1] - sample.size(-dim) | |
padded_samples.append( | |
F.pad( | |
sample.data, pad, value=sample.padding_value)) | |
stacked.append(default_collate(padded_samples)) | |
elif batch[i].pad_dims is None: | |
stacked.append( | |
default_collate([ | |
sample.data | |
for sample in batch[i:i + samples_per_gpu] | |
])) | |
else: | |
raise ValueError( | |
'pad_dims should be either None or integers (1-3)') | |
else: | |
for i in range(0, len(batch), samples_per_gpu): | |
stacked.append( | |
[sample.data for sample in batch[i:i + samples_per_gpu]]) | |
return DataContainer(stacked, batch[0].stack, batch[0].padding_value) | |
elif isinstance(batch[0], Sequence): | |
transposed = zip(*batch) | |
return [collate(samples, samples_per_gpu) for samples in transposed] | |
elif isinstance(batch[0], Mapping): | |
return { | |
key: collate([d[key] for d in batch], samples_per_gpu) | |
for key in batch[0] | |
} | |
else: | |
return default_collate(batch) | |