import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_norm_layer from mmseg.ops import Encoding, resize from ..builder import HEADS, build_loss from .decode_head import BaseDecodeHead class EncModule(nn.Module): """Encoding Module used in EncNet. Args: in_channels (int): Input channels. num_codes (int): Number of code words. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. """ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): super(EncModule, self).__init__() self.encoding_project = ConvModule( in_channels, in_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) # TODO: resolve this hack # change to 1d if norm_cfg is not None: encoding_norm_cfg = norm_cfg.copy() if encoding_norm_cfg['type'] in ['BN', 'IN']: encoding_norm_cfg['type'] += '1d' else: encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( '2d', '1d') else: # fallback to BN1d encoding_norm_cfg = dict(type='BN1d') self.encoding = nn.Sequential( Encoding(channels=in_channels, num_codes=num_codes), build_norm_layer(encoding_norm_cfg, num_codes)[1], nn.ReLU(inplace=True)) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels), nn.Sigmoid()) def forward(self, x): """Forward function.""" encoding_projection = self.encoding_project(x) encoding_feat = self.encoding(encoding_projection).mean(dim=1) batch_size, channels, _, _ = x.size() gamma = self.fc(encoding_feat) y = gamma.view(batch_size, channels, 1, 1) output = F.relu_(x + x * y) return encoding_feat, output @HEADS.register_module() class EncHead(BaseDecodeHead): """Context Encoding for Semantic Segmentation. This head is the implementation of `EncNet `_. Args: num_codes (int): Number of code words. Default: 32. use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to regularize the training. Default: True. add_lateral (bool): Whether use lateral connection to fuse features. Default: False. loss_se_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss', use_sigmoid=True). """ def __init__(self, num_codes=32, use_se_loss=True, add_lateral=False, loss_se_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2), **kwargs): super(EncHead, self).__init__( input_transform='multiple_select', **kwargs) self.use_se_loss = use_se_loss self.add_lateral = add_lateral self.num_codes = num_codes self.bottleneck = ConvModule( self.in_channels[-1], self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if add_lateral: self.lateral_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the last one self.lateral_convs.append( ConvModule( in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion = ConvModule( len(self.in_channels) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.enc_module = EncModule( self.channels, num_codes=num_codes, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.use_se_loss: self.loss_se_decode = build_loss(loss_se_decode) self.se_layer = nn.Linear(self.channels, self.num_classes) def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) feat = self.bottleneck(inputs[-1]) if self.add_lateral: laterals = [ resize( lateral_conv(inputs[i]), size=feat.shape[2:], mode='bilinear', align_corners=self.align_corners) for i, lateral_conv in enumerate(self.lateral_convs) ] feat = self.fusion(torch.cat([feat, *laterals], 1)) encode_feat, output = self.enc_module(feat) output = self.cls_seg(output) if self.use_se_loss: se_output = self.se_layer(encode_feat) return output, se_output else: return output def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing, ignore se_loss.""" if self.use_se_loss: return self.forward(inputs)[0] else: return self.forward(inputs) @staticmethod def _convert_to_onehot_labels(seg_label, num_classes): """Convert segmentation label to onehot. Args: seg_label (Tensor): Segmentation label of shape (N, H, W). num_classes (int): Number of classes. Returns: Tensor: Onehot labels of shape (N, num_classes). """ batch_size = seg_label.size(0) onehot_labels = seg_label.new_zeros((batch_size, num_classes)) for i in range(batch_size): hist = seg_label[i].float().histc( bins=num_classes, min=0, max=num_classes - 1) onehot_labels[i] = hist > 0 return onehot_labels def losses(self, seg_logit, seg_label): """Compute segmentation and semantic encoding loss.""" seg_logit, se_seg_logit = seg_logit loss = dict() loss.update(super(EncHead, self).losses(seg_logit, seg_label)) se_loss = self.loss_se_decode( se_seg_logit, self._convert_to_onehot_labels(seg_label, self.num_classes)) loss['loss_se'] = se_loss return loss