|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, Scale |
|
from torch import Tensor, nn |
|
|
|
from mmseg.registry import MODELS |
|
from mmseg.utils import SampleList, add_prefix |
|
from ..utils import SelfAttentionBlock as _SelfAttentionBlock |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
class PAM(_SelfAttentionBlock): |
|
"""Position Attention Module (PAM) |
|
|
|
Args: |
|
in_channels (int): Input channels of key/query feature. |
|
channels (int): Output channels of key/query transform. |
|
""" |
|
|
|
def __init__(self, in_channels, channels): |
|
super().__init__( |
|
key_in_channels=in_channels, |
|
query_in_channels=in_channels, |
|
channels=channels, |
|
out_channels=in_channels, |
|
share_key_query=False, |
|
query_downsample=None, |
|
key_downsample=None, |
|
key_query_num_convs=1, |
|
key_query_norm=False, |
|
value_out_num_convs=1, |
|
value_out_norm=False, |
|
matmul_norm=False, |
|
with_out=False, |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
act_cfg=None) |
|
|
|
self.gamma = Scale(0) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
out = super().forward(x, x) |
|
|
|
out = self.gamma(out) + x |
|
return out |
|
|
|
|
|
class CAM(nn.Module): |
|
"""Channel Attention Module (CAM)""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.gamma = Scale(0) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
batch_size, channels, height, width = x.size() |
|
proj_query = x.view(batch_size, channels, -1) |
|
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) |
|
energy = torch.bmm(proj_query, proj_key) |
|
energy_new = torch.max( |
|
energy, -1, keepdim=True)[0].expand_as(energy) - energy |
|
attention = F.softmax(energy_new, dim=-1) |
|
proj_value = x.view(batch_size, channels, -1) |
|
|
|
out = torch.bmm(attention, proj_value) |
|
out = out.view(batch_size, channels, height, width) |
|
|
|
out = self.gamma(out) + x |
|
return out |
|
|
|
|
|
@MODELS.register_module() |
|
class DAHead(BaseDecodeHead): |
|
"""Dual Attention Network for Scene Segmentation. |
|
|
|
This head is the implementation of `DANet |
|
<https://arxiv.org/abs/1809.02983>`_. |
|
|
|
Args: |
|
pam_channels (int): The channels of Position Attention Module(PAM). |
|
""" |
|
|
|
def __init__(self, pam_channels, **kwargs): |
|
super().__init__(**kwargs) |
|
self.pam_channels = pam_channels |
|
self.pam_in_conv = ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
self.pam = PAM(self.channels, pam_channels) |
|
self.pam_out_conv = ConvModule( |
|
self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
self.pam_conv_seg = nn.Conv2d( |
|
self.channels, self.num_classes, kernel_size=1) |
|
|
|
self.cam_in_conv = ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
self.cam = CAM() |
|
self.cam_out_conv = ConvModule( |
|
self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
self.cam_conv_seg = nn.Conv2d( |
|
self.channels, self.num_classes, kernel_size=1) |
|
|
|
def pam_cls_seg(self, feat): |
|
"""PAM feature classification.""" |
|
if self.dropout is not None: |
|
feat = self.dropout(feat) |
|
output = self.pam_conv_seg(feat) |
|
return output |
|
|
|
def cam_cls_seg(self, feat): |
|
"""CAM feature classification.""" |
|
if self.dropout is not None: |
|
feat = self.dropout(feat) |
|
output = self.cam_conv_seg(feat) |
|
return output |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
pam_feat = self.pam_in_conv(x) |
|
pam_feat = self.pam(pam_feat) |
|
pam_feat = self.pam_out_conv(pam_feat) |
|
pam_out = self.pam_cls_seg(pam_feat) |
|
|
|
cam_feat = self.cam_in_conv(x) |
|
cam_feat = self.cam(cam_feat) |
|
cam_feat = self.cam_out_conv(cam_feat) |
|
cam_out = self.cam_cls_seg(cam_feat) |
|
|
|
feat_sum = pam_feat + cam_feat |
|
pam_cam_out = self.cls_seg(feat_sum) |
|
|
|
return pam_cam_out, pam_out, cam_out |
|
|
|
def predict(self, inputs, batch_img_metas: List[dict], test_cfg, |
|
**kwargs) -> List[Tensor]: |
|
"""Forward function for testing, only ``pam_cam`` is used.""" |
|
seg_logits = self.forward(inputs)[0] |
|
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) |
|
|
|
def loss_by_feat(self, seg_logit: Tuple[Tensor], |
|
batch_data_samples: SampleList, **kwargs) -> dict: |
|
"""Compute ``pam_cam``, ``pam``, ``cam`` loss.""" |
|
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit |
|
loss = dict() |
|
loss.update( |
|
add_prefix( |
|
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples), |
|
'pam_cam')) |
|
loss.update( |
|
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples), |
|
'pam')) |
|
loss.update( |
|
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples), |
|
'cam')) |
|
return loss |
|
|