# Copyright 2022-present NAVER Corp. # CC BY-NC-SA 4.0 # Available only for non-commercial use from pdb import set_trace as bb import numpy as np import torch import torch.nn.functional as F def affmul( aff, vecs ): """ affine multiplication: computes aff @ vecs.T """ if aff is None: return vecs if isinstance(aff, (tuple,list)) or aff.ndim==3: assert len(aff) == 2 assert 4 <= vecs.shape[-1], bb() vecs = vecs.clone() if isinstance(vecs, torch.Tensor) else vecs.copy() vecs[...,0:2] = affmul(aff[0], vecs[...,0:2]) vecs[...,2:4] = affmul(aff[1], vecs[...,2:4]) return vecs else: assert vecs.shape[-1] == 2, bb() assert aff.shape == (2,3) or (aff.shape==(3,3) and aff[2,0] == aff[2,1] == 0 and aff[2,2] == 1), bb() return (vecs @ aff[:2,:2].T) + aff[:2,2] def imresize( img, max_size, mode='area' ): # trf: cur_pix --> old_pix img, trf = img if isinstance(img,tuple) else (img, torch.eye(3,device=img.device)) shape = img.shape[-2:] if max_size > 0 and max(shape) > max_size: new_shape = tuple(i * max_size // max(shape) for i in shape) img = F.interpolate( img[None].float(), size=new_shape, mode=mode )[0] img.clamp_(min=0, max=255) sca = torch.diag(torch.tensor((shape[0]/new_shape[0],shape[1]/new_shape[1],1), device=img.device)) img = img.byte() trf = trf @ sca # undo sca first return img, trf def rotate_img( img, angle, crop=False ): if angle in (0, 90, 180, 270): return rotate_img_90(img,angle) img, trf = img assert trf.shape == (3,3) def centered_rotation(rotation, shape, **device): # rotation matrix # pt_in_original_image = rot * pt_in_rotated_image angle = rotation * np.pi / 180 c, s = np.cos(angle), np.sin(angle) rot = torch.tensor([(c, -s, 0), (s, c, 0), (0, 0, 1)], dtype=torch.float32, **device) # determine center of rotation before H, W = shape c_before = torch.tensor((W,H), **device) / 2 if crop: c_after = c_before rot_size = (W,H) else: # enlarge image to fit everything corners = torch.tensor([(0, W, W, 0), (0, 0, H, H)], dtype=torch.float32, **device) corners = affmul(rot, corners.T).T rot_size = (corners.max(dim=1).values - corners.min(dim=1).values + 0.5).int() rot_size = (rot_size // 4) * 4 # legacy c_after = rot_size / 2 rot[:2,2] = c_before - affmul(rot, c_after) # fix translation return rot, tuple(rot_size)[::-1] C, H, W = img.shape rot, (OH, OW) = centered_rotation(angle, (H,W), device=img.device) # pt_in_original_image = rot * pt_in_rotated_image # but pytorch works in [-1,1] coordinates... annoying # pt_in_original_1_1 = orig_px_to_1_1 * rot * rotated_1_1_to_px * pt_in_rotated_1_1 _1_1_to_px = lambda W,H: torch.tensor(((W/2, 0, W/2), (0, H/2, H/2), (0, 0, 1)), device=img.device) theta = torch.inverse(_1_1_to_px(W-1,H-1)) @ rot @ _1_1_to_px(OW-1,OH-1) grid = F.affine_grid(theta[None,:2], (1, C, OH, OW), align_corners=True) res = F.grid_sample(img[None].float(), grid, align_corners=True).to(dtype=img.dtype)[0] return res, trf @ rot def rotate_img_90( img, angle ): """ Rotate an image by a multiple of 90 degrees using simple transpose and flip ops. img = tuple( image, existing_trf ) existing_trf: current --> old """ angle = angle % 360 assert angle in (0, 90, 180, 270), 'cannot handle rotation other than multiple of 90 degrees' img, trf = img assert trf.shape == (3,3) if isinstance(img, np.ndarray): assert img.ndim == 3 and 1 <= img.shape[2] <= 3 new, x, y = np.float32, 1, 0 flip = lambda i,d: np.flip(i,axis=d) elif isinstance(img, torch.Tensor): assert img.ndim == 3 and 1 <= img.shape[0] <= 3 new, x, y = trf.new, -1, -2 flip = lambda i,d: i.flip(dims=[d]) H, W = img.shape[y], img.shape[x] if angle == 90: # point 0,0 --> (0, H-1); W-1,0 --> 0,0 img = flip(img.swapaxes(x,y),y) trf = trf @ new([[0,-1,W-1],[1,0,0],[0,0,1]]) # inverse transform: new --> current if angle == 180: # point 0,0 --> (W-1, H-1) img = flip(flip(img,x),y) trf = trf @ new([[-1,0,W-1],[0,-1,H-1],[0,0,1]]) # inverse transform: new --> current if angle == 270: # point 0,0 --> (H-1, 0); 0,H-1 --> 0,0 img = flip(img.swapaxes(x,y),x) trf = trf @ new([[0,1,0],[-1,0,H-1],[0,0,1]]) # inverse transform: new --> current return img, trf def encode_scale_rot(scale, rot): s = np.int32(np.rint(np.log(scale) / (0.5*np.log(2)))) r = np.int32(np.rint(((-rot) % 360) / 45)) % 8 return 8*s + (r%8) def decode_scale_rot( code ): s = code // 8 r = (code % 8) return 2 ** (s/2), -((45 * r + 180) % 360 - 180) def normalized_corr(patches, img, padding='ncc', extra_patch=False, ret_norms=False): assert patches.ndim == 4, 'patches shape must be (H*W, C, K, K)' P, C, K, K = patches.shape assert img.ndim == 3 and img.shape[0] == C, 'img shape must be (C, W, H)' eps = torch.finfo(patches.dtype).tiny # normalize on patches side norms = patches.view(P,-1).norm(dim=-1) patches = patches / norms[:,None,None,None].clamp(min=eps) # convolve normalized patches on unnormalized image ninth = 0 if padding == 'ninth': ninth = img[:,-1].mean() # ninth dimension img = F.pad(img[None], (K//2,K//2)*2, mode='constant', value=ninth)[0] corr = F.conv2d(img[None], patches, padding=0, bias=None)[0] # normalize on img's side ones = patches.new_ones((1, C, K, K)) local_norm = torch.sqrt(F.conv2d(img[None]**2, ones))[0] corr /= local_norm # normalize on patches' side (image borders) if padding == 'ncc': local_norm = torch.sqrt(F.conv2d(ones, patches**2, padding=2))[0] local_norm.clamp_(min=eps) for j in range(-2, 3): for i in range(-2,3): if i == j == 2: continue # normal case is already normalized if i == 2: i = slice(2,-2) if j == 2: j = slice(2,-2) corr[:,j,i] /= local_norm[:,j,i] return (corr, norms) if ret_norms else corr def true_corr_shape( corr_shape, level ): H1, W1, H2, W2 = corr_shape[-4:] if level > 0: # recover true size H1, W1 = H1-1, W1-1 return corr_shape[:-4] + (H1, W1, H2, W2) def children(level, H1, W1, H2, W2): """ level: parent level (> 1) """ gap = 2**(level-2) # @ level 1: gap=0.5 (parent at x=1 has children at x=[0.5, 1.5]) # @ level 2: gap=1 (parent at x=1 has children at x=[0, 2]) # @ level 3: gap=2 (parent at x=2 has children at x=[0, 4]) # etc. def ravel_child(x, y): # x,y is he center of the child patch inside = (0 <= x <= W1) and (0 <= y <= H1) if gap < 1: assert x % 1 == y % 1 == 0.5, bb() return int((x-0.5) + (y-0.5) * W1) if inside else -1 else: assert x % 1 == y % 1 == 0, bb() return int(x + y * (W1+1)) if inside else -1 # 4 children for each parent patch (top-left, top-right, bot-left, bot-right, -1 = None) parents = [] for h in range(H1+1): for w in range(W1+1): # enumerate the 4 children for this patch children = [ravel_child(w + gap*tx, h + gap*ty) for ty in (-1,1) for tx in (-1,1)] parents.append(children) return torch.tensor(parents, dtype=torch.int64) def sparse_conv(level, corr, weights=None, reverse=False, norm=0.9): H1, W1, H2, W2 = true_corr_shape(corr.shape, level-1 + reverse) parents = children(level, H1, W1, H2, W2).to(corr.device) n_parents = len(parents) # perform the sparse convolution 'manually' # since sparse convolutions are not implemented in pytorch currently corr = corr.view(-1, *corr.shape[-2:]) if not reverse: res = corr.new_zeros((n_parents+1,)+corr.shape[-2:]) # last one = garbage channel nrm = corr.new_full((n_parents+1,3,3), 1e-8) ones = nrm.new_ones((len(corr),1,1)) ex = 1 if weights is not None: weights = weights.view(len(corr),1,1) corr *= weights # apply weights to correlation maps without increasing memory footprint ones *= weights else: assert corr._base is not None and corr._base.shape[0] == n_parents+1 corr._base[-1] = 0 # reset garbage layer ex = 1 if level > 1 else 0 n_children = (H1+ex) * (W1+ex) res = corr.new_zeros((n_children,)+corr.shape[-2:]) sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None) c = 0 for y in (-1, 1): for x in (-1, 1): src_layers = parents[:,c]; c+= 1 # we want to do: res += corr[src_layers] (for all children != -1) # but we only have 'res.index_add_()' <==> res[tgt_layers] += corr tgt_layers = inverse_mapping(src_layers, max_elem=len(corr), default=n_parents)[:-1] if not reverse: # All of corr's channels MUST be utilized. for level>1, this doesn't hold, # so we'll send them to a garbage channel ==> res[n_parents] sel = good_slice( tgt_layers < n_parents ) res[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], corr[sel,sl(y),sl(x)]) nrm[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2)) else: ''' parent=199=11*17+12 @ (x=48, y=44) at level=1 |-- child=171 @ (x=46,y=42) at level0 |-- child=172 @ (x=50,y=42) at level0 |-- child=187 @ (x=46,y=46) at level0 |-- child=188 @ (x=50,y=46) at level0 ''' out = res[:,sl(y),sl(x)] sel = tgt_layers[:n_children] torch.maximum(out, corr._base[sel,sl(-y),sl(-x)], out=out) if not reverse: if weights is not None: corr /= weights.clamp(min=1e-12) # cancel weights weights = norm_borders(res, nrm, norm=norm)[:-1] res = res[:-1] # remove garbage channel res = res.view(H1+ex, W1+ex, *res.shape[-2:]) return res if reverse else (res, weights) def norm_borders( res, nrm, norm=0.9 ): """ apply some border normalization, modulated by `norm` - if norm=0: no normalization at all - if norm=1: full normalization Formula: nrm = k * (nrm/k)**p = k**(1-p) * nrm**p, with k=nrm[:,1,1] and p=norm """ new_weights = nrm[...,1,1].clone() nrm = (nrm[...,1:2,1:2] ** (1-norm)) * (nrm ** norm) # assert not torch.isnan(nrm).any() # normalize results on the borders res[...,0 ,0 ] /= nrm[...,0 ,0 ] res[...,0 ,1:-1] /= nrm[...,0 ,1:2] res[...,0 , -1] /= nrm[...,0 ,2 ] res[...,1:-1,0 ] /= nrm[...,1:2,0 ] res[...,1:-1,1:-1] /= nrm[...,1:2,1:2] res[...,1:-1, -1] /= nrm[...,1:2,2 ] res[..., -1,0 ] /= nrm[...,2 ,0 ] res[..., -1,1:-1] /= nrm[...,2 ,1:2] res[..., -1, -1] /= nrm[...,2 ,2 ] return new_weights def inverse_mapping( map, max_elem=None, default=None): """ given a mapping {i:j} we output {j:i} (the mapping is a torch array) """ assert isinstance(map, torch.Tensor) and map.ndim == 1 if max_elem is None: max_elem = map.max() if default is None: index = torch.empty(max_elem+1, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage else: index = torch.full((max_elem+1,), default, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage index[map] = torch.arange(len(map), device=map.device) return index def good_slice( nonzero ): good = nonzero.nonzero().ravel() return slice(good.min().item(), good.max().item()+1) def max_unpool(upper, lower, exclude_border=True): # re-compute max-pool indices if exclude_border: # apparently, we cannot unpool on the bottom and right borders in legacy code (local_argmax with ex=1) _, pos = F.max_pool2d(lower[:,:,:-1,:-1], 3, padding=1, stride=2, return_indices=True, ceil_mode=True) W1 = lower.shape[-1] pos = (pos//(W1-1))*W1 + (pos%(W1-1)) # fix the shortening else: _, pos = F.max_pool2d(lower, 3, padding=1, stride=2, return_indices=True) # because there are potential collisions between overlapping 3x3 cells, # that pytorch does not handle, we unpool in 4 successive non-overlapping steps. for i in range(2): for j in range(2): # stride=0 instead of 1 because pytorch does some size checking, this is a hack tmp = F.max_unpool2d(upper[:,:,i::2,j::2], pos[:,:,i::2,j::2], kernel_size=3, padding=0, stride=4, output_size=lower.shape[-2:]) if i == j == 0: res = tmp else: torch.maximum(res, tmp, out=res) # add scores to existing lower correlation map lower += res return lower def mgrid( shape, **kw ): """ Returns in (x, y) order (contrary to numpy which is (y,x) """ if isinstance(shape, torch.Tensor): shape = shape.shape res = torch.meshgrid(*[torch.arange(n, dtype=torch.float32, **kw) for n in shape], indexing='ij') return torch.stack(res[::-1], dim=-1).view(-1,2) def check_corres( corres, step, rot=None ): H, W, two = corres.shape assert two == 2 if isinstance(corres, np.ndarray): corres = torch.from_numpy(corres) if rot is not None: corres = affmul(rot, corres) gt = mgrid(corres.shape[:2]).view(H,W,2) assert ((gt - corres // step).abs() <= 2).float().mean() > 0.99, bb() def best_correspondences(corr): """ All positions are returned as x1, y1, x2, y2 """ if isinstance(corr, tuple): return corr # for legacy H1, W1, H2, W2 = corr.shape fix1 = lambda arr: 4*arr+2 # center of cells in img1 div = lambda a,b: torch.div(a, b, rounding_mode='trunc') # because of warning in pytorch 1.9+ # best scores in img1 score1, pos1 = corr.view(H1, W1, H2*W2).max(dim=-1) pos1 = torch.cat((fix1(mgrid(score1, device=pos1.device)), pos1.view(-1,1)%W2, div(pos1.view(-1,1),W2)), dim=-1) # best scores in img2 score2, pos2 = max_pool3d( corr, kernel_size=4, stride=4 ) pos2, score2 = pos2.view(-1,1), score2.squeeze() pos2 = torch.cat((fix1(div(pos2,W2*H2)%W1), fix1(div(pos2,(W1*H2*W2))), pos2%W2, div(pos2,W2)%H2), dim=-1).float() return (pos1, score1), (pos2, score2) def intersection( set1_, set2_ ): """ Returns the indices of values in set1 that are duplicated in set2 """ set1, map1 = set1_.squeeze().unique(return_inverse=True) # map1: i1 -> j1 set2 = set2_.squeeze().unique() combined = torch.cat((set1, set2)) uniques, inverse, counts = combined.unique(return_counts=True, return_inverse=True) # j -> u, i -> j, j -> n # we are interested only in (j -> i) for n > 1: # assert counts.max() <= 2, 'there were non-unique values in either set1 or set2'+bb() # intersected_values = uniques[counts > 1] inverse1 = inverse_mapping(inverse[:len(set1)], max_elem=len(uniques)-1) intersected_indices1 = inverse1[counts>1] return inverse_mapping(map1, max_elem=len(set1)-1)[intersected_indices1] def reciprocal(self, corres1, corres2 ): pos1, score1 = corres1 pos2, score2 = corres2 (H1, W1), (H2, W2) = score1.shape, map(lambda i: 4*i+1, score2.shape) to_int = pos1.new_tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32) inter1 = intersection(pos1@to_int, pos2@to_int) res = torch.cat((pos1[inter1], score1.view(-1,1)[inter1], 0*score1.view(-1,1)[inter1]), dim=-1) return res def max_pool3d( corr, kernel_size=4, stride=4 ): H1, W1, H2, W2 = corr.shape ks, st = kernel_size, stride if corr.numel() >= 2**31 and corr.device != torch.device('cpu'): # re-implementation due to a bug in pytorch import core.cuda_deepm as kernels return kernels.max_pool3d( corr.view(1, H1*W1, H2, W2), kernel_size, stride) else: return F.max_pool3d( corr.view(1, 1, H1*W1, H2, W2), kernel_size=(H1*W1,ks,ks), stride=(1,st,st), return_indices=True) def forward_cuda(self, level, lower, weights=None, pooled=False): import core.cuda_deepm as kernels # must be imported after torch_set_gpu() assert lower.numel() < 2**31, 'please use cuda-lowmem, pytorch cannot handle big tensors' pooled = lower if pooled else F.max_pool2d(lower, 3, padding=1, stride=2) return kernels.forward_agg(level, self.border_inv, pooled, weights) def forward_cuda_lowmem(self, level, lower, weights=None): import core.cuda_deepm as kernels # must be imported after torch_set_gpu() return kernels.forward_pool_agg(level, self.border_inv, lower, weights) def backward_cuda(self, level, pyramid): import core.cuda_deepm as kernels # must be imported after torch_set_gpu() kernels.backward_agg_unpool(level, pyramid[level], pyramid[level-1], True) # assert not torch.isnan(pyramid[level-1]).any(), bb() return pyramid[level-1] def merge_corres(self, corres, rots, all_corres, code): " rot : reference --> rotated " all_step = self.matcher.pixel_desc.get_atomic_patch_size() // 2 # step size in all_corres dev = all_corres[0][1].device # stack correspondences corres = [torch.cat((p.view(*s.shape,4),s[:,:,None],torch.full_like(s[:,:,None],code)),dim=2) for (p,s) in corres] import core.cuda_deepm as kernels # must be imported after torch_set_gpu() kernels.merge_corres_one_side( corres[0].to(dev), 0, rots[0].to(dev), all_corres[0][1], all_step ) kernels.merge_corres_one_side( corres[1].to(dev), 2, rots[1].to(dev), all_corres[1][1], all_step )