veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
2.28 kB
"""Simple VGG encoder for image features extraction."""
import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor
from siclib.models.base_model import BaseModel
# mypy: ignore-errors
class VGG(BaseModel):
"""VGG encoder for image features extraction."""
default_conf = {
"encoder": "vgg13",
"pretrained": True,
"input_dim": 3,
"num_downsample": None, # how many downsample blocs to use
"pixel_mean": [0.485, 0.456, 0.406],
"pixel_std": [0.229, 0.224, 0.225],
}
required_data_keys = ["image"]
def build_encoder(self, conf):
"""Build the encoder from the configuration."""
if conf.pretrained:
assert conf.input_dim == 3
Encoder = getattr(torchvision.models, conf.encoder)
kw = {}
if conf.encoder == "vgg13":
layers = [
"features.3",
"features.8",
"features.13",
"features.18",
"features.23",
]
elif conf.encoder == "vgg16":
layers = [
"features.3",
"features.8",
"features.15",
"features.22",
"features.29",
]
else:
raise NotImplementedError(f"Encoder not implemented: {conf.encoder}")
if conf.num_downsample is not None:
layers = layers[: conf.num_downsample]
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
encoder = create_feature_extractor(encoder, return_nodes=layers)
return encoder, layers
def _init(self, conf):
self.register_buffer("pixel_mean", torch.tensor(conf.pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.tensor(conf.pixel_std).view(-1, 1, 1), False)
self.encoder, self.layers = self.build_encoder(conf)
def _forward(self, data):
image = data["image"]
image = (image - self.pixel_mean) / self.pixel_std
skip_features = self.encoder(image).values()
return {"features": skip_features}
def loss(self, pred, data):
"""Compute the loss."""
raise NotImplementedError