File size: 2,219 Bytes
9223079
8320ccc
9223079
8320ccc
 
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8320ccc
9223079
 
 
8320ccc
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import subprocess
import sys
from pathlib import Path

import torch
import torchvision.transforms as tvf

from ..utils.base_model import BaseModel

logger = logging.getLogger(__name__)
fire_path = Path(__file__).parent / "../../third_party/fire"
sys.path.append(str(fire_path))


import fire_network


class FIRe(BaseModel):
    default_conf = {
        "global": True,
        "asmk": False,
        "model_name": "fire_SfM_120k.pth",
        "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25],  # default params
        "features_num": 1000,  # TODO:not supported now
        "asmk_name": "asmk_codebook.bin",  # TODO:not supported now
        "config_name": "eval_fire.yml",
    }
    required_inputs = ["image"]

    # Models exported using
    fire_models = {
        "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
        "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
    }

    def _init(self, conf):
        assert conf["model_name"] in self.fire_models.keys()
        # Config paths
        model_path = fire_path / "model" / conf["model_name"]

        # Download the model.
        if not model_path.exists():
            model_path.parent.mkdir(exist_ok=True)
            link = self.fire_models[conf["model_name"]]
            cmd = ["wget", "--quiet", link, "-O", str(model_path)]
            logger.info(f"Downloading the FIRe model with `{cmd}`.")
            subprocess.run(cmd, check=True)

        logger.info("Loading fire model...")

        # Load net
        state = torch.load(model_path)
        state["net_params"]["pretrained"] = None
        net = fire_network.init_network(**state["net_params"])
        net.load_state_dict(state["state_dict"])
        self.net = net

        self.norm_rgb = tvf.Normalize(
            **dict(zip(["mean", "std"], net.runtime["mean_std"]))
        )

        # params
        self.scales = conf["scales"]

    def _forward(self, data):
        image = self.norm_rgb(data["image"])

        # Feature extraction.
        desc = self.net.forward_global(image, scales=self.scales)

        return {"global_descriptor": desc}