Spaces:
Running
Running
""" | |
Code for loading models trained with CosPlace as a global features extractor | |
for geolocalization through image retrieval. | |
Multiple models are available with different backbones. Below is a summary of | |
models available (backbone : list of available output descriptors | |
dimensionality). For example you can use a model based on a ResNet50 with | |
descriptors dimensionality 1024. | |
ResNet18: [32, 64, 128, 256, 512] | |
ResNet50: [32, 64, 128, 256, 512, 1024, 2048] | |
ResNet101: [32, 64, 128, 256, 512, 1024, 2048] | |
ResNet152: [32, 64, 128, 256, 512, 1024, 2048] | |
VGG16: [ 64, 128, 256, 512] | |
CosPlace paper: https://arxiv.org/abs/2204.02287 | |
""" | |
import torch | |
import torchvision.transforms as tvf | |
from ..utils.base_model import BaseModel | |
class CosPlace(BaseModel): | |
default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048} | |
required_inputs = ["image"] | |
def _init(self, conf): | |
self.net = torch.hub.load( | |
"gmberton/CosPlace", | |
"get_trained_model", | |
backbone=conf["backbone"], | |
fc_output_dim=conf["fc_output_dim"], | |
).eval() | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
self.norm_rgb = tvf.Normalize(mean=mean, std=std) | |
def _forward(self, data): | |
image = self.norm_rgb(data["image"]) | |
desc = self.net(image) | |
return { | |
"global_descriptor": desc, | |
} | |