Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Tuple | |
import torch | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule, Scale | |
from torch import Tensor, nn | |
from mmseg.registry import MODELS | |
from mmseg.utils import SampleList, add_prefix | |
from ..utils import SelfAttentionBlock as _SelfAttentionBlock | |
from .decode_head import BaseDecodeHead | |
class PAM(_SelfAttentionBlock): | |
"""Position Attention Module (PAM) | |
Args: | |
in_channels (int): Input channels of key/query feature. | |
channels (int): Output channels of key/query transform. | |
""" | |
def __init__(self, in_channels, channels): | |
super().__init__( | |
key_in_channels=in_channels, | |
query_in_channels=in_channels, | |
channels=channels, | |
out_channels=in_channels, | |
share_key_query=False, | |
query_downsample=None, | |
key_downsample=None, | |
key_query_num_convs=1, | |
key_query_norm=False, | |
value_out_num_convs=1, | |
value_out_norm=False, | |
matmul_norm=False, | |
with_out=False, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=None) | |
self.gamma = Scale(0) | |
def forward(self, x): | |
"""Forward function.""" | |
out = super().forward(x, x) | |
out = self.gamma(out) + x | |
return out | |
class CAM(nn.Module): | |
"""Channel Attention Module (CAM)""" | |
def __init__(self): | |
super().__init__() | |
self.gamma = Scale(0) | |
def forward(self, x): | |
"""Forward function.""" | |
batch_size, channels, height, width = x.size() | |
proj_query = x.view(batch_size, channels, -1) | |
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) | |
energy = torch.bmm(proj_query, proj_key) | |
energy_new = torch.max( | |
energy, -1, keepdim=True)[0].expand_as(energy) - energy | |
attention = F.softmax(energy_new, dim=-1) | |
proj_value = x.view(batch_size, channels, -1) | |
out = torch.bmm(attention, proj_value) | |
out = out.view(batch_size, channels, height, width) | |
out = self.gamma(out) + x | |
return out | |
class DAHead(BaseDecodeHead): | |
"""Dual Attention Network for Scene Segmentation. | |
This head is the implementation of `DANet | |
<https://arxiv.org/abs/1809.02983>`_. | |
Args: | |
pam_channels (int): The channels of Position Attention Module(PAM). | |
""" | |
def __init__(self, pam_channels, **kwargs): | |
super().__init__(**kwargs) | |
self.pam_channels = pam_channels | |
self.pam_in_conv = ConvModule( | |
self.in_channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.pam = PAM(self.channels, pam_channels) | |
self.pam_out_conv = ConvModule( | |
self.channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.pam_conv_seg = nn.Conv2d( | |
self.channels, self.num_classes, kernel_size=1) | |
self.cam_in_conv = ConvModule( | |
self.in_channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.cam = CAM() | |
self.cam_out_conv = ConvModule( | |
self.channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.cam_conv_seg = nn.Conv2d( | |
self.channels, self.num_classes, kernel_size=1) | |
def pam_cls_seg(self, feat): | |
"""PAM feature classification.""" | |
if self.dropout is not None: | |
feat = self.dropout(feat) | |
output = self.pam_conv_seg(feat) | |
return output | |
def cam_cls_seg(self, feat): | |
"""CAM feature classification.""" | |
if self.dropout is not None: | |
feat = self.dropout(feat) | |
output = self.cam_conv_seg(feat) | |
return output | |
def forward(self, inputs): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
pam_feat = self.pam_in_conv(x) | |
pam_feat = self.pam(pam_feat) | |
pam_feat = self.pam_out_conv(pam_feat) | |
pam_out = self.pam_cls_seg(pam_feat) | |
cam_feat = self.cam_in_conv(x) | |
cam_feat = self.cam(cam_feat) | |
cam_feat = self.cam_out_conv(cam_feat) | |
cam_out = self.cam_cls_seg(cam_feat) | |
feat_sum = pam_feat + cam_feat | |
pam_cam_out = self.cls_seg(feat_sum) | |
return pam_cam_out, pam_out, cam_out | |
def predict(self, inputs, batch_img_metas: List[dict], test_cfg, | |
**kwargs) -> List[Tensor]: | |
"""Forward function for testing, only ``pam_cam`` is used.""" | |
seg_logits = self.forward(inputs)[0] | |
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) | |
def loss_by_feat(self, seg_logit: Tuple[Tensor], | |
batch_data_samples: SampleList, **kwargs) -> dict: | |
"""Compute ``pam_cam``, ``pam``, ``cam`` loss.""" | |
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit | |
loss = dict() | |
loss.update( | |
add_prefix( | |
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples), | |
'pam_cam')) | |
loss.update( | |
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples), | |
'pam')) | |
loss.update( | |
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples), | |
'cam')) | |
return loss | |