KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType, SampleList
@MODELS.register_module()
class DDRHead(BaseDecodeHead):
"""Decode head for DDRNet.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_classes (int): Number of classes.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
channels: int,
num_classes: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
**kwargs):
super().__init__(
in_channels,
channels,
num_classes=num_classes,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.head = self._make_base_head(self.in_channels, self.channels)
self.aux_head = self._make_base_head(self.in_channels // 2,
self.channels)
self.aux_cls_seg = nn.Conv2d(
self.channels, self.out_channels, kernel_size=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
inputs: Union[Tensor,
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
if self.training:
c3_feat, c5_feat = inputs
x_c = self.head(c5_feat)
x_c = self.cls_seg(x_c)
x_s = self.aux_head(c3_feat)
x_s = self.aux_cls_seg(x_s)
return x_c, x_s
else:
x_c = self.head(inputs)
x_c = self.cls_seg(x_c)
return x_c
def _make_base_head(self, in_channels: int,
channels: int) -> nn.Sequential:
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
order=('norm', 'act', 'conv')),
build_norm_layer(self.norm_cfg, channels)[1],
build_activation_layer(self.act_cfg),
]
return nn.Sequential(*layers)
def loss_by_feat(self, seg_logits: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
loss = dict()
context_logit, spatial_logit = seg_logits
seg_label = self._stack_batch_gt(batch_data_samples)
context_logit = resize(
context_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
spatial_logit = resize(
spatial_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
seg_label = seg_label.squeeze(1)
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
loss['acc_seg'] = accuracy(
context_logit, seg_label, ignore_index=self.ignore_index)
return loss