Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from abc import ABCMeta, abstractmethod | |
from typing import List, Tuple | |
import torch | |
import torch.nn as nn | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmseg.structures import build_pixel_sampler | |
from mmseg.utils import ConfigType, SampleList | |
from ..builder import build_loss | |
from ..losses import accuracy | |
from ..utils import resize | |
class BaseDecodeHead(BaseModule, metaclass=ABCMeta): | |
"""Base class for BaseDecodeHead. | |
1. The ``init_weights`` method is used to initialize decode_head's | |
model parameters. After segmentor initialization, ``init_weights`` | |
is triggered when ``segmentor.init_weights()`` is called externally. | |
2. The ``loss`` method is used to calculate the loss of decode_head, | |
which includes two steps: (1) the decode_head model performs forward | |
propagation to obtain the feature maps (2) The ``loss_by_feat`` method | |
is called based on the feature maps to calculate the loss. | |
.. code:: text | |
loss(): forward() -> loss_by_feat() | |
3. The ``predict`` method is used to predict segmentation results, | |
which includes two steps: (1) the decode_head model performs forward | |
propagation to obtain the feature maps (2) The ``predict_by_feat`` method | |
is called based on the feature maps to predict segmentation results | |
including post-processing. | |
.. code:: text | |
predict(): forward() -> predict_by_feat() | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
channels (int): Channels after modules, before conv_seg. | |
num_classes (int): Number of classes. | |
out_channels (int): Output channels of conv_seg. Default: None. | |
threshold (float): Threshold for binary segmentation in the case of | |
`num_classes==1`. Default: None. | |
dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |
conv_cfg (dict|None): Config of conv layers. Default: None. | |
norm_cfg (dict|None): Config of norm layers. Default: None. | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU') | |
in_index (int|Sequence[int]): Input feature index. Default: -1 | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
Default: None. | |
loss_decode (dict | Sequence[dict]): Config of decode loss. | |
The `loss_name` is property of corresponding loss function which | |
could be shown in training log. If you want this loss | |
item to be included into the backward graph, `loss_` must be the | |
prefix of the name. Defaults to 'loss_ce'. | |
e.g. dict(type='CrossEntropyLoss'), | |
[dict(type='CrossEntropyLoss', loss_name='loss_ce'), | |
dict(type='DiceLoss', loss_name='loss_dice')] | |
Default: dict(type='CrossEntropyLoss'). | |
ignore_index (int | None): The label index to be ignored. When using | |
masked BCE loss, ignore_index should be set to None. Default: 255. | |
sampler (dict|None): The config of segmentation map sampler. | |
Default: None. | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
in_channels, | |
channels, | |
*, | |
num_classes, | |
out_channels=None, | |
threshold=None, | |
dropout_ratio=0.1, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=dict(type='ReLU'), | |
in_index=-1, | |
input_transform=None, | |
loss_decode=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0), | |
ignore_index=255, | |
sampler=None, | |
align_corners=False, | |
init_cfg=dict( | |
type='Normal', std=0.01, override=dict(name='conv_seg'))): | |
super().__init__(init_cfg) | |
self._init_inputs(in_channels, in_index, input_transform) | |
self.channels = channels | |
self.dropout_ratio = dropout_ratio | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.in_index = in_index | |
self.ignore_index = ignore_index | |
self.align_corners = align_corners | |
if out_channels is None: | |
if num_classes == 2: | |
warnings.warn('For binary segmentation, we suggest using' | |
'`out_channels = 1` to define the output' | |
'channels of segmentor, and use `threshold`' | |
'to convert `seg_logits` into a prediction' | |
'applying a threshold') | |
out_channels = num_classes | |
if out_channels != num_classes and out_channels != 1: | |
raise ValueError( | |
'out_channels should be equal to num_classes,' | |
'except binary segmentation set out_channels == 1 and' | |
f'num_classes == 2, but got out_channels={out_channels}' | |
f'and num_classes={num_classes}') | |
if out_channels == 1 and threshold is None: | |
threshold = 0.3 | |
warnings.warn('threshold is not defined for binary, and defaults' | |
'to 0.3') | |
self.num_classes = num_classes | |
self.out_channels = out_channels | |
self.threshold = threshold | |
if isinstance(loss_decode, dict): | |
self.loss_decode = build_loss(loss_decode) | |
elif isinstance(loss_decode, (list, tuple)): | |
self.loss_decode = nn.ModuleList() | |
for loss in loss_decode: | |
self.loss_decode.append(build_loss(loss)) | |
else: | |
raise TypeError(f'loss_decode must be a dict or sequence of dict,\ | |
but got {type(loss_decode)}') | |
if sampler is not None: | |
self.sampler = build_pixel_sampler(sampler, context=self) | |
else: | |
self.sampler = None | |
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) | |
if dropout_ratio > 0: | |
self.dropout = nn.Dropout2d(dropout_ratio) | |
else: | |
self.dropout = None | |
def extra_repr(self): | |
"""Extra repr.""" | |
s = f'input_transform={self.input_transform}, ' \ | |
f'ignore_index={self.ignore_index}, ' \ | |
f'align_corners={self.align_corners}' | |
return s | |
def _init_inputs(self, in_channels, in_index, input_transform): | |
"""Check and initialize input transforms. | |
The in_channels, in_index and input_transform must match. | |
Specifically, when input_transform is None, only single feature map | |
will be selected. So in_channels and in_index must be of type int. | |
When input_transform | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
in_index (int|Sequence[int]): Input feature index. | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
""" | |
if input_transform is not None: | |
assert input_transform in ['resize_concat', 'multiple_select'] | |
self.input_transform = input_transform | |
self.in_index = in_index | |
if input_transform is not None: | |
assert isinstance(in_channels, (list, tuple)) | |
assert isinstance(in_index, (list, tuple)) | |
assert len(in_channels) == len(in_index) | |
if input_transform == 'resize_concat': | |
self.in_channels = sum(in_channels) | |
else: | |
self.in_channels = in_channels | |
else: | |
assert isinstance(in_channels, int) | |
assert isinstance(in_index, int) | |
self.in_channels = in_channels | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |
def forward(self, inputs): | |
"""Placeholder of forward function.""" | |
pass | |
def cls_seg(self, feat): | |
"""Classify each pixel.""" | |
if self.dropout is not None: | |
feat = self.dropout(feat) | |
output = self.conv_seg(feat) | |
return output | |
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, | |
train_cfg: ConfigType) -> dict: | |
"""Forward function for training. | |
Args: | |
inputs (Tuple[Tensor]): List of multi-level img features. | |
batch_data_samples (list[:obj:`SegDataSample`]): The seg | |
data samples. It usually includes information such | |
as `img_metas` or `gt_semantic_seg`. | |
train_cfg (dict): The training config. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
seg_logits = self.forward(inputs) | |
losses = self.loss_by_feat(seg_logits, batch_data_samples) | |
return losses | |
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], | |
test_cfg: ConfigType) -> Tensor: | |
"""Forward function for prediction. | |
Args: | |
inputs (Tuple[Tensor]): List of multi-level img features. | |
batch_img_metas (dict): List Image info where each dict may also | |
contain: 'img_shape', 'scale_factor', 'flip', 'img_path', | |
'ori_shape', and 'pad_shape'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`. | |
test_cfg (dict): The testing config. | |
Returns: | |
Tensor: Outputs segmentation logits map. | |
""" | |
seg_logits = self.forward(inputs) | |
return self.predict_by_feat(seg_logits, batch_img_metas) | |
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: | |
gt_semantic_segs = [ | |
data_sample.gt_sem_seg.data for data_sample in batch_data_samples | |
] | |
return torch.stack(gt_semantic_segs, dim=0) | |
def loss_by_feat(self, seg_logits: Tensor, | |
batch_data_samples: SampleList) -> dict: | |
"""Compute segmentation loss. | |
Args: | |
seg_logits (Tensor): The output from decode head forward function. | |
batch_data_samples (List[:obj:`SegDataSample`]): The seg | |
data samples. It usually includes information such | |
as `metainfo` and `gt_sem_seg`. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
seg_label = self._stack_batch_gt(batch_data_samples) | |
loss = dict() | |
seg_logits = resize( | |
input=seg_logits, | |
size=seg_label.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
if self.sampler is not None: | |
seg_weight = self.sampler.sample(seg_logits, seg_label) | |
else: | |
seg_weight = None | |
seg_label = seg_label.squeeze(1) | |
if not isinstance(self.loss_decode, nn.ModuleList): | |
losses_decode = [self.loss_decode] | |
else: | |
losses_decode = self.loss_decode | |
for loss_decode in losses_decode: | |
if loss_decode.loss_name not in loss: | |
loss[loss_decode.loss_name] = loss_decode( | |
seg_logits, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
else: | |
loss[loss_decode.loss_name] += loss_decode( | |
seg_logits, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
loss['acc_seg'] = accuracy( | |
seg_logits, seg_label, ignore_index=self.ignore_index) | |
return loss | |
def predict_by_feat(self, seg_logits: Tensor, | |
batch_img_metas: List[dict]) -> Tensor: | |
"""Transform a batch of output seg_logits to the input shape. | |
Args: | |
seg_logits (Tensor): The output from decode head forward function. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
Returns: | |
Tensor: Outputs segmentation logits map. | |
""" | |
if isinstance(batch_img_metas[0]['img_shape'], torch.Size): | |
# slide inference | |
size = batch_img_metas[0]['img_shape'] | |
elif 'pad_shape' in batch_img_metas[0]: | |
size = batch_img_metas[0]['pad_shape'][:2] | |
else: | |
size = batch_img_metas[0]['img_shape'] | |
seg_logits = resize( | |
input=seg_logits, | |
size=size, | |
mode='bilinear', | |
align_corners=self.align_corners) | |
return seg_logits | |