File size: 3,404 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
 
 
c74a070
 
 
a80d6bb
c74a070
 
 
 
 
 
 
a80d6bb
 
c74a070
a80d6bb
c74a070
 
a80d6bb
c74a070
a80d6bb
 
 
c74a070
 
 
 
 
 
a80d6bb
 
c74a070
 
 
 
 
 
 
 
a80d6bb
c74a070
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
 
 
 
 
a80d6bb
c74a070
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.geometry.subpix import dsnt
from kornia.utils.grid import create_meshgrid


class FineMatching(nn.Module):
    """FineMatching with s2d paradigm"""

    def __init__(self):
        super().__init__()

    def forward(self, feat_f0, feat_f1, data):
        """
        Args:
            feat0 (torch.Tensor): [M, WW, C]
            feat1 (torch.Tensor): [M, WW, C]
            data (dict)
        Update:
            data (dict):{
                'expec_f' (torch.Tensor): [M, 3],
                'mkpts0_f' (torch.Tensor): [M, 2],
                'mkpts1_f' (torch.Tensor): [M, 2]}
        """
        M, WW, C = feat_f0.shape
        W = int(math.sqrt(WW))
        scale = data["hw0_i"][0] / data["hw0_f"][0]
        self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale

        # corner case: if no coarse matches found
        if M == 0:
            assert (
                self.training == False
            ), "M is always >0, when training, see coarse_matching.py"
            # logger.warning('No matches found in coarse-level.')
            data.update(
                {
                    "expec_f": torch.empty(0, 3, device=feat_f0.device),
                    "mkpts0_f": data["mkpts0_c"],
                    "mkpts1_f": data["mkpts1_c"],
                }
            )
            return

        feat_f0_picked = feat_f0[:, WW // 2, :]

        sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1)
        softmax_temp = 1.0 / C**0.5
        heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
        feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1)  # [M, C]
        heatmap = heatmap.view(-1, W, W)

        # compute coordinates from heatmap
        coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[
            0
        ]  # [M, 2]
        grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
            1, -1, 2
        )  # [1, WW, 2]

        # compute std over <x, y>
        var = (
            torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1)
            - coords1_normalized**2
        )  # [M, 2]
        std = torch.sum(
            torch.sqrt(torch.clamp(var, min=1e-10)), -1
        )  # [M]  clamp needed for numerical stability

        # for fine-level supervision
        data.update(
            {
                "expec_f": torch.cat([coords1_normalized, std.unsqueeze(1)], -1),
                "descriptors0": feat_f0_picked.detach(),
                "descriptors1": feat_f1_picked.detach(),
            }
        )

        # compute absolute kpt coords
        self.get_fine_match(coords1_normalized, data)

    @torch.no_grad()
    def get_fine_match(self, coords1_normed, data):
        W, WW, C, scale = self.W, self.WW, self.C, self.scale

        # mkpts0_f and mkpts1_f
        # scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
        mkpts0_f = data[
            "mkpts0_c"
        ]  # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])]
        scale1 = scale * data["scale1"][data["b_ids"]] if "scale1" in data else scale
        mkpts1_f = (
            data["mkpts1_c"]
            + (coords1_normed * (W // 2) * scale1)[: len(data["mconf"])]
        )

        data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f})