File size: 5,467 Bytes
3ef85e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

from pdb import set_trace as bb

import torch
import torch.nn as nn
import torch.nn.functional as F

from core import functional as myF


class DeepMatchingLoss (nn.Module):
    """ This loss is based on DeepMatching (IJCV'16).
    atleast:    (int) minimum image size at which the pyramid construction stops.
    sub:        (int) prior subsampling
    way:        (str) which way to compute the asymmetric matching ('1', '2' or '12')
    border:     (int) ignore pixels too close to the border
    rectify_p:  (float) non-linear power-rectification in DeepMatching
    eps:        (float) epsilon for the L1 normalization. Kinda handles unmatched pixels.
    """
    def __init__(self, eps=0.03, atleast=5, sub=2, way='12', border=16, rectify_p=1.5):
        super().__init__()
        assert way in ('1','2','12')
        self.subsample = sub
        self.border = border
        self.way = way
        self.atleast = atleast
        self.rectify_p = rectify_p
        self.eps = eps

        self._cache = {}

    def rectify(self, corr):
        corr = corr.clip_(min=0)
        corr = corr ** self.rectify_p
        return corr
        
    def forward(self, desc1, desc2, **kw):
        # 1 --> 2
        loss1 = self.forward_oneway(desc1, desc2, **kw) \
                if '1' in self.way else 0

        # 2 --> 1
        loss2 = self.forward_oneway(desc2, desc1, **kw) \
                if '2' in self.way else 0

        return dict(deepm_loss=(loss1+loss2)/len(self.way))

    def forward_oneway(self, desc1, desc2, dbg=(), **kw):
        assert desc1.shape[:2] == desc2.shape[:2]

        # prior subsampling
        s = slice(self.border, -self.border or None, self.subsample)
        desc1, desc2 = desc1[...,s,s], desc2[...,s,s]
        desc1 = desc1[:,:,2::4,2::4] # subsample patches in 1st image
        B, D, H1, W1, H2, W2 = desc1.shape + desc2.shape[-2:]
        if B == 0: return 0 # empty batch

        # intial 4D correlation volume
        corr = torch.bmm(desc1.reshape(B,D,-1).transpose(1,2), desc2.reshape(B,D,-1)).view(B,H1,W1,H2,W2)

        # build pyramid
        pyramid = self.deep_matching(corr)
        corr = pyramid[-1] # high-level correlation
        corr = self.rectify(corr)

        # L1 norm
        B, H1, W1, H2, W2 = corr.shape
        corr = corr / (corr.reshape(B,H1*W1,-1).sum(dim=-1).view(B,H1,W1,1,1) + self.eps)

        # squared L2 norm 
        loss = - torch.square(corr).sum() / (B*H1*W1)
        return loss

    def deep_matching(self, corr):
        # print(f'level=0 {corr.shape=}')
        weights = None
        pyramid = [corr]
        for level in range(1,999):
            corr, weights = self.forward_level(level, corr, weights)
            pyramid.append(corr)
            # print(f'{level=} {corr.shape=}')
            if weights.sum() == 0: break # img1 has become too small
            if min(corr.shape[-2:]) < 2*self.atleast: break # img2 has become too small
        return pyramid

    def forward_level(self, level, corr, weights):
        B, H1, W1, H2, W2 = corr.shape

        # max-pooling
        pooled = F.max_pool2d(corr.view(B,H1*W1,H2,W2), 3, padding=1, stride=2)
        pooled = pooled.view(B, H1, W1, *pooled.shape[-2:])

        # print(f'rectifying corr at {level=}')
        pooled = self.rectify(pooled)

        # sparse conv
        key = level, H1, W1, H2, W2
        if key not in self._cache:
            B, H1, W1, H2, W2 = myF.true_corr_shape(pooled.shape, level-1)
            self._cache[key] = myF.children(level, H1, W1, H2, W2).to(corr.device)

        return sparse_conv(level, pooled, self._cache[key], weights)


def sparse_conv(level, corr, parents, weights=None, border_norm=0.9):
    B, H1, W1, H2, W2 = myF.true_corr_shape(corr.shape, level-1)
    n_cache = len(parents)

    # perform the sparse convolution 'manually'
    # since sparse convolutions are not implemented in pytorch currently
    corr = corr.view(B, -1, H2, W2)

    res = corr.new_zeros((B, n_cache+1, H2, W2)) # last one = garbage channel
    nrm = corr.new_full((n_cache+1, 3, 3), torch.finfo(corr.dtype).eps)
    ones = nrm.new_ones((corr.shape[1], 1, 1))
    ex = 1
    if weights is not None: 
        weights = weights.view(corr.shape[1],1,1)
        corr = corr * weights[None] # apply weights to correlation maps beforehand
        ones *= weights

    sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
    c = 0
    for y in (-1, 1):
        for x in (-1, 1):
            src_layers = parents[:,c]; c+= 1
            # we want to do: res += corr[src_layers]  (for all children != -1)
            # but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
            tgt_layers = myF.inverse_mapping(src_layers, max_elem=corr.shape[1], default=n_cache)[:-1]

            # All of corr's channels MUST be utilized. for level>1, this doesn't hold,
            # so we'll send them to a garbage channel ==> res[n_cache]
            sel = myF.good_slice( tgt_layers < n_cache )

            res[:,:,sl(-y),sl(-x)].index_add_(1, tgt_layers[sel], corr[:,sel,sl(y),sl(x)])
            nrm[  :,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))

    # normalize borders
    weights = myF.norm_borders(res, nrm, norm=border_norm)[:-1]

    res = res[:,:-1] # remove garbage channel
    return res.view(B, H1+ex, W1+ex, *res.shape[-2:]), weights