File size: 1,562 Bytes
aa46ae9
8320ccc
aa46ae9
8320ccc
aa46ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8320ccc
aa46ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch

from hloc import logger

from ..utils.base_model import BaseModel


class XFeatDense(BaseModel):
    default_conf = {
        "keypoint_threshold": 0.005,
        "max_keypoints": 8000,
    }
    required_inputs = [
        "image0",
        "image1",
    ]

    def _init(self, conf):
        self.net = torch.hub.load(
            "verlab/accelerated_features",
            "XFeat",
            pretrained=True,
            top_k=self.conf["max_keypoints"],
        )
        logger.info("Load XFeat(dense) model done.")

    def _forward(self, data):
        # Compute coarse feats
        out0 = self.net.detectAndComputeDense(
            data["image0"], top_k=self.conf["max_keypoints"]
        )
        out1 = self.net.detectAndComputeDense(
            data["image1"], top_k=self.conf["max_keypoints"]
        )

        # Match batches of pairs
        idxs_list = self.net.batch_match(
            out0["descriptors"], out1["descriptors"]
        )
        B = len(data["image0"])

        # Refine coarse matches
        # this part is harder to batch, currently iterate
        matches = []
        for b in range(B):
            matches.append(
                self.net.refine_matches(
                    out0, out1, matches=idxs_list, batch_idx=b
                )
            )
        # we use results from one batch
        matches = matches[0]
        pred = {
            "keypoints0": matches[:, :2],
            "keypoints1": matches[:, 2:],
            "mconf": torch.ones_like(matches[:, 0]),
        }
        return pred