from pathlib import Path import subprocess import logging import sys import torch import torchvision.transforms as tvf from ..utils.base_model import BaseModel logger = logging.getLogger(__name__) fire_path = Path(__file__).parent / "../../third_party/fire" sys.path.append(str(fire_path)) import fire_network class FIRe(BaseModel): default_conf = { "global": True, "asmk": False, "model_name": "fire_SfM_120k.pth", "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params "features_num": 1000, # TODO:not supported now "asmk_name": "asmk_codebook.bin", # TODO:not supported now "config_name": "eval_fire.yml", } required_inputs = ["image"] # Models exported using fire_models = { "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth", "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth", } def _init(self, conf): assert conf["model_name"] in self.fire_models.keys() # Config paths model_path = fire_path / "model" / conf["model_name"] # Download the model. if not model_path.exists(): model_path.parent.mkdir(exist_ok=True) link = self.fire_models[conf["model_name"]] cmd = ["wget", link, "-O", str(model_path)] logger.info(f"Downloading the FIRe model with `{cmd}`.") subprocess.run(cmd, check=True) logger.info(f"Loading fire model...") # Load net state = torch.load(model_path) state["net_params"]["pretrained"] = None net = fire_network.init_network(**state["net_params"]) net.load_state_dict(state["state_dict"]) self.net = net self.norm_rgb = tvf.Normalize( **dict(zip(["mean", "std"], net.runtime["mean_std"])) ) # params self.scales = conf["scales"] def _forward(self, data): image = self.norm_rgb(data["image"]) # Feature extraction. desc = self.net.forward_global(image, scales=self.scales) return {"global_descriptor": desc}