PUMP / test_multiscale.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
from itertools import starmap
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import test_singlescale as tss
from core import functional as myF
from tools.common import todevice, cpu
from tools.viz import dbgfig, show_correspondences
def arg_parser():
parser = tss.arg_parser()
parser.set_defaults(levels = 0, verbose=0)
parser.add_argument('--min-scale', type=float, default=None, help='min scale ratio')
parser.add_argument('--max-scale', type=float, default=4, help='max scale ratio')
parser.add_argument('--min-rot', type=float, default=None, help='min rotation (in degrees) in [-180,180]')
parser.add_argument('--max-rot', type=float, default=0, help='max rotation (in degrees) in [0,180]')
parser.add_argument('--crop-rot', action='store_true', help='crop rotated image to prevent memory blow-up')
parser.add_argument('--rot-step', type=int, default=45, help='rotation step (in degrees)')
parser.add_argument('--no-swap', type=int, default=1, nargs='?', const=0, choices=[1,0,-1], help='if 0, img1 will have keypoints on a grid')
parser.add_argument('--same-levels', action='store_true', help='use the same number of pyramid levels for all scales')
parser.add_argument('--merge', choices='torch cpu cuda'.split(), default='cpu')
return parser
class MultiScalePUMP (nn.Module):
""" DeepMatching that loops over all possible {scale x rotation} combinations.
"""
def __init__(self, matcher,
min_scale=1,
max_scale=1,
max_rot=0,
min_rot=0,
rot_step=45,
swap_mode=1,
same_levels=False,
crop_rot=False):
super().__init__()
min_scale = min_scale or 1/max_scale
min_rot = min_rot or -max_rot
assert 0.1 <= min_scale <= max_scale <= 10
assert -180 <= min_rot <= max_rot <= 180
self.matcher = matcher
self.matcher.crop_rot = crop_rot
self.min_sc = min_scale
self.max_sc = max_scale
self.min_rot = min_rot
self.max_rot = max_rot
self.rot_step = rot_step
self.swap_mode = swap_mode
self.merge_device = None
self.same_levels = same_levels
@torch.no_grad()
def forward(self, img1, img2, dbg=()):
img1, sca1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3, device=img1.device))
img2, sca2 = img2 if isinstance(img2, tuple) else (img2, torch.eye(3, device=img2.device))
# prepare correspondences accumulators
if self.same_levels: # limit number of levels
self.matcher.levels = self._find_max_levels(img1,img2)
elif self.matcher.levels == 0:
max_psize = int(min(np.mean(img1.shape[-2:]), np.mean(img2.shape[-2:])))
self.matcher.levels = int(np.log2(max_psize / self.matcher.pixel_desc.get_atomic_patch_size()))
all_corres = (self._make_accu(img1), self._make_accu(img2))
for scale, ang, code, swap, swapped, (scimg1, scimg2) in self._enum_scaled_pairs(img1, img2):
print(f"processing {scale=:g} x {ang=} {['','(swapped)'][swapped]} ({code=})...")
# compute correspondences with rotated+scaled image
corres, rots = self.process_one_scale(swapped, *[scimg1,scimg2], dbg=dbg)
if dbgfig('corres-ms', dbg): viz_correspondences(img1, img2, *corres, fig='last')
# merge correspondences in the reference frame
self.merge_corres( corres, rots, all_corres, code )
# final intersection
corres = self.reciprocal( *all_corres )
return myF.affmul(todevice((sca1,sca2),corres.device), corres) # rescaling to original image scale
def process_one_scale(self, swapped, *imgs, dbg=()):
return unswap(self.matcher(*imgs, ret='raw', dbg=dbg), swapped)
def _find_max_levels(self, img1, img2):
min_levels = self.matcher.levels or 999
for _, _, code, _, _, (img1, img2) in self._enum_scaled_pairs(img1, img2):
# first level when a parent dont have children: gap >= min(shape), with gap = 2**(level-2)
img1_levels = ceil(np.log2(min(img1[0].shape[-2:])) - 1)
# first level when img2's shape becomes smaller than self.min_shape, with shape = min(shape) / 2**level
img2_levels = ceil(np.log2(min(img2[0].shape[-2:]) / self.matcher.min_shape))
# print(f'predicted levels for {code=}:\timg1 --> {img1_levels},\timg2 --> {img2_levels} levels')
min_levels = min(min_levels, img1_levels, img2_levels)
return min_levels
def merge_corres(self, corres, rots, all_corres, code):
" rot : reference --> rotated "
self.merge_one_side( corres[0], slice(0,2), rots[0], all_corres[0], code )
self.merge_one_side( corres[1], slice(2,4), rots[1], all_corres[1], code )
def merge_one_side(self, corres, sel, trf, all_corres, code ):
pos, scores = corres
grid, accu = all_corres
accu = accu.view(-1, 6)
# compute 4-nn in transformed image for each grid point
best4 = torch.cdist(pos[:,sel].float(), grid).topk(4, dim=0, largest=False)
# best4.shape = (4, len(grid))
# update if score is better AND distance less than 2x best dist
scale = float(torch.sqrt(torch.det(trf))) # == scale (with scale >= 1)
dist_max = 8*scale - 1e-7 # 2x the distance between contiguous patches
close_enough = (best4.values <= 2*best4.values[0:1]) & (best4.values < dist_max)
neg_inf = torch.tensor(-np.inf, device=scores.device)
best_score = torch.where(close_enough, scores.ravel()[best4.indices], neg_inf).max(dim=0)
is_better = best_score.values > accu[:,4].ravel()
accu[is_better,0:4] = pos[best4.indices[best_score.indices,torch.arange(len(grid))][is_better]]
accu[is_better,4] = best_score.values[is_better]
accu[is_better,5] = code
def reciprocal(self, corres1, corres2 ):
grid1, corres1 = cpu(corres1)
grid2, corres2 = cpu(corres2)
(H1, W1), (H2, W2) = grid1[-1]+1, grid2[-1]+1
pos1 = corres1[:,:,0:4].view(-1,4)
pos2 = corres2[:,:,0:4].view(-1,4)
to_int = torch.tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32)
inter1 = myF.intersection(pos1@to_int, pos2@to_int)
return corres1.view(-1,6)[inter1]
def _enum_scales(self):
for i in range(-100,101):
scale = 2**(i/2)
# if i != -2: continue
if self.min_sc <= scale <= self.max_sc:
yield i,scale
def _enum_rotations(self):
for i in range(-180//self.rot_step, 180//self.rot_step):
rot = i * self.rot_step
if self.min_rot <= rot <= self.max_rot:
yield i,-rot
def _enum_scaled_pairs(self, img1, img2):
for s, scale in self._enum_scales():
(i1,sca1), (i2,sca2) = starmap(downsample_img, [(img1, min(scale, 1)), (img2, min(1/scale, 1))])
# set bigger image as the first one
size1 = min(i1.shape[-2:])
size2 = min(i2.shape[-2:])
swapped = size1*self.swap_mode < size2*self.swap_mode
swap = (1 - 2*swapped) # swapped ==> swap = -1
if swapped:
(i1,sca1), (i2,sca2) = (i2,sca2), (i1,sca1)
for r, ang in self._enum_rotations():
code = myF.encode_scale_rot(scale, ang)
trf1 = (sca1, swap*ang) if ang != 0 else sca1
yield scale, ang, code, swap, swapped, ((i1,trf1), (i2,sca2))
def _make_accu(self, img):
C, H, W = img.shape
step = self.matcher.pixel_desc.get_atomic_patch_size() // 2
h = step//2 - 1
accu = img.new_zeros(((H+h)//step, (W+h)//step, 6), dtype=torch.float32, device=self.merge_device or img.device)
grid = step * myF.mgrid(accu[:,:,0], device=img.device) + (step//2)
return grid, accu
def downsample_img(img, scale=0):
assert scale <= 1
img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
if scale == 1: return img, trf
assert img.dtype == torch.uint8
trf = trf.clone() # dont modify inplace
trf[:2,:2] /= scale
while scale <= 0.5:
img = F.avg_pool2d(img[None].float(), 2, stride=2, count_include_pad=False)[0]
scale *= 2
if scale != 1:
img = F.interpolate(img[None].float(), scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=False).clamp(min=0, max=255)[0]
return img.byte(), trf # scaled --> pxl
def ceil(i):
return int(np.ceil(i))
def unswap( corres, swapped ):
swap = -1 if swapped else 1
corres, rots = corres
corres = corres[::swap]
rots = rots[::swap]
if swapped:
for pos, _ in corres:
pos[:,0:4] = pos[:,[2,3,0,1]].clone()
return corres, rots
def demultiplex_img_trf(self, img, force=False):
""" img is:
- an image
- a tuple (image, trf)
- a tuple (image, (cur_trf, trf_todo))
In any case, trf: cur_pix --> old_pix
"""
img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
if isinstance(trf, tuple):
trf, todo = trf
if isinstance(todo, (int,float)): # pure rotation
img, trf = myF.rotate_img((img,trf), angle=todo, crop=self.crop_rot)
else:
img = myF.apply_trf_to_img(todo, img)
trf = trf @ todo
return img, trf
class Main (tss.Main):
@staticmethod
def get_options( args ):
return dict(max_scale=args.max_scale, min_scale=args.min_scale,
max_rot=args.max_rot, min_rot=args.min_rot, rot_step=args.rot_step,
swap_mode=args.no_swap, same_levels=args.same_levels, crop_rot=args.crop_rot)
@staticmethod
def tune_matcher( args, matcher, device ):
if device == 'cpu':
args.merge = 'cpu'
if args.merge == 'cpu': type(matcher).merge_corres = myF.merge_corres; matcher.merge_device = 'cpu'
elif args.merge == 'cuda': type(matcher).merge_corres = myF.merge_corres
return matcher.to(device)
@staticmethod
def build_matcher( args, device):
# get a normal matcher
matcher = tss.Main.build_matcher(args, device)
type(matcher).demultiplex_img_trf = demultiplex_img_trf # update transformer
options = Main.get_options(args)
return Main.tune_matcher(args, MultiScalePUMP(matcher, **options), device)
if __name__ == '__main__':
Main().run_from_args(arg_parser().parse_args())