TTP / mmseg /models /decode_heads /decode_head.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
14.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.structures import build_pixel_sampler
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from ..losses import accuracy
from ..utils import resize
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
1. The ``init_weights`` method is used to initialize decode_head's
model parameters. After segmentor initialization, ``init_weights``
is triggered when ``segmentor.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of decode_head,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> loss_by_feat()
3. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
is called based on the feature maps to predict segmentation results
including post-processing.
.. code:: text
predict(): forward() -> predict_by_feat()
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg. Default: None.
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
ignore_index=255,
sampler=None,
align_corners=False,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='conv_seg'))):
super().__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.ignore_index = ignore_index
self.align_corners = align_corners
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert `seg_logits` into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.
Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
seg_logits = resize(
input=seg_logits,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logits, seg_label, ignore_index=self.ignore_index)
return loss
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
# slide inference
size = batch_img_metas[0]['img_shape']
elif 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape'][:2]
else:
size = batch_img_metas[0]['img_shape']
seg_logits = resize(
input=seg_logits,
size=size,
mode='bilinear',
align_corners=self.align_corners)
return seg_logits