File size: 7,250 Bytes
fd01725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import logging

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

from .base import BaseModel
from .schema import ResNetConfiguration

logger = logging.getLogger(__name__)


class DecoderBlock(nn.Module):
    def __init__(
        self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
    ):
        super().__init__()
        layers = []
        for i in range(num_convs):
            conv = nn.Conv2d(
                previous if i == 0 else out,
                out,
                kernel_size=ksize,
                padding=ksize // 2,
                bias=norm is None,
                padding_mode=padding,
            )
            layers.append(conv)
            if norm is not None:
                layers.append(norm(out))
            layers.append(nn.ReLU(inplace=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, previous, skip):
        _, _, hp, wp = previous.shape
        _, _, hs, ws = skip.shape
        scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
        upsampled = nn.functional.interpolate(
            previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
        )
        # If the shape of the input map `skip` is not a multiple of 2,
        # it will not match the shape of the upsampled map `upsampled`.
        # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
        # If it uses ceil_mode=True (not supported here), we should pad it.
        _, _, hu, wu = upsampled.shape
        _, _, hs, ws = skip.shape
        if (hu <= hs) and (wu <= ws):
            skip = skip[:, :, :hu, :wu]
        elif (hu >= hs) and (wu >= ws):
            skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
        else:
            raise ValueError(
                f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
            )

        return self.layers(skip) + upsampled


class FPN(nn.Module):
    def __init__(self, in_channels_list, out_channels, **kw):
        super().__init__()
        self.first = nn.Conv2d(
            in_channels_list[-1], out_channels, 1, padding=0, bias=True
        )
        self.blocks = nn.ModuleList(
            [
                DecoderBlock(c, out_channels, ksize=1, **kw)
                for c in in_channels_list[::-1][1:]
            ]
        )
        self.out = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, layers):
        feats = None
        for idx, x in enumerate(reversed(layers.values())):
            if feats is None:
                feats = self.first(x)
            else:
                feats = self.blocks[idx - 1](feats, x)
        out = self.out(feats)
        return out


def remove_conv_stride(conv):
    conv_new = nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        conv.kernel_size,
        bias=conv.bias is not None,
        stride=1,
        padding=conv.padding,
    )
    conv_new.weight = conv.weight
    conv_new.bias = conv.bias
    return conv_new


class FeatureExtractor(BaseModel):
    default_conf = {
        "pretrained": True,
        "input_dim": 3,
        "output_dim": 128,  # # of channels in output feature maps
        "encoder": "resnet50",  # torchvision net as string
        "remove_stride_from_first_conv": False,
        "num_downsample": None,  # how many downsample block
        "decoder_norm": "nn.BatchNorm2d",  # normalization ind decoder blocks
        "do_average_pooling": False,
        "checkpointed": False,  # whether to use gradient checkpointing
    }
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    def freeze_encoder(self):
        """
        Freeze the encoder part of the model, i.e., set requires_grad = False
        for all parameters in the encoder.
        """
        for param in self.encoder.parameters():
            param.requires_grad = False
        logger.debug("Encoder has been frozen.")

    def unfreeze_encoder(self):
        """
        Unfreeze the encoder part of the model, i.e., set requires_grad = True
        for all parameters in the encoder.
        """
        for param in self.encoder.parameters():
            param.requires_grad = True
        logger.debug("Encoder has been unfrozen.")

    def build_encoder(self, conf: ResNetConfiguration):
        assert isinstance(conf.encoder, str)
        if conf.pretrained:
            assert conf.input_dim == 3
        Encoder = getattr(torchvision.models, conf.encoder)

        kw = {}
        if conf.encoder.startswith("resnet"):
            layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
            kw["replace_stride_with_dilation"] = [False, False, False]
        elif conf.encoder == "vgg13":
            layers = [
                "features.3",
                "features.8",
                "features.13",
                "features.18",
                "features.23",
            ]
        elif conf.encoder == "vgg16":
            layers = [
                "features.3",
                "features.8",
                "features.15",
                "features.22",
                "features.29",
            ]
        else:
            raise NotImplementedError(conf.encoder)

        if conf.num_downsample is not None:
            layers = layers[: conf.num_downsample]
        encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
        encoder = create_feature_extractor(encoder, return_nodes=layers)
        if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
            encoder.conv1 = remove_conv_stride(encoder.conv1)

        if conf.do_average_pooling:
            raise NotImplementedError
        if conf.checkpointed:
            raise NotImplementedError

        return encoder, layers

    def _init(self, conf):
        # Preprocessing
        self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
        self.register_buffer("std_", torch.tensor(self.std), persistent=False)

        # Encoder
        self.encoder, self.layers = self.build_encoder(conf)
        s = 128
        inp = torch.zeros(1, 3, s, s)
        features = list(self.encoder(inp).values())
        self.skip_dims = [x.shape[1] for x in features]
        self.layer_strides = [s / f.shape[-1] for f in features]
        self.scales = [self.layer_strides[0]]

        # Decoder
        norm = eval(conf.decoder_norm) if conf.decoder_norm else None  # noqa
        self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)

        logger.debug(
            "Built feature extractor with layers {name:dim:stride}:\n"
            f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
            f"and output scales {self.scales}."
        )

    def _forward(self, data):
        image = data["image"]
        image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]

        skip_features = self.encoder(image)
        output = self.decoder(skip_features)
        return output, data['camera']