TTP / opencd /models /data_preprocessor.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) Open-CD. All rights reserved.
from numbers import Number
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor
from mmseg.utils import SampleList
from opencd.registry import MODELS
def stack_batch(inputs: List[torch.Tensor],
data_samples: Optional[SampleList] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Union[int, float] = 0,
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
to the max shape use the right bottom padding mode.
Args:
inputs (List[Tensor]): The input multiple tensors. each is a
CHW 3D-tensor.
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
It usually includes information such as `gt_sem_seg`.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (int, float): The padding value. Defaults to 0
seg_pad_val (int, float): The padding value. Defaults to 255
Returns:
Tensor: The 4D-tensor.
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
"""
assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}'
assert len({tensor.ndim for tensor in inputs}) == 1, \
f'Expected the dimensions of all inputs must be the same, ' \
f'but got {[tensor.ndim for tensor in inputs]}'
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
f'but got {inputs[0].ndim}'
assert len({tensor.shape[0] for tensor in inputs}) == 1, \
f'Expected the channels of all inputs must be the same, ' \
f'but got {[tensor.shape[0] for tensor in inputs]}'
# only one of size and size_divisor should be valid
assert (size is not None) ^ (size_divisor is not None), \
'only one of size and size_divisor should be valid'
padded_inputs = []
padded_samples = []
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
max_size = np.stack(inputs_sizes).max(0)
if size_divisor is not None and size_divisor > 1:
# the last two dims are H,W, both subject to divisibility requirement
max_size = (max_size +
(size_divisor - 1)) // size_divisor * size_divisor
for i in range(len(inputs)):
tensor = inputs[i]
if size is not None:
width = max(size[-1] - tensor.shape[-1], 0)
height = max(size[-2] - tensor.shape[-2], 0)
# (padding_left, padding_right, padding_top, padding_bottom)
padding_size = (0, width, 0, height)
elif size_divisor is not None:
width = max(max_size[-1] - tensor.shape[-1], 0)
height = max(max_size[-2] - tensor.shape[-2], 0)
padding_size = (0, width, 0, height)
else:
padding_size = [0, 0, 0, 0]
# pad img
pad_img = F.pad(tensor, padding_size, value=pad_val)
padded_inputs.append(pad_img)
# pad gt_sem_seg
if data_samples is not None:
data_sample = data_samples[i]
gt_sem_seg = data_sample.gt_sem_seg.data
del data_sample.gt_sem_seg.data
data_sample.gt_sem_seg.data = F.pad(
gt_sem_seg, padding_size, value=seg_pad_val)
if 'gt_edge_map' in data_sample:
gt_edge_map = data_sample.gt_edge_map.data
del data_sample.gt_edge_map.data
data_sample.gt_edge_map.data = F.pad(
gt_edge_map, padding_size, value=seg_pad_val)
if 'gt_seg_map_from' in data_sample:
gt_seg_map_from = data_sample.gt_seg_map_from.data
del data_sample.gt_seg_map_from.data
data_sample.gt_seg_map_from.data = F.pad(
gt_seg_map_from, padding_size, value=seg_pad_val)
if 'gt_seg_map_to' in data_sample:
gt_seg_map_to = data_sample.gt_seg_map_to.data
del data_sample.gt_seg_map_to.data
data_sample.gt_seg_map_to.data = F.pad(
gt_seg_map_to, padding_size, value=seg_pad_val)
data_sample.set_metainfo({
'img_shape': tensor.shape[-2:],
'pad_shape': data_sample.gt_sem_seg.shape,
'padding_size': padding_size
})
padded_samples.append(data_sample)
else:
padded_samples.append(
dict(
img_padding_size=padding_size,
pad_shape=pad_img.shape[-2:]))
return torch.stack(padded_inputs, dim=0), padded_samples
@MODELS.register_module()
class DualInputSegDataPreProcessor(BaseDataPreprocessor):
"""Image pre-processor for change detection tasks.
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the input size with defined ``pad_val``, and pad seg map
with defined ``seg_pad_val``.
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
padding_mode (str): Type of padding. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations
test_cfg (dict, optional): The padding size config in testing, if not
specify, will use `size` and `size_divisor` params as default.
Defaults to None, only supports keys `size` or `size_divisor`.
"""
def __init__(
self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Number = 0,
seg_pad_val: Number = 255,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None,
test_cfg: dict = None,
):
super().__init__()
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
assert not (bgr_to_rgb and rgb_to_bgr), (
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both ' \
'`mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
# TODO: support batch augmentations.
self.batch_augments = batch_augments
# Support different padding methods in testing
self.test_cfg = test_cfg
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Dict: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
inputs = data['inputs']
data_samples = data.get('data_samples', None)
# TODO: whether normalize should be after stack_batch
if self.channel_conversion and inputs[0].size(0) == 6:
inputs = [_input[[2, 1, 0, 5, 4, 3], ...] for _input in inputs]
inputs = [_input.float() for _input in inputs]
if self._enable_normalize:
inputs = [(_input - self.mean) / self.std for _input in inputs]
if training:
assert data_samples is not None, ('During training, ',
'`data_samples` must be define.')
inputs, data_samples = stack_batch(
inputs=inputs,
data_samples=data_samples,
size=self.size,
size_divisor=self.size_divisor,
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
if self.batch_augments is not None:
inputs, data_samples = self.batch_augments(
inputs, data_samples)
else:
assert len(inputs) == 1, (
'Batch inference is not support currently, '
'as the image size might be different in a batch')
# pad images when testing
if self.test_cfg:
inputs, padded_samples = stack_batch(
inputs=inputs,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
for data_sample, pad_info in zip(data_samples, padded_samples):
data_sample.set_metainfo({**pad_info})
else:
inputs = torch.stack(inputs, dim=0)
return dict(inputs=inputs, data_samples=data_samples)