File size: 4,028 Bytes
10b4a5f
 
 
358ab8f
10b4a5f
358ab8f
10b4a5f
 
 
 
 
 
 
 
358ab8f
 
 
10b4a5f
358ab8f
 
 
 
 
 
 
 
 
 
 
 
10b4a5f
358ab8f
 
 
 
 
 
 
 
 
 
 
 
 
10b4a5f
358ab8f
 
 
 
 
 
 
 
 
 
10b4a5f
358ab8f
 
 
 
 
 
 
 
10b4a5f
 
358ab8f
 
 
 
 
10b4a5f
358ab8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10b4a5f
358ab8f
10b4a5f
358ab8f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import numpy as np
import os
from collections import OrderedDict, namedtuple
import sys

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)

from sgmnet import matcher as SGM_Model
from superglue import matcher as SG_Model
from utils import evaluation_utils


class GNN_Matcher(object):
    def __init__(self, config, model_name):
        assert model_name == "SGM" or model_name == "SG"

        config = namedtuple("config", config.keys())(*config.values())
        self.p_th = config.p_th
        self.model = SGM_Model(config) if model_name == "SGM" else SG_Model(config)
        self.model.cuda(), self.model.eval()
        checkpoint = torch.load(os.path.join(config.model_dir, "model_best.pth"))
        # for ddp model
        if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
            new_stat_dict = OrderedDict()
            for key, value in checkpoint["state_dict"].items():
                new_stat_dict[key[7:]] = value
            checkpoint["state_dict"] = new_stat_dict
        self.model.load_state_dict(checkpoint["state_dict"])

    def run(self, test_data):
        norm_x1, norm_x2 = evaluation_utils.normalize_size(
            test_data["x1"][:, :2], test_data["size1"]
        ), evaluation_utils.normalize_size(test_data["x2"][:, :2], test_data["size2"])
        x1, x2 = np.concatenate(
            [norm_x1, test_data["x1"][:, 2, np.newaxis]], axis=-1
        ), np.concatenate([norm_x2, test_data["x2"][:, 2, np.newaxis]], axis=-1)
        feed_data = {
            "x1": torch.from_numpy(x1[np.newaxis]).cuda().float(),
            "x2": torch.from_numpy(x2[np.newaxis]).cuda().float(),
            "desc1": torch.from_numpy(test_data["desc1"][np.newaxis]).cuda().float(),
            "desc2": torch.from_numpy(test_data["desc2"][np.newaxis]).cuda().float(),
        }
        with torch.no_grad():
            res = self.model(feed_data, test_mode=True)
            p = res["p"]
        index1, index2 = self.match_p(p[0, :-1, :-1])
        corr1, corr2 = (
            test_data["x1"][:, :2][index1.cpu()],
            test_data["x2"][:, :2][index2.cpu()],
        )
        if len(corr1.shape) == 1:
            corr1, corr2 = corr1[np.newaxis], corr2[np.newaxis]
        return corr1, corr2

    def match_p(self, p):  # p N*M
        score, index = torch.topk(p, k=1, dim=-1)
        _, index2 = torch.topk(p, k=1, dim=-2)
        mask_th, index, index2 = score[:, 0] > self.p_th, index[:, 0], index2.squeeze(0)
        mask_mc = index2[index] == torch.arange(len(p)).cuda()
        mask = mask_th & mask_mc
        index1, index2 = torch.nonzero(mask).squeeze(1), index[mask]
        return index1, index2


class NN_Matcher(object):
    def __init__(self, config):
        config = namedtuple("config", config.keys())(*config.values())
        self.mutual_check = config.mutual_check
        self.ratio_th = config.ratio_th

    def run(self, test_data):
        desc1, desc2, x1, x2 = (
            test_data["desc1"],
            test_data["desc2"],
            test_data["x1"],
            test_data["x2"],
        )
        desc_mat = np.sqrt(
            abs(
                (desc1**2).sum(-1)[:, np.newaxis]
                + (desc2**2).sum(-1)[np.newaxis]
                - 2 * desc1 @ desc2.T
            )
        )
        nn_index = np.argpartition(desc_mat, kth=(1, 2), axis=-1)
        dis_value12 = np.take_along_axis(desc_mat, nn_index, axis=-1)
        ratio_score = dis_value12[:, 0] / dis_value12[:, 1]
        nn_index1 = nn_index[:, 0]
        nn_index2 = np.argmin(desc_mat, axis=0)
        mask_ratio, mask_mutual = (
            ratio_score < self.ratio_th,
            np.arange(len(x1)) == nn_index2[nn_index1],
        )
        corr1, corr2 = x1[:, :2], x2[:, :2][nn_index1]
        if self.mutual_check:
            mask = mask_ratio & mask_mutual
        else:
            mask = mask_ratio
        corr1, corr2 = corr1[mask], corr2[mask]
        return corr1, corr2