| |
| |
|
|
| """ |
| @Author : Peike Li |
| @Contact : peike.li@yahoo.com |
| @File : psp.py |
| @Time : 8/4/19 3:36 PM |
| @Desc : |
| @License : This source code is licensed under the license found in the |
| LICENSE file in the root directory of this source tree. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
| from modules import InPlaceABNSync |
|
|
|
|
| class PSPModule(nn.Module): |
| """ |
| Reference: |
| Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* |
| """ |
| def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)): |
| super(PSPModule, self).__init__() |
|
|
| self.stages = [] |
| self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes]) |
| self.bottleneck = nn.Sequential( |
| nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1, |
| bias=False), |
| InPlaceABNSync(out_features), |
| ) |
|
|
| def _make_stage(self, features, out_features, size): |
| prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) |
| conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False) |
| bn = InPlaceABNSync(out_features) |
| return nn.Sequential(prior, conv, bn) |
|
|
| def forward(self, feats): |
| h, w = feats.size(2), feats.size(3) |
| priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in |
| self.stages] + [feats] |
| bottle = self.bottleneck(torch.cat(priors, 1)) |
| return bottle |