|
import logging |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from torchvision.models.feature_extraction import create_feature_extractor |
|
|
|
from .base import BaseModel |
|
from .schema import ResNetConfiguration |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__( |
|
self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros" |
|
): |
|
super().__init__() |
|
layers = [] |
|
for i in range(num_convs): |
|
conv = nn.Conv2d( |
|
previous if i == 0 else out, |
|
out, |
|
kernel_size=ksize, |
|
padding=ksize // 2, |
|
bias=norm is None, |
|
padding_mode=padding, |
|
) |
|
layers.append(conv) |
|
if norm is not None: |
|
layers.append(norm(out)) |
|
layers.append(nn.ReLU(inplace=True)) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, previous, skip): |
|
_, _, hp, wp = previous.shape |
|
_, _, hs, ws = skip.shape |
|
scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp]))) |
|
upsampled = nn.functional.interpolate( |
|
previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False |
|
) |
|
|
|
|
|
|
|
|
|
_, _, hu, wu = upsampled.shape |
|
_, _, hs, ws = skip.shape |
|
if (hu <= hs) and (wu <= ws): |
|
skip = skip[:, :, :hu, :wu] |
|
elif (hu >= hs) and (wu >= ws): |
|
skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs]) |
|
else: |
|
raise ValueError( |
|
f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}" |
|
) |
|
|
|
return self.layers(skip) + upsampled |
|
|
|
|
|
class FPN(nn.Module): |
|
def __init__(self, in_channels_list, out_channels, **kw): |
|
super().__init__() |
|
self.first = nn.Conv2d( |
|
in_channels_list[-1], out_channels, 1, padding=0, bias=True |
|
) |
|
self.blocks = nn.ModuleList( |
|
[ |
|
DecoderBlock(c, out_channels, ksize=1, **kw) |
|
for c in in_channels_list[::-1][1:] |
|
] |
|
) |
|
self.out = nn.Sequential( |
|
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
def forward(self, layers): |
|
feats = None |
|
for idx, x in enumerate(reversed(layers.values())): |
|
if feats is None: |
|
feats = self.first(x) |
|
else: |
|
feats = self.blocks[idx - 1](feats, x) |
|
out = self.out(feats) |
|
return out |
|
|
|
|
|
def remove_conv_stride(conv): |
|
conv_new = nn.Conv2d( |
|
conv.in_channels, |
|
conv.out_channels, |
|
conv.kernel_size, |
|
bias=conv.bias is not None, |
|
stride=1, |
|
padding=conv.padding, |
|
) |
|
conv_new.weight = conv.weight |
|
conv_new.bias = conv.bias |
|
return conv_new |
|
|
|
|
|
class FeatureExtractor(BaseModel): |
|
default_conf = { |
|
"pretrained": True, |
|
"input_dim": 3, |
|
"output_dim": 128, |
|
"encoder": "resnet50", |
|
"remove_stride_from_first_conv": False, |
|
"num_downsample": None, |
|
"decoder_norm": "nn.BatchNorm2d", |
|
"do_average_pooling": False, |
|
"checkpointed": False, |
|
} |
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
|
|
def freeze_encoder(self): |
|
""" |
|
Freeze the encoder part of the model, i.e., set requires_grad = False |
|
for all parameters in the encoder. |
|
""" |
|
for param in self.encoder.parameters(): |
|
param.requires_grad = False |
|
logger.debug("Encoder has been frozen.") |
|
|
|
def unfreeze_encoder(self): |
|
""" |
|
Unfreeze the encoder part of the model, i.e., set requires_grad = True |
|
for all parameters in the encoder. |
|
""" |
|
for param in self.encoder.parameters(): |
|
param.requires_grad = True |
|
logger.debug("Encoder has been unfrozen.") |
|
|
|
def build_encoder(self, conf: ResNetConfiguration): |
|
assert isinstance(conf.encoder, str) |
|
if conf.pretrained: |
|
assert conf.input_dim == 3 |
|
Encoder = getattr(torchvision.models, conf.encoder) |
|
|
|
kw = {} |
|
if conf.encoder.startswith("resnet"): |
|
layers = ["relu", "layer1", "layer2", "layer3", "layer4"] |
|
kw["replace_stride_with_dilation"] = [False, False, False] |
|
elif conf.encoder == "vgg13": |
|
layers = [ |
|
"features.3", |
|
"features.8", |
|
"features.13", |
|
"features.18", |
|
"features.23", |
|
] |
|
elif conf.encoder == "vgg16": |
|
layers = [ |
|
"features.3", |
|
"features.8", |
|
"features.15", |
|
"features.22", |
|
"features.29", |
|
] |
|
else: |
|
raise NotImplementedError(conf.encoder) |
|
|
|
if conf.num_downsample is not None: |
|
layers = layers[: conf.num_downsample] |
|
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw) |
|
encoder = create_feature_extractor(encoder, return_nodes=layers) |
|
if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv: |
|
encoder.conv1 = remove_conv_stride(encoder.conv1) |
|
|
|
if conf.do_average_pooling: |
|
raise NotImplementedError |
|
if conf.checkpointed: |
|
raise NotImplementedError |
|
|
|
return encoder, layers |
|
|
|
def _init(self, conf): |
|
|
|
self.register_buffer("mean_", torch.tensor(self.mean), persistent=False) |
|
self.register_buffer("std_", torch.tensor(self.std), persistent=False) |
|
|
|
|
|
self.encoder, self.layers = self.build_encoder(conf) |
|
s = 128 |
|
inp = torch.zeros(1, 3, s, s) |
|
features = list(self.encoder(inp).values()) |
|
self.skip_dims = [x.shape[1] for x in features] |
|
self.layer_strides = [s / f.shape[-1] for f in features] |
|
self.scales = [self.layer_strides[0]] |
|
|
|
|
|
norm = eval(conf.decoder_norm) if conf.decoder_norm else None |
|
self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm) |
|
|
|
logger.debug( |
|
"Built feature extractor with layers {name:dim:stride}:\n" |
|
f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n" |
|
f"and output scales {self.scales}." |
|
) |
|
|
|
def _forward(self, data): |
|
image = data["image"] |
|
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None] |
|
|
|
skip_features = self.encoder(image) |
|
output = self.decoder(skip_features) |
|
return output, data['camera'] |
|
|