File size: 7,250 Bytes
fd01725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import logging
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor
from .base import BaseModel
from .schema import ResNetConfiguration
logger = logging.getLogger(__name__)
class DecoderBlock(nn.Module):
def __init__(
self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
):
super().__init__()
layers = []
for i in range(num_convs):
conv = nn.Conv2d(
previous if i == 0 else out,
out,
kernel_size=ksize,
padding=ksize // 2,
bias=norm is None,
padding_mode=padding,
)
layers.append(conv)
if norm is not None:
layers.append(norm(out))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
def forward(self, previous, skip):
_, _, hp, wp = previous.shape
_, _, hs, ws = skip.shape
scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
upsampled = nn.functional.interpolate(
previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
)
# If the shape of the input map `skip` is not a multiple of 2,
# it will not match the shape of the upsampled map `upsampled`.
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
# If it uses ceil_mode=True (not supported here), we should pad it.
_, _, hu, wu = upsampled.shape
_, _, hs, ws = skip.shape
if (hu <= hs) and (wu <= ws):
skip = skip[:, :, :hu, :wu]
elif (hu >= hs) and (wu >= ws):
skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
else:
raise ValueError(
f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
)
return self.layers(skip) + upsampled
class FPN(nn.Module):
def __init__(self, in_channels_list, out_channels, **kw):
super().__init__()
self.first = nn.Conv2d(
in_channels_list[-1], out_channels, 1, padding=0, bias=True
)
self.blocks = nn.ModuleList(
[
DecoderBlock(c, out_channels, ksize=1, **kw)
for c in in_channels_list[::-1][1:]
]
)
self.out = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, layers):
feats = None
for idx, x in enumerate(reversed(layers.values())):
if feats is None:
feats = self.first(x)
else:
feats = self.blocks[idx - 1](feats, x)
out = self.out(feats)
return out
def remove_conv_stride(conv):
conv_new = nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
bias=conv.bias is not None,
stride=1,
padding=conv.padding,
)
conv_new.weight = conv.weight
conv_new.bias = conv.bias
return conv_new
class FeatureExtractor(BaseModel):
default_conf = {
"pretrained": True,
"input_dim": 3,
"output_dim": 128, # # of channels in output feature maps
"encoder": "resnet50", # torchvision net as string
"remove_stride_from_first_conv": False,
"num_downsample": None, # how many downsample block
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
"do_average_pooling": False,
"checkpointed": False, # whether to use gradient checkpointing
}
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def freeze_encoder(self):
"""
Freeze the encoder part of the model, i.e., set requires_grad = False
for all parameters in the encoder.
"""
for param in self.encoder.parameters():
param.requires_grad = False
logger.debug("Encoder has been frozen.")
def unfreeze_encoder(self):
"""
Unfreeze the encoder part of the model, i.e., set requires_grad = True
for all parameters in the encoder.
"""
for param in self.encoder.parameters():
param.requires_grad = True
logger.debug("Encoder has been unfrozen.")
def build_encoder(self, conf: ResNetConfiguration):
assert isinstance(conf.encoder, str)
if conf.pretrained:
assert conf.input_dim == 3
Encoder = getattr(torchvision.models, conf.encoder)
kw = {}
if conf.encoder.startswith("resnet"):
layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
kw["replace_stride_with_dilation"] = [False, False, False]
elif conf.encoder == "vgg13":
layers = [
"features.3",
"features.8",
"features.13",
"features.18",
"features.23",
]
elif conf.encoder == "vgg16":
layers = [
"features.3",
"features.8",
"features.15",
"features.22",
"features.29",
]
else:
raise NotImplementedError(conf.encoder)
if conf.num_downsample is not None:
layers = layers[: conf.num_downsample]
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
encoder = create_feature_extractor(encoder, return_nodes=layers)
if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
encoder.conv1 = remove_conv_stride(encoder.conv1)
if conf.do_average_pooling:
raise NotImplementedError
if conf.checkpointed:
raise NotImplementedError
return encoder, layers
def _init(self, conf):
# Preprocessing
self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
self.register_buffer("std_", torch.tensor(self.std), persistent=False)
# Encoder
self.encoder, self.layers = self.build_encoder(conf)
s = 128
inp = torch.zeros(1, 3, s, s)
features = list(self.encoder(inp).values())
self.skip_dims = [x.shape[1] for x in features]
self.layer_strides = [s / f.shape[-1] for f in features]
self.scales = [self.layer_strides[0]]
# Decoder
norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
logger.debug(
"Built feature extractor with layers {name:dim:stride}:\n"
f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
f"and output scales {self.scales}."
)
def _forward(self, data):
image = data["image"]
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
skip_features = self.encoder(image)
output = self.decoder(skip_features)
return output, data['camera']
|