|
"""Basic ResNet encoder for image feature extraction. |
|
|
|
https://pytorch.org/hub/pytorch_vision_resnet/ |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from torchvision.models.feature_extraction import create_feature_extractor |
|
|
|
from siclib.models.base_model import BaseModel |
|
|
|
|
|
|
|
|
|
def remove_conv_stride(conv): |
|
"""Remove the stride from a convolutional layer.""" |
|
conv_new = nn.Conv2d( |
|
conv.in_channels, |
|
conv.out_channels, |
|
conv.kernel_size, |
|
bias=conv.bias is not None, |
|
stride=1, |
|
padding=conv.padding, |
|
) |
|
conv_new.weight = conv.weight |
|
conv_new.bias = conv.bias |
|
return conv_new |
|
|
|
|
|
class ResNet(BaseModel): |
|
"""ResNet encoder for image features extraction.""" |
|
|
|
default_conf = { |
|
"encoder": "resnet18", |
|
"pretrained": True, |
|
"input_dim": 3, |
|
"remove_stride_from_first_conv": True, |
|
"num_downsample": None, |
|
"pixel_mean": [0.485, 0.456, 0.406], |
|
"pixel_std": [0.229, 0.224, 0.225], |
|
} |
|
|
|
required_data_keys = ["image"] |
|
|
|
def build_encoder(self, conf): |
|
"""Build the encoder from the configuration.""" |
|
if conf.pretrained: |
|
assert conf.input_dim == 3 |
|
|
|
Encoder = getattr(torchvision.models, conf.encoder) |
|
|
|
layers = ["layer1", "layer2", "layer3", "layer4"] |
|
kw = {"replace_stride_with_dilation": [False, False, False]} |
|
|
|
if conf.num_downsample is not None: |
|
layers = layers[: conf.num_downsample] |
|
|
|
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw) |
|
encoder = create_feature_extractor(encoder, return_nodes=layers) |
|
|
|
if conf.remove_stride_from_first_conv: |
|
encoder.conv1 = remove_conv_stride(encoder.conv1) |
|
|
|
return encoder, layers |
|
|
|
def _init(self, conf): |
|
self.register_buffer("pixel_mean", torch.tensor(conf.pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.tensor(conf.pixel_std).view(-1, 1, 1), False) |
|
|
|
self.encoder, self.layers = self.build_encoder(conf) |
|
|
|
def _forward(self, data): |
|
image = data["image"] |
|
image = (image - self.pixel_mean) / self.pixel_std |
|
skip_features = list(self.encoder(image).values()) |
|
|
|
|
|
return {"features": skip_features} |
|
|
|
def loss(self, pred, data): |
|
"""Compute the loss.""" |
|
raise NotImplementedError |
|
|