Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
from .decode_head import BaseDecodeHead | |
try: | |
from mmcv.ops import PSAMask | |
except ModuleNotFoundError: | |
PSAMask = None | |
class PSAHead(BaseDecodeHead): | |
"""Point-wise Spatial Attention Network for Scene Parsing. | |
This head is the implementation of `PSANet | |
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_. | |
Args: | |
mask_size (tuple[int]): The PSA mask size. It usually equals input | |
size. | |
psa_type (str): The type of psa module. Options are 'collect', | |
'distribute', 'bi-direction'. Default: 'bi-direction' | |
compact (bool): Whether use compact map for 'collect' mode. | |
Default: True. | |
shrink_factor (int): The downsample factors of psa mask. Default: 2. | |
normalization_factor (float): The normalize factor of attention. | |
psa_softmax (bool): Whether use softmax for attention. | |
""" | |
def __init__(self, | |
mask_size, | |
psa_type='bi-direction', | |
compact=False, | |
shrink_factor=2, | |
normalization_factor=1.0, | |
psa_softmax=True, | |
**kwargs): | |
if PSAMask is None: | |
raise RuntimeError('Please install mmcv-full for PSAMask ops') | |
super().__init__(**kwargs) | |
assert psa_type in ['collect', 'distribute', 'bi-direction'] | |
self.psa_type = psa_type | |
self.compact = compact | |
self.shrink_factor = shrink_factor | |
self.mask_size = mask_size | |
mask_h, mask_w = mask_size | |
self.psa_softmax = psa_softmax | |
if normalization_factor is None: | |
normalization_factor = mask_h * mask_w | |
self.normalization_factor = normalization_factor | |
self.reduce = ConvModule( | |
self.in_channels, | |
self.channels, | |
kernel_size=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.attention = nn.Sequential( | |
ConvModule( | |
self.channels, | |
self.channels, | |
kernel_size=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
nn.Conv2d( | |
self.channels, mask_h * mask_w, kernel_size=1, bias=False)) | |
if psa_type == 'bi-direction': | |
self.reduce_p = ConvModule( | |
self.in_channels, | |
self.channels, | |
kernel_size=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.attention_p = nn.Sequential( | |
ConvModule( | |
self.channels, | |
self.channels, | |
kernel_size=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
nn.Conv2d( | |
self.channels, mask_h * mask_w, kernel_size=1, bias=False)) | |
self.psamask_collect = PSAMask('collect', mask_size) | |
self.psamask_distribute = PSAMask('distribute', mask_size) | |
else: | |
self.psamask = PSAMask(psa_type, mask_size) | |
self.proj = ConvModule( | |
self.channels * (2 if psa_type == 'bi-direction' else 1), | |
self.in_channels, | |
kernel_size=1, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.bottleneck = ConvModule( | |
self.in_channels * 2, | |
self.channels, | |
kernel_size=3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
def forward(self, inputs): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
identity = x | |
align_corners = self.align_corners | |
if self.psa_type in ['collect', 'distribute']: | |
out = self.reduce(x) | |
n, c, h, w = out.size() | |
if self.shrink_factor != 1: | |
if h % self.shrink_factor and w % self.shrink_factor: | |
h = (h - 1) // self.shrink_factor + 1 | |
w = (w - 1) // self.shrink_factor + 1 | |
align_corners = True | |
else: | |
h = h // self.shrink_factor | |
w = w // self.shrink_factor | |
align_corners = False | |
out = resize( | |
out, | |
size=(h, w), | |
mode='bilinear', | |
align_corners=align_corners) | |
y = self.attention(out) | |
if self.compact: | |
if self.psa_type == 'collect': | |
y = y.view(n, h * w, | |
h * w).transpose(1, 2).view(n, h * w, h, w) | |
else: | |
y = self.psamask(y) | |
if self.psa_softmax: | |
y = F.softmax(y, dim=1) | |
out = torch.bmm( | |
out.view(n, c, h * w), y.view(n, h * w, h * w)).view( | |
n, c, h, w) * (1.0 / self.normalization_factor) | |
else: | |
x_col = self.reduce(x) | |
x_dis = self.reduce_p(x) | |
n, c, h, w = x_col.size() | |
if self.shrink_factor != 1: | |
if h % self.shrink_factor and w % self.shrink_factor: | |
h = (h - 1) // self.shrink_factor + 1 | |
w = (w - 1) // self.shrink_factor + 1 | |
align_corners = True | |
else: | |
h = h // self.shrink_factor | |
w = w // self.shrink_factor | |
align_corners = False | |
x_col = resize( | |
x_col, | |
size=(h, w), | |
mode='bilinear', | |
align_corners=align_corners) | |
x_dis = resize( | |
x_dis, | |
size=(h, w), | |
mode='bilinear', | |
align_corners=align_corners) | |
y_col = self.attention(x_col) | |
y_dis = self.attention_p(x_dis) | |
if self.compact: | |
y_dis = y_dis.view(n, h * w, | |
h * w).transpose(1, 2).view(n, h * w, h, w) | |
else: | |
y_col = self.psamask_collect(y_col) | |
y_dis = self.psamask_distribute(y_dis) | |
if self.psa_softmax: | |
y_col = F.softmax(y_col, dim=1) | |
y_dis = F.softmax(y_dis, dim=1) | |
x_col = torch.bmm( | |
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( | |
n, c, h, w) * (1.0 / self.normalization_factor) | |
x_dis = torch.bmm( | |
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( | |
n, c, h, w) * (1.0 / self.normalization_factor) | |
out = torch.cat([x_col, x_dis], 1) | |
out = self.proj(out) | |
out = resize( | |
out, | |
size=identity.shape[2:], | |
mode='bilinear', | |
align_corners=align_corners) | |
out = self.bottleneck(torch.cat((identity, out), dim=1)) | |
out = self.cls_seg(out) | |
return out | |