File size: 5,467 Bytes
3ef85e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import torch
import torch.nn as nn
import torch.nn.functional as F
from core import functional as myF
class DeepMatchingLoss (nn.Module):
""" This loss is based on DeepMatching (IJCV'16).
atleast: (int) minimum image size at which the pyramid construction stops.
sub: (int) prior subsampling
way: (str) which way to compute the asymmetric matching ('1', '2' or '12')
border: (int) ignore pixels too close to the border
rectify_p: (float) non-linear power-rectification in DeepMatching
eps: (float) epsilon for the L1 normalization. Kinda handles unmatched pixels.
"""
def __init__(self, eps=0.03, atleast=5, sub=2, way='12', border=16, rectify_p=1.5):
super().__init__()
assert way in ('1','2','12')
self.subsample = sub
self.border = border
self.way = way
self.atleast = atleast
self.rectify_p = rectify_p
self.eps = eps
self._cache = {}
def rectify(self, corr):
corr = corr.clip_(min=0)
corr = corr ** self.rectify_p
return corr
def forward(self, desc1, desc2, **kw):
# 1 --> 2
loss1 = self.forward_oneway(desc1, desc2, **kw) \
if '1' in self.way else 0
# 2 --> 1
loss2 = self.forward_oneway(desc2, desc1, **kw) \
if '2' in self.way else 0
return dict(deepm_loss=(loss1+loss2)/len(self.way))
def forward_oneway(self, desc1, desc2, dbg=(), **kw):
assert desc1.shape[:2] == desc2.shape[:2]
# prior subsampling
s = slice(self.border, -self.border or None, self.subsample)
desc1, desc2 = desc1[...,s,s], desc2[...,s,s]
desc1 = desc1[:,:,2::4,2::4] # subsample patches in 1st image
B, D, H1, W1, H2, W2 = desc1.shape + desc2.shape[-2:]
if B == 0: return 0 # empty batch
# intial 4D correlation volume
corr = torch.bmm(desc1.reshape(B,D,-1).transpose(1,2), desc2.reshape(B,D,-1)).view(B,H1,W1,H2,W2)
# build pyramid
pyramid = self.deep_matching(corr)
corr = pyramid[-1] # high-level correlation
corr = self.rectify(corr)
# L1 norm
B, H1, W1, H2, W2 = corr.shape
corr = corr / (corr.reshape(B,H1*W1,-1).sum(dim=-1).view(B,H1,W1,1,1) + self.eps)
# squared L2 norm
loss = - torch.square(corr).sum() / (B*H1*W1)
return loss
def deep_matching(self, corr):
# print(f'level=0 {corr.shape=}')
weights = None
pyramid = [corr]
for level in range(1,999):
corr, weights = self.forward_level(level, corr, weights)
pyramid.append(corr)
# print(f'{level=} {corr.shape=}')
if weights.sum() == 0: break # img1 has become too small
if min(corr.shape[-2:]) < 2*self.atleast: break # img2 has become too small
return pyramid
def forward_level(self, level, corr, weights):
B, H1, W1, H2, W2 = corr.shape
# max-pooling
pooled = F.max_pool2d(corr.view(B,H1*W1,H2,W2), 3, padding=1, stride=2)
pooled = pooled.view(B, H1, W1, *pooled.shape[-2:])
# print(f'rectifying corr at {level=}')
pooled = self.rectify(pooled)
# sparse conv
key = level, H1, W1, H2, W2
if key not in self._cache:
B, H1, W1, H2, W2 = myF.true_corr_shape(pooled.shape, level-1)
self._cache[key] = myF.children(level, H1, W1, H2, W2).to(corr.device)
return sparse_conv(level, pooled, self._cache[key], weights)
def sparse_conv(level, corr, parents, weights=None, border_norm=0.9):
B, H1, W1, H2, W2 = myF.true_corr_shape(corr.shape, level-1)
n_cache = len(parents)
# perform the sparse convolution 'manually'
# since sparse convolutions are not implemented in pytorch currently
corr = corr.view(B, -1, H2, W2)
res = corr.new_zeros((B, n_cache+1, H2, W2)) # last one = garbage channel
nrm = corr.new_full((n_cache+1, 3, 3), torch.finfo(corr.dtype).eps)
ones = nrm.new_ones((corr.shape[1], 1, 1))
ex = 1
if weights is not None:
weights = weights.view(corr.shape[1],1,1)
corr = corr * weights[None] # apply weights to correlation maps beforehand
ones *= weights
sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
c = 0
for y in (-1, 1):
for x in (-1, 1):
src_layers = parents[:,c]; c+= 1
# we want to do: res += corr[src_layers] (for all children != -1)
# but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
tgt_layers = myF.inverse_mapping(src_layers, max_elem=corr.shape[1], default=n_cache)[:-1]
# All of corr's channels MUST be utilized. for level>1, this doesn't hold,
# so we'll send them to a garbage channel ==> res[n_cache]
sel = myF.good_slice( tgt_layers < n_cache )
res[:,:,sl(-y),sl(-x)].index_add_(1, tgt_layers[sel], corr[:,sel,sl(y),sl(x)])
nrm[ :,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))
# normalize borders
weights = myF.norm_borders(res, nrm, norm=border_norm)[:-1]
res = res[:,:-1] # remove garbage channel
return res.view(B, H1+ex, W1+ex, *res.shape[-2:]), weights
|