Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmseg.registry import MODELS | |
from ..utils import SelfAttentionBlock as _SelfAttentionBlock | |
from ..utils import resize | |
from .cascade_decode_head import BaseCascadeDecodeHead | |
class SpatialGatherModule(nn.Module): | |
"""Aggregate the context features according to the initial predicted | |
probability distribution. | |
Employ the soft-weighted method to aggregate the context. | |
""" | |
def __init__(self, scale): | |
super().__init__() | |
self.scale = scale | |
def forward(self, feats, probs): | |
"""Forward function.""" | |
batch_size, num_classes, height, width = probs.size() | |
channels = feats.size(1) | |
probs = probs.view(batch_size, num_classes, -1) | |
feats = feats.view(batch_size, channels, -1) | |
# [batch_size, height*width, num_classes] | |
feats = feats.permute(0, 2, 1) | |
# [batch_size, channels, height*width] | |
probs = F.softmax(self.scale * probs, dim=2) | |
# [batch_size, channels, num_classes] | |
ocr_context = torch.matmul(probs, feats) | |
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) | |
return ocr_context | |
class ObjectAttentionBlock(_SelfAttentionBlock): | |
"""Make a OCR used SelfAttentionBlock.""" | |
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, | |
act_cfg): | |
if scale > 1: | |
query_downsample = nn.MaxPool2d(kernel_size=scale) | |
else: | |
query_downsample = None | |
super().__init__( | |
key_in_channels=in_channels, | |
query_in_channels=in_channels, | |
channels=channels, | |
out_channels=in_channels, | |
share_key_query=False, | |
query_downsample=query_downsample, | |
key_downsample=None, | |
key_query_num_convs=2, | |
key_query_norm=True, | |
value_out_num_convs=1, | |
value_out_norm=True, | |
matmul_norm=True, | |
with_out=True, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.bottleneck = ConvModule( | |
in_channels * 2, | |
in_channels, | |
1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
def forward(self, query_feats, key_feats): | |
"""Forward function.""" | |
context = super().forward(query_feats, key_feats) | |
output = self.bottleneck(torch.cat([context, query_feats], dim=1)) | |
if self.query_downsample is not None: | |
output = resize(query_feats) | |
return output | |
class OCRHead(BaseCascadeDecodeHead): | |
"""Object-Contextual Representations for Semantic Segmentation. | |
This head is the implementation of `OCRNet | |
<https://arxiv.org/abs/1909.11065>`_. | |
Args: | |
ocr_channels (int): The intermediate channels of OCR block. | |
scale (int): The scale of probability map in SpatialGatherModule in | |
Default: 1. | |
""" | |
def __init__(self, ocr_channels, scale=1, **kwargs): | |
super().__init__(**kwargs) | |
self.ocr_channels = ocr_channels | |
self.scale = scale | |
self.object_context_block = ObjectAttentionBlock( | |
self.channels, | |
self.ocr_channels, | |
self.scale, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.spatial_gather_module = SpatialGatherModule(self.scale) | |
self.bottleneck = ConvModule( | |
self.in_channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
def forward(self, inputs, prev_output): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
feats = self.bottleneck(x) | |
context = self.spatial_gather_module(feats, prev_output) | |
object_context = self.object_context_block(feats, context) | |
output = self.cls_seg(object_context) | |
return output | |