File size: 5,009 Bytes
c295391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""This file is modified version from mmsegmentation (https://github.com/open-mmlab/mmsegmentation)"""

import torch
import torch.nn as nn
from torch.nn import functional as F

class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet.

    Args:

        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid

            Module.

        in_channels (int): Input channels.

        channels (int): Channels after modules, before conv_seg.

        conv_cfg (dict|None): Config of conv layers.

        norm_cfg (dict|None): Config of norm layers.

        act_cfg (dict): Config of activation layers.

        align_corners (bool): align_corners argument of F.interpolate.

    """

    def __init__(self, pool_scales, in_channels, channels):
        super(PPM, self).__init__()
        self.pool_scales = pool_scales
        self.in_channels = in_channels
        self.channels = channels
        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    nn.Conv2d(self.in_channels, self.channels, kernel_size=1),
                    nn.ReLU()
                    )
            )

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)

            upsampled_ppm_out = F.interpolate(
                ppm_out.float(),
                size=x.size()[2:],
                mode='bilinear',
                align_corners=False).to(torch.bfloat16)

            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs

class UPerHead(nn.Module):
    """Unified Perceptual Parsing for Scene Understanding.

    This head is the implementation of `UPerNet

    <https://arxiv.org/abs/1807.10221>`_.

    Args:

        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid

            Module applied on the last feature. Default: (1, 2, 3, 6).

    """

    def __init__(self, in_channels = (96, 192, 384, 768), channels = 256, pool_scales=(1, 2, 3, 6),):
        super(UPerHead, self).__init__()
        # PSP Module
        self.in_channels = in_channels
        self.channels = channels
        self.psp_modules = PPM(
            pool_scales,
            self.in_channels[-1],
            self.channels
            )

        self.bottleneck = nn.Sequential(
            nn.Conv2d(self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, kernel_size=3, padding=1),
            nn.ReLU())
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for in_channels in self.in_channels[:-1]:  # skip the top layer
            l_conv = nn.Sequential(
            nn.Conv2d(in_channels, self.channels, kernel_size=1, padding=0),
            nn.ReLU())


            fpn_conv = nn.Sequential(
            nn.Conv2d(self.channels, self.channels, kernel_size=3, padding=1),
            nn.ReLU())

            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = nn.Sequential(
            nn.Conv2d(len(self.in_channels) * self.channels, self.channels, kernel_size=3, padding=1),
            nn.ReLU())


    def psp_forward(self, inputs):
        """Forward function of PSP module."""

        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)
        return output

    def forward(self, inputs):
        """Forward function.

        inputs = {x_96, x_192, x_384, x_768}

        """

        laterals = [
            lateral_conv(inputs[i])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        laterals.append(self.psp_forward(inputs))
        
        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + F.interpolate(
                laterals[i].float(),
                size = prev_shape,
                mode='bilinear',
                align_corners = False
                ).to(torch.bfloat16)
        
        # build outputs
        fpn_outs = [
            self.fpn_convs[i](laterals[i])
            for i in range(used_backbone_levels - 1)
        ]
        
        # append psp feature
        fpn_outs.append(laterals[-1])
        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = F.interpolate(
                fpn_outs[i].float(),
                size=fpn_outs[0].shape[2:],
                mode='bilinear',
                align_corners=False).to(torch.bfloat16)
        fpn_outs = torch.cat(fpn_outs, dim=1)
        output = self.fpn_bottleneck(fpn_outs)

        return output