Spaces:
Running
Running
from pathlib import Path | |
import subprocess | |
import logging | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from scipy.io import loadmat | |
from ..utils.base_model import BaseModel | |
logger = logging.getLogger(__name__) | |
EPS = 1e-6 | |
class NetVLADLayer(nn.Module): | |
def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True): | |
super().__init__() | |
self.score_proj = nn.Conv1d(input_dim, K, kernel_size=1, bias=score_bias) | |
centers = nn.parameter.Parameter(torch.empty([input_dim, K])) | |
nn.init.xavier_uniform_(centers) | |
self.register_parameter("centers", centers) | |
self.intranorm = intranorm | |
self.output_dim = input_dim * K | |
def forward(self, x): | |
b = x.size(0) | |
scores = self.score_proj(x) | |
scores = F.softmax(scores, dim=1) | |
diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1) | |
desc = (scores.unsqueeze(1) * diff).sum(dim=-1) | |
if self.intranorm: | |
# From the official MATLAB implementation. | |
desc = F.normalize(desc, dim=1) | |
desc = desc.view(b, -1) | |
desc = F.normalize(desc, dim=1) | |
return desc | |
class NetVLAD(BaseModel): | |
default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True} | |
required_inputs = ["image"] | |
# Models exported using | |
# https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m. | |
dir_models = { | |
"VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat", | |
"VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat", | |
} | |
def _init(self, conf): | |
assert conf["model_name"] in self.dir_models.keys() | |
# Download the checkpoint. | |
checkpoint = Path(torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat") | |
if not checkpoint.exists(): | |
checkpoint.parent.mkdir(exist_ok=True, parents=True) | |
link = self.dir_models[conf["model_name"]] | |
cmd = ["wget", link, "-O", str(checkpoint)] | |
logger.info(f"Downloading the NetVLAD model with `{cmd}`.") | |
subprocess.run(cmd, check=True) | |
# Create the network. | |
# Remove classification head. | |
backbone = list(models.vgg16().children())[0] | |
# Remove last ReLU + MaxPool2d. | |
self.backbone = nn.Sequential(*list(backbone.children())[:-2]) | |
self.netvlad = NetVLADLayer() | |
if conf["whiten"]: | |
self.whiten = nn.Linear(self.netvlad.output_dim, 4096) | |
# Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open | |
mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True) | |
# CNN weights. | |
for layer, mat_layer in zip(self.backbone.children(), mat["net"].layers): | |
if isinstance(layer, nn.Conv2d): | |
w = mat_layer.weights[0] # Shape: S x S x IN x OUT | |
b = mat_layer.weights[1] # Shape: OUT | |
# Prepare for PyTorch - enforce float32 and right shape. | |
# w should have shape: OUT x IN x S x S | |
# b should have shape: OUT | |
w = torch.tensor(w).float().permute([3, 2, 0, 1]) | |
b = torch.tensor(b).float() | |
# Update layer weights. | |
layer.weight = nn.Parameter(w) | |
layer.bias = nn.Parameter(b) | |
# NetVLAD weights. | |
score_w = mat["net"].layers[30].weights[0] # D x K | |
# centers are stored as opposite in official MATLAB code | |
center_w = -mat["net"].layers[30].weights[1] # D x K | |
# Prepare for PyTorch - make sure it is float32 and has right shape. | |
# score_w should have shape K x D x 1 | |
# center_w should have shape D x K | |
score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1) | |
center_w = torch.tensor(center_w).float() | |
# Update layer weights. | |
self.netvlad.score_proj.weight = nn.Parameter(score_w) | |
self.netvlad.centers = nn.Parameter(center_w) | |
# Whitening weights. | |
if conf["whiten"]: | |
w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT | |
b = mat["net"].layers[33].weights[1] # Shape: OUT | |
# Prepare for PyTorch - make sure it is float32 and has right shape | |
w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN | |
b = torch.tensor(b.squeeze()).float() # Shape: OUT | |
# Update layer weights. | |
self.whiten.weight = nn.Parameter(w) | |
self.whiten.bias = nn.Parameter(b) | |
# Preprocessing parameters. | |
self.preprocess = { | |
"mean": mat["net"].meta.normalization.averageImage[0, 0], | |
"std": np.array([1, 1, 1], dtype=np.float32), | |
} | |
def _forward(self, data): | |
image = data["image"] | |
assert image.shape[1] == 3 | |
assert image.min() >= -EPS and image.max() <= 1 + EPS | |
image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255. | |
mean = self.preprocess["mean"] | |
std = self.preprocess["std"] | |
image = image - image.new_tensor(mean).view(1, -1, 1, 1) | |
image = image / image.new_tensor(std).view(1, -1, 1, 1) | |
# Feature extraction. | |
descriptors = self.backbone(image) | |
b, c, _, _ = descriptors.size() | |
descriptors = descriptors.view(b, c, -1) | |
# NetVLAD layer. | |
descriptors = F.normalize(descriptors, dim=1) # Pre-normalization. | |
desc = self.netvlad(descriptors) | |
# Whiten if needed. | |
if hasattr(self, "whiten"): | |
desc = self.whiten(desc) | |
desc = F.normalize(desc, dim=1) # Final L2 normalization. | |
return {"global_descriptor": desc} | |