File size: 1,397 Bytes
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import warnings
from ..utils.base_model import BaseModel
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent / "../../third_party"))
from TopicFM.src.models.topic_fm import TopicFM as _TopicFM
from TopicFM.src import get_model_cfg

topicfm_path = Path(__file__).parent / "../../third_party/TopicFM"


class TopicFM(BaseModel):
    default_conf = {
        "weights": "outdoor",
        "match_threshold": 0.2,
        "n_sampling_topics": 4,
    }
    required_inputs = ["image0", "image1"]

    def _init(self, conf):
        _conf = dict(get_model_cfg())
        _conf["match_coarse"]["thr"] = conf["match_threshold"]
        _conf["coarse"]["n_samples"] = conf["n_sampling_topics"]
        weight_path = topicfm_path / "pretrained/model_best.ckpt"
        self.net = _TopicFM(config=_conf)
        ckpt_dict = torch.load(weight_path, map_location="cpu")
        self.net.load_state_dict(ckpt_dict["state_dict"])

    def _forward(self, data):
        data_ = {
            "image0": data["image0"],
            "image1": data["image1"],
        }
        self.net(data_)
        mkpts0 = data_["mkpts0_f"]
        mkpts1 = data_["mkpts1_f"]
        mconf = data_["mconf"]
        total_n_matches = len(data_["mkpts0_f"])

        pred = {}
        pred["keypoints0"], pred["keypoints1"] = mkpts0, mkpts1
        pred["mconf"] = mconf
        return pred