Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,009 Bytes
c295391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""This file is modified version from mmsegmentation (https://github.com/open-mmlab/mmsegmentation)"""
import torch
import torch.nn as nn
from torch.nn import functional as F
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
nn.Conv2d(self.in_channels, self.channels, kernel_size=1),
nn.ReLU()
)
)
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = F.interpolate(
ppm_out.float(),
size=x.size()[2:],
mode='bilinear',
align_corners=False).to(torch.bfloat16)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
class UPerHead(nn.Module):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, in_channels = (96, 192, 384, 768), channels = 256, pool_scales=(1, 2, 3, 6),):
super(UPerHead, self).__init__()
# PSP Module
self.in_channels = in_channels
self.channels = channels
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels
)
self.bottleneck = nn.Sequential(
nn.Conv2d(self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, kernel_size=3, padding=1),
nn.ReLU())
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = nn.Sequential(
nn.Conv2d(in_channels, self.channels, kernel_size=1, padding=0),
nn.ReLU())
fpn_conv = nn.Sequential(
nn.Conv2d(self.channels, self.channels, kernel_size=3, padding=1),
nn.ReLU())
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = nn.Sequential(
nn.Conv2d(len(self.in_channels) * self.channels, self.channels, kernel_size=3, padding=1),
nn.ReLU())
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function.
inputs = {x_96, x_192, x_384, x_768}
"""
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + F.interpolate(
laterals[i].float(),
size = prev_shape,
mode='bilinear',
align_corners = False
).to(torch.bfloat16)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = F.interpolate(
fpn_outs[i].float(),
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=False).to(torch.bfloat16)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
return output |