Vincentqyw
update: limit keypoints number
e15a186
raw history blame
No virus
2.21 kB
from pathlib import Path
import subprocess
import logging
import sys
import torch
import torchvision.transforms as tvf
from ..utils.base_model import BaseModel
logger = logging.getLogger(__name__)
fire_path = Path(__file__).parent / "../../third_party/fire"
sys.path.append(str(fire_path))
import fire_network
class FIRe(BaseModel):
default_conf = {
"global": True,
"asmk": False,
"model_name": "fire_SfM_120k.pth",
"scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
"features_num": 1000, # TODO:not supported now
"asmk_name": "asmk_codebook.bin", # TODO:not supported now
"config_name": "eval_fire.yml",
}
required_inputs = ["image"]
# Models exported using
fire_models = {
"fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
"fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
}
def _init(self, conf):
assert conf["model_name"] in self.fire_models.keys()
# Config paths
model_path = fire_path / "model" / conf["model_name"]
# Download the model.
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
link = self.fire_models[conf["model_name"]]
cmd = ["wget", link, "-O", str(model_path)]
logger.info(f"Downloading the FIRe model with `{cmd}`.")
subprocess.run(cmd, check=True)
logger.info(f"Loading fire model...")
# Load net
state = torch.load(model_path)
state["net_params"]["pretrained"] = None
net = fire_network.init_network(**state["net_params"])
net.load_state_dict(state["state_dict"])
self.net = net
self.norm_rgb = tvf.Normalize(
**dict(zip(["mean", "std"], net.runtime["mean_std"]))
)
# params
self.scales = conf["scales"]
def _forward(self, data):
image = self.norm_rgb(data["image"])
# Feature extraction.
desc = self.net.forward_global(image, scales=self.scales)
return {"global_descriptor": desc}