File size: 2,816 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
c74a070
 
a80d6bb
c74a070
 
a80d6bb
 
 
c74a070
a80d6bb
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
c74a070
 
a80d6bb
 
 
 
 
c74a070
 
 
 
 
 
 
 
a80d6bb
 
c74a070
 
a80d6bb
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.einops import rearrange, repeat


class FinePreprocess(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.cat_c_feat = config["fine_concat_coarse_feat"]
        self.W = self.config["fine_window_size"]

        d_model_c = self.config["coarse"]["d_model"]
        d_model_f = self.config["fine"]["d_model"]
        self.d_model_f = d_model_f
        if self.cat_c_feat:
            self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
            self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")

    def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
        W = self.W
        stride = data["hw0_f"][0] // data["hw0_c"][0]

        data.update({"W": W})
        if data["b_ids"].shape[0] == 0:
            feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
            feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
            return feat0, feat1

        # 1. unfold(crop) all local windows
        feat_f0_unfold = F.unfold(
            feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2
        )
        feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2)
        feat_f1_unfold = F.unfold(
            feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2
        )
        feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2)

        # 2. select only the predicted matches
        feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]]  # [n, ww, cf]
        feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]]

        # option: use coarse-level feature as context: concat and linear
        if self.cat_c_feat:
            feat_c_win = self.down_proj(
                torch.cat(
                    [
                        feat_c0[data["b_ids"], data["i_ids"]],
                        feat_c1[data["b_ids"], data["j_ids"]],
                    ],
                    0,
                )
            )  # [2n, c]
            feat_cf_win = self.merge_feat(
                torch.cat(
                    [
                        torch.cat([feat_f0_unfold, feat_f1_unfold], 0),  # [2n, ww, cf]
                        repeat(feat_c_win, "n c -> n ww c", ww=W**2),  # [2n, ww, cf]
                    ],
                    -1,
                )
            )
            feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)

        return feat_f0_unfold, feat_f1_unfold