|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, build_norm_layer |
|
from torch import Tensor |
|
|
|
from mmseg.registry import MODELS |
|
from mmseg.utils import ConfigType, SampleList |
|
from ..builder import build_loss |
|
from ..utils import Encoding, resize |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
class EncModule(nn.Module): |
|
"""Encoding Module used in EncNet. |
|
|
|
Args: |
|
in_channels (int): Input channels. |
|
num_codes (int): Number of code words. |
|
conv_cfg (dict|None): Config of conv layers. |
|
norm_cfg (dict|None): Config of norm layers. |
|
act_cfg (dict): Config of activation layers. |
|
""" |
|
|
|
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): |
|
super().__init__() |
|
self.encoding_project = ConvModule( |
|
in_channels, |
|
in_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
|
|
|
|
if norm_cfg is not None: |
|
encoding_norm_cfg = norm_cfg.copy() |
|
if encoding_norm_cfg['type'] in ['BN', 'IN']: |
|
encoding_norm_cfg['type'] += '1d' |
|
else: |
|
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( |
|
'2d', '1d') |
|
else: |
|
|
|
encoding_norm_cfg = dict(type='BN1d') |
|
self.encoding = nn.Sequential( |
|
Encoding(channels=in_channels, num_codes=num_codes), |
|
build_norm_layer(encoding_norm_cfg, num_codes)[1], |
|
nn.ReLU(inplace=True)) |
|
self.fc = nn.Sequential( |
|
nn.Linear(in_channels, in_channels), nn.Sigmoid()) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
encoding_projection = self.encoding_project(x) |
|
encoding_feat = self.encoding(encoding_projection).mean(dim=1) |
|
batch_size, channels, _, _ = x.size() |
|
gamma = self.fc(encoding_feat) |
|
y = gamma.view(batch_size, channels, 1, 1) |
|
output = F.relu_(x + x * y) |
|
return encoding_feat, output |
|
|
|
|
|
@MODELS.register_module() |
|
class EncHead(BaseDecodeHead): |
|
"""Context Encoding for Semantic Segmentation. |
|
|
|
This head is the implementation of `EncNet |
|
<https://arxiv.org/abs/1803.08904>`_. |
|
|
|
Args: |
|
num_codes (int): Number of code words. Default: 32. |
|
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to |
|
regularize the training. Default: True. |
|
add_lateral (bool): Whether use lateral connection to fuse features. |
|
Default: False. |
|
loss_se_decode (dict): Config of decode loss. |
|
Default: dict(type='CrossEntropyLoss', use_sigmoid=True). |
|
""" |
|
|
|
def __init__(self, |
|
num_codes=32, |
|
use_se_loss=True, |
|
add_lateral=False, |
|
loss_se_decode=dict( |
|
type='CrossEntropyLoss', |
|
use_sigmoid=True, |
|
loss_weight=0.2), |
|
**kwargs): |
|
super().__init__(input_transform='multiple_select', **kwargs) |
|
self.use_se_loss = use_se_loss |
|
self.add_lateral = add_lateral |
|
self.num_codes = num_codes |
|
self.bottleneck = ConvModule( |
|
self.in_channels[-1], |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
if add_lateral: |
|
self.lateral_convs = nn.ModuleList() |
|
for in_channels in self.in_channels[:-1]: |
|
self.lateral_convs.append( |
|
ConvModule( |
|
in_channels, |
|
self.channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
self.fusion = ConvModule( |
|
len(self.in_channels) * self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
self.enc_module = EncModule( |
|
self.channels, |
|
num_codes=num_codes, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
if self.use_se_loss: |
|
self.loss_se_decode = build_loss(loss_se_decode) |
|
self.se_layer = nn.Linear(self.channels, self.num_classes) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
inputs = self._transform_inputs(inputs) |
|
feat = self.bottleneck(inputs[-1]) |
|
if self.add_lateral: |
|
laterals = [ |
|
resize( |
|
lateral_conv(inputs[i]), |
|
size=feat.shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
for i, lateral_conv in enumerate(self.lateral_convs) |
|
] |
|
feat = self.fusion(torch.cat([feat, *laterals], 1)) |
|
encode_feat, output = self.enc_module(feat) |
|
output = self.cls_seg(output) |
|
if self.use_se_loss: |
|
se_output = self.se_layer(encode_feat) |
|
return output, se_output |
|
else: |
|
return output |
|
|
|
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], |
|
test_cfg: ConfigType): |
|
"""Forward function for testing, ignore se_loss.""" |
|
if self.use_se_loss: |
|
seg_logits = self.forward(inputs)[0] |
|
else: |
|
seg_logits = self.forward(inputs) |
|
return self.predict_by_feat(seg_logits, batch_img_metas) |
|
|
|
@staticmethod |
|
def _convert_to_onehot_labels(seg_label, num_classes): |
|
"""Convert segmentation label to onehot. |
|
|
|
Args: |
|
seg_label (Tensor): Segmentation label of shape (N, H, W). |
|
num_classes (int): Number of classes. |
|
|
|
Returns: |
|
Tensor: Onehot labels of shape (N, num_classes). |
|
""" |
|
|
|
batch_size = seg_label.size(0) |
|
onehot_labels = seg_label.new_zeros((batch_size, num_classes)) |
|
for i in range(batch_size): |
|
hist = seg_label[i].float().histc( |
|
bins=num_classes, min=0, max=num_classes - 1) |
|
onehot_labels[i] = hist > 0 |
|
return onehot_labels |
|
|
|
def loss_by_feat(self, seg_logit: Tuple[Tensor], |
|
batch_data_samples: SampleList, **kwargs) -> dict: |
|
"""Compute segmentation and semantic encoding loss.""" |
|
seg_logit, se_seg_logit = seg_logit |
|
loss = dict() |
|
loss.update(super().loss_by_feat(seg_logit, batch_data_samples)) |
|
|
|
seg_label = self._stack_batch_gt(batch_data_samples) |
|
se_loss = self.loss_se_decode( |
|
se_seg_logit, |
|
self._convert_to_onehot_labels(seg_label, self.num_classes)) |
|
loss['loss_se'] = se_loss |
|
return loss |
|
|