|
|
|
from numbers import Number |
|
from typing import Any, Dict, List, Optional, Sequence, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from mmengine.model import BaseDataPreprocessor |
|
|
|
from mmseg.utils import SampleList |
|
from opencd.registry import MODELS |
|
|
|
|
|
def stack_batch(inputs: List[torch.Tensor], |
|
data_samples: Optional[SampleList] = None, |
|
size: Optional[tuple] = None, |
|
size_divisor: Optional[int] = None, |
|
pad_val: Union[int, float] = 0, |
|
seg_pad_val: Union[int, float] = 255) -> torch.Tensor: |
|
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs |
|
to the max shape use the right bottom padding mode. |
|
|
|
Args: |
|
inputs (List[Tensor]): The input multiple tensors. each is a |
|
CHW 3D-tensor. |
|
data_samples (list[:obj:`SegDataSample`]): The list of data samples. |
|
It usually includes information such as `gt_sem_seg`. |
|
size (tuple, optional): Fixed padding size. |
|
size_divisor (int, optional): The divisor of padded size. |
|
pad_val (int, float): The padding value. Defaults to 0 |
|
seg_pad_val (int, float): The padding value. Defaults to 255 |
|
|
|
Returns: |
|
Tensor: The 4D-tensor. |
|
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. |
|
""" |
|
assert isinstance(inputs, list), \ |
|
f'Expected input type to be list, but got {type(inputs)}' |
|
assert len({tensor.ndim for tensor in inputs}) == 1, \ |
|
f'Expected the dimensions of all inputs must be the same, ' \ |
|
f'but got {[tensor.ndim for tensor in inputs]}' |
|
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ |
|
f'but got {inputs[0].ndim}' |
|
assert len({tensor.shape[0] for tensor in inputs}) == 1, \ |
|
f'Expected the channels of all inputs must be the same, ' \ |
|
f'but got {[tensor.shape[0] for tensor in inputs]}' |
|
|
|
|
|
assert (size is not None) ^ (size_divisor is not None), \ |
|
'only one of size and size_divisor should be valid' |
|
|
|
padded_inputs = [] |
|
padded_samples = [] |
|
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] |
|
max_size = np.stack(inputs_sizes).max(0) |
|
if size_divisor is not None and size_divisor > 1: |
|
|
|
max_size = (max_size + |
|
(size_divisor - 1)) // size_divisor * size_divisor |
|
|
|
for i in range(len(inputs)): |
|
tensor = inputs[i] |
|
if size is not None: |
|
width = max(size[-1] - tensor.shape[-1], 0) |
|
height = max(size[-2] - tensor.shape[-2], 0) |
|
|
|
padding_size = (0, width, 0, height) |
|
elif size_divisor is not None: |
|
width = max(max_size[-1] - tensor.shape[-1], 0) |
|
height = max(max_size[-2] - tensor.shape[-2], 0) |
|
padding_size = (0, width, 0, height) |
|
else: |
|
padding_size = [0, 0, 0, 0] |
|
|
|
|
|
pad_img = F.pad(tensor, padding_size, value=pad_val) |
|
padded_inputs.append(pad_img) |
|
|
|
if data_samples is not None: |
|
data_sample = data_samples[i] |
|
gt_sem_seg = data_sample.gt_sem_seg.data |
|
del data_sample.gt_sem_seg.data |
|
data_sample.gt_sem_seg.data = F.pad( |
|
gt_sem_seg, padding_size, value=seg_pad_val) |
|
if 'gt_edge_map' in data_sample: |
|
gt_edge_map = data_sample.gt_edge_map.data |
|
del data_sample.gt_edge_map.data |
|
data_sample.gt_edge_map.data = F.pad( |
|
gt_edge_map, padding_size, value=seg_pad_val) |
|
if 'gt_seg_map_from' in data_sample: |
|
gt_seg_map_from = data_sample.gt_seg_map_from.data |
|
del data_sample.gt_seg_map_from.data |
|
data_sample.gt_seg_map_from.data = F.pad( |
|
gt_seg_map_from, padding_size, value=seg_pad_val) |
|
if 'gt_seg_map_to' in data_sample: |
|
gt_seg_map_to = data_sample.gt_seg_map_to.data |
|
del data_sample.gt_seg_map_to.data |
|
data_sample.gt_seg_map_to.data = F.pad( |
|
gt_seg_map_to, padding_size, value=seg_pad_val) |
|
data_sample.set_metainfo({ |
|
'img_shape': tensor.shape[-2:], |
|
'pad_shape': data_sample.gt_sem_seg.shape, |
|
'padding_size': padding_size |
|
}) |
|
padded_samples.append(data_sample) |
|
else: |
|
padded_samples.append( |
|
dict( |
|
img_padding_size=padding_size, |
|
pad_shape=pad_img.shape[-2:])) |
|
|
|
return torch.stack(padded_inputs, dim=0), padded_samples |
|
|
|
|
|
@MODELS.register_module() |
|
class DualInputSegDataPreProcessor(BaseDataPreprocessor): |
|
"""Image pre-processor for change detection tasks. |
|
|
|
Comparing with the :class:`mmengine.ImgDataPreprocessor`, |
|
|
|
1. It won't do normalization if ``mean`` is not specified. |
|
2. It does normalization and color space conversion after stacking batch. |
|
3. It supports batch augmentations like mixup and cutmix. |
|
|
|
|
|
It provides the data pre-processing as follows |
|
|
|
- Collate and move data to the target device. |
|
- Pad inputs to the input size with defined ``pad_val``, and pad seg map |
|
with defined ``seg_pad_val``. |
|
- Stack inputs to batch_inputs. |
|
- Convert inputs from bgr to rgb if the shape of input is (3, H, W). |
|
- Normalize image with defined std and mean. |
|
- Do batch augmentations like Mixup and Cutmix during training. |
|
|
|
Args: |
|
mean (Sequence[Number], optional): The pixel mean of R, G, B channels. |
|
Defaults to None. |
|
std (Sequence[Number], optional): The pixel standard deviation of |
|
R, G, B channels. Defaults to None. |
|
size (tuple, optional): Fixed padding size. |
|
size_divisor (int, optional): The divisor of padded size. |
|
pad_val (float, optional): Padding value. Default: 0. |
|
seg_pad_val (float, optional): Padding value of segmentation map. |
|
Default: 255. |
|
padding_mode (str): Type of padding. Default: constant. |
|
- constant: pads with a constant value, this value is specified |
|
with pad_val. |
|
bgr_to_rgb (bool): whether to convert image from BGR to RGB. |
|
Defaults to False. |
|
rgb_to_bgr (bool): whether to convert image from RGB to RGB. |
|
Defaults to False. |
|
batch_augments (list[dict], optional): Batch-level augmentations |
|
test_cfg (dict, optional): The padding size config in testing, if not |
|
specify, will use `size` and `size_divisor` params as default. |
|
Defaults to None, only supports keys `size` or `size_divisor`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mean: Sequence[Number] = None, |
|
std: Sequence[Number] = None, |
|
size: Optional[tuple] = None, |
|
size_divisor: Optional[int] = None, |
|
pad_val: Number = 0, |
|
seg_pad_val: Number = 255, |
|
bgr_to_rgb: bool = False, |
|
rgb_to_bgr: bool = False, |
|
batch_augments: Optional[List[dict]] = None, |
|
test_cfg: dict = None, |
|
): |
|
super().__init__() |
|
self.size = size |
|
self.size_divisor = size_divisor |
|
self.pad_val = pad_val |
|
self.seg_pad_val = seg_pad_val |
|
|
|
assert not (bgr_to_rgb and rgb_to_bgr), ( |
|
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') |
|
self.channel_conversion = rgb_to_bgr or bgr_to_rgb |
|
|
|
if mean is not None: |
|
assert std is not None, 'To enable the normalization in ' \ |
|
'preprocessing, please specify both ' \ |
|
'`mean` and `std`.' |
|
|
|
self._enable_normalize = True |
|
self.register_buffer('mean', |
|
torch.tensor(mean).view(-1, 1, 1), False) |
|
self.register_buffer('std', |
|
torch.tensor(std).view(-1, 1, 1), False) |
|
else: |
|
self._enable_normalize = False |
|
|
|
|
|
self.batch_augments = batch_augments |
|
|
|
|
|
self.test_cfg = test_cfg |
|
|
|
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: |
|
"""Perform normalization、padding and bgr2rgb conversion based on |
|
``BaseDataPreprocessor``. |
|
|
|
Args: |
|
data (dict): data sampled from dataloader. |
|
training (bool): Whether to enable training time augmentation. |
|
|
|
Returns: |
|
Dict: Data in the same format as the model input. |
|
""" |
|
data = self.cast_data(data) |
|
inputs = data['inputs'] |
|
data_samples = data.get('data_samples', None) |
|
|
|
if self.channel_conversion and inputs[0].size(0) == 6: |
|
inputs = [_input[[2, 1, 0, 5, 4, 3], ...] for _input in inputs] |
|
|
|
inputs = [_input.float() for _input in inputs] |
|
if self._enable_normalize: |
|
inputs = [(_input - self.mean) / self.std for _input in inputs] |
|
|
|
if training: |
|
assert data_samples is not None, ('During training, ', |
|
'`data_samples` must be define.') |
|
inputs, data_samples = stack_batch( |
|
inputs=inputs, |
|
data_samples=data_samples, |
|
size=self.size, |
|
size_divisor=self.size_divisor, |
|
pad_val=self.pad_val, |
|
seg_pad_val=self.seg_pad_val) |
|
|
|
if self.batch_augments is not None: |
|
inputs, data_samples = self.batch_augments( |
|
inputs, data_samples) |
|
else: |
|
assert len(inputs) == 1, ( |
|
'Batch inference is not support currently, ' |
|
'as the image size might be different in a batch') |
|
|
|
if self.test_cfg: |
|
inputs, padded_samples = stack_batch( |
|
inputs=inputs, |
|
size=self.test_cfg.get('size', None), |
|
size_divisor=self.test_cfg.get('size_divisor', None), |
|
pad_val=self.pad_val, |
|
seg_pad_val=self.seg_pad_val) |
|
for data_sample, pad_info in zip(data_samples, padded_samples): |
|
data_sample.set_metainfo({**pad_info}) |
|
else: |
|
inputs = torch.stack(inputs, dim=0) |
|
|
|
return dict(inputs=inputs, data_samples=data_samples) |
|
|