| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init |
|
|
| from mmdet.registry import MODELS |
|
|
|
|
| class PyramidPoolingModule(nn.Module): |
|
|
| def __init__(self, |
| in_channels, |
| channels=512, |
| sizes=(1, 2, 3, 6), |
| act_cfg=dict(type='ReLU')): |
| super().__init__() |
| self.stages = [] |
| self.stages = nn.ModuleList( |
| [self._make_stage(in_channels, channels, size) for size in sizes]) |
| self.bottleneck = nn.Conv2d(in_channels + len(sizes) * channels, |
| in_channels, 1) |
| self.act = MODELS.build(act_cfg) |
|
|
| def _make_stage(self, features, out_features, size): |
| prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) |
| conv = nn.Conv2d(features, out_features, 1) |
| return nn.Sequential(prior, conv) |
|
|
| def forward(self, feats): |
| h, w = feats.size(2), feats.size(3) |
| priors = [ |
| F.interpolate( |
| input=self.act(stage(feats)), |
| size=(h, w), |
| mode='bilinear', |
| align_corners=False) for stage in self.stages |
| ] + [feats] |
| out = self.act(self.bottleneck(torch.cat(priors, 1))) |
| return out |
|
|
|
|
| @MODELS.register_module() |
| class InstanceContextEncoder(nn.Module): |
| """ |
| Instance Context Encoder |
| 1. construct feature pyramids from ResNet |
| 2. enlarge receptive fields (ppm) |
| 3. multi-scale fusion |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels=256, |
| with_ppm=True, |
| act_cfg=dict(type='ReLU')): |
| super().__init__() |
| self.num_channels = out_channels |
| self.in_channels = in_channels |
| self.with_ppm = with_ppm |
| fpn_laterals = [] |
| fpn_outputs = [] |
| for in_channel in reversed(self.in_channels): |
| lateral_conv = nn.Conv2d(in_channel, self.num_channels, 1) |
| output_conv = nn.Conv2d( |
| self.num_channels, self.num_channels, 3, padding=1) |
| caffe2_xavier_init(lateral_conv) |
| caffe2_xavier_init(output_conv) |
| fpn_laterals.append(lateral_conv) |
| fpn_outputs.append(output_conv) |
| self.fpn_laterals = nn.ModuleList(fpn_laterals) |
| self.fpn_outputs = nn.ModuleList(fpn_outputs) |
| |
| if self.with_ppm: |
| self.ppm = PyramidPoolingModule( |
| self.num_channels, self.num_channels // 4, act_cfg=act_cfg) |
| |
| self.fusion = nn.Conv2d(self.num_channels * 3, self.num_channels, 1) |
| kaiming_init(self.fusion) |
|
|
| def forward(self, features): |
| features = features[::-1] |
| prev_features = self.fpn_laterals[0](features[0]) |
| if self.with_ppm: |
| prev_features = self.ppm(prev_features) |
| outputs = [self.fpn_outputs[0](prev_features)] |
| for feature, lat_conv, output_conv in zip(features[1:], |
| self.fpn_laterals[1:], |
| self.fpn_outputs[1:]): |
| lat_features = lat_conv(feature) |
| top_down_features = F.interpolate( |
| prev_features, scale_factor=2.0, mode='nearest') |
| prev_features = lat_features + top_down_features |
| outputs.insert(0, output_conv(prev_features)) |
| size = outputs[0].shape[2:] |
| features = [outputs[0]] + [ |
| F.interpolate(x, size, mode='bilinear', align_corners=False) |
| for x in outputs[1:] |
| ] |
| features = self.fusion(torch.cat(features, dim=1)) |
| return features |
|
|