|
|
|
from typing import Tuple, Union |
|
|
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer |
|
from torch import Tensor |
|
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead |
|
from mmseg.models.losses import accuracy |
|
from mmseg.models.utils import resize |
|
from mmseg.registry import MODELS |
|
from mmseg.utils import OptConfigType, SampleList |
|
|
|
|
|
@MODELS.register_module() |
|
class DDRHead(BaseDecodeHead): |
|
"""Decode head for DDRNet. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
channels (int): Number of output channels. |
|
num_classes (int): Number of classes. |
|
norm_cfg (dict, optional): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict, optional): Config dict for activation layer. |
|
Default: dict(type='ReLU', inplace=True). |
|
""" |
|
|
|
def __init__(self, |
|
in_channels: int, |
|
channels: int, |
|
num_classes: int, |
|
norm_cfg: OptConfigType = dict(type='BN'), |
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True), |
|
**kwargs): |
|
super().__init__( |
|
in_channels, |
|
channels, |
|
num_classes=num_classes, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
**kwargs) |
|
|
|
self.head = self._make_base_head(self.in_channels, self.channels) |
|
self.aux_head = self._make_base_head(self.in_channels // 2, |
|
self.channels) |
|
self.aux_cls_seg = nn.Conv2d( |
|
self.channels, self.out_channels, kernel_size=1) |
|
|
|
def init_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_( |
|
m.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward( |
|
self, |
|
inputs: Union[Tensor, |
|
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: |
|
if self.training: |
|
c3_feat, c5_feat = inputs |
|
x_c = self.head(c5_feat) |
|
x_c = self.cls_seg(x_c) |
|
x_s = self.aux_head(c3_feat) |
|
x_s = self.aux_cls_seg(x_s) |
|
|
|
return x_c, x_s |
|
else: |
|
x_c = self.head(inputs) |
|
x_c = self.cls_seg(x_c) |
|
return x_c |
|
|
|
def _make_base_head(self, in_channels: int, |
|
channels: int) -> nn.Sequential: |
|
layers = [ |
|
ConvModule( |
|
in_channels, |
|
channels, |
|
kernel_size=3, |
|
padding=1, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
order=('norm', 'act', 'conv')), |
|
build_norm_layer(self.norm_cfg, channels)[1], |
|
build_activation_layer(self.act_cfg), |
|
] |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def loss_by_feat(self, seg_logits: Tuple[Tensor], |
|
batch_data_samples: SampleList) -> dict: |
|
loss = dict() |
|
context_logit, spatial_logit = seg_logits |
|
seg_label = self._stack_batch_gt(batch_data_samples) |
|
|
|
context_logit = resize( |
|
context_logit, |
|
size=seg_label.shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
spatial_logit = resize( |
|
spatial_logit, |
|
size=seg_label.shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
seg_label = seg_label.squeeze(1) |
|
|
|
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label) |
|
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label) |
|
loss['acc_seg'] = accuracy( |
|
context_logit, seg_label, ignore_index=self.ignore_index) |
|
|
|
return loss |
|
|