|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
from itertools import starmap |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import test_singlescale as tss |
|
from core import functional as myF |
|
from tools.common import todevice, cpu |
|
from tools.viz import dbgfig, show_correspondences |
|
|
|
|
|
def arg_parser(): |
|
parser = tss.arg_parser() |
|
parser.set_defaults(levels = 0, verbose=0) |
|
|
|
parser.add_argument('--min-scale', type=float, default=None, help='min scale ratio') |
|
parser.add_argument('--max-scale', type=float, default=4, help='max scale ratio') |
|
|
|
parser.add_argument('--min-rot', type=float, default=None, help='min rotation (in degrees) in [-180,180]') |
|
parser.add_argument('--max-rot', type=float, default=0, help='max rotation (in degrees) in [0,180]') |
|
parser.add_argument('--crop-rot', action='store_true', help='crop rotated image to prevent memory blow-up') |
|
parser.add_argument('--rot-step', type=int, default=45, help='rotation step (in degrees)') |
|
|
|
parser.add_argument('--no-swap', type=int, default=1, nargs='?', const=0, choices=[1,0,-1], help='if 0, img1 will have keypoints on a grid') |
|
parser.add_argument('--same-levels', action='store_true', help='use the same number of pyramid levels for all scales') |
|
|
|
parser.add_argument('--merge', choices='torch cpu cuda'.split(), default='cpu') |
|
return parser |
|
|
|
|
|
class MultiScalePUMP (nn.Module): |
|
""" DeepMatching that loops over all possible {scale x rotation} combinations. |
|
""" |
|
def __init__(self, matcher, |
|
min_scale=1, |
|
max_scale=1, |
|
max_rot=0, |
|
min_rot=0, |
|
rot_step=45, |
|
swap_mode=1, |
|
same_levels=False, |
|
crop_rot=False): |
|
super().__init__() |
|
min_scale = min_scale or 1/max_scale |
|
min_rot = min_rot or -max_rot |
|
assert 0.1 <= min_scale <= max_scale <= 10 |
|
assert -180 <= min_rot <= max_rot <= 180 |
|
self.matcher = matcher |
|
self.matcher.crop_rot = crop_rot |
|
|
|
self.min_sc = min_scale |
|
self.max_sc = max_scale |
|
self.min_rot = min_rot |
|
self.max_rot = max_rot |
|
self.rot_step = rot_step |
|
self.swap_mode = swap_mode |
|
self.merge_device = None |
|
self.same_levels = same_levels |
|
|
|
@torch.no_grad() |
|
def forward(self, img1, img2, dbg=()): |
|
img1, sca1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3, device=img1.device)) |
|
img2, sca2 = img2 if isinstance(img2, tuple) else (img2, torch.eye(3, device=img2.device)) |
|
|
|
|
|
if self.same_levels: |
|
self.matcher.levels = self._find_max_levels(img1,img2) |
|
elif self.matcher.levels == 0: |
|
max_psize = int(min(np.mean(img1.shape[-2:]), np.mean(img2.shape[-2:]))) |
|
self.matcher.levels = int(np.log2(max_psize / self.matcher.pixel_desc.get_atomic_patch_size())) |
|
|
|
all_corres = (self._make_accu(img1), self._make_accu(img2)) |
|
|
|
for scale, ang, code, swap, swapped, (scimg1, scimg2) in self._enum_scaled_pairs(img1, img2): |
|
print(f"processing {scale=:g} x {ang=} {['','(swapped)'][swapped]} ({code=})...") |
|
|
|
|
|
corres, rots = self.process_one_scale(swapped, *[scimg1,scimg2], dbg=dbg) |
|
if dbgfig('corres-ms', dbg): viz_correspondences(img1, img2, *corres, fig='last') |
|
|
|
|
|
self.merge_corres( corres, rots, all_corres, code ) |
|
|
|
|
|
corres = self.reciprocal( *all_corres ) |
|
return myF.affmul(todevice((sca1,sca2),corres.device), corres) |
|
|
|
def process_one_scale(self, swapped, *imgs, dbg=()): |
|
return unswap(self.matcher(*imgs, ret='raw', dbg=dbg), swapped) |
|
|
|
def _find_max_levels(self, img1, img2): |
|
min_levels = self.matcher.levels or 999 |
|
for _, _, code, _, _, (img1, img2) in self._enum_scaled_pairs(img1, img2): |
|
|
|
img1_levels = ceil(np.log2(min(img1[0].shape[-2:])) - 1) |
|
|
|
img2_levels = ceil(np.log2(min(img2[0].shape[-2:]) / self.matcher.min_shape)) |
|
|
|
min_levels = min(min_levels, img1_levels, img2_levels) |
|
return min_levels |
|
|
|
def merge_corres(self, corres, rots, all_corres, code): |
|
" rot : reference --> rotated " |
|
self.merge_one_side( corres[0], slice(0,2), rots[0], all_corres[0], code ) |
|
self.merge_one_side( corres[1], slice(2,4), rots[1], all_corres[1], code ) |
|
|
|
def merge_one_side(self, corres, sel, trf, all_corres, code ): |
|
pos, scores = corres |
|
grid, accu = all_corres |
|
accu = accu.view(-1, 6) |
|
|
|
|
|
best4 = torch.cdist(pos[:,sel].float(), grid).topk(4, dim=0, largest=False) |
|
|
|
|
|
|
|
scale = float(torch.sqrt(torch.det(trf))) |
|
dist_max = 8*scale - 1e-7 |
|
|
|
close_enough = (best4.values <= 2*best4.values[0:1]) & (best4.values < dist_max) |
|
neg_inf = torch.tensor(-np.inf, device=scores.device) |
|
best_score = torch.where(close_enough, scores.ravel()[best4.indices], neg_inf).max(dim=0) |
|
is_better = best_score.values > accu[:,4].ravel() |
|
|
|
accu[is_better,0:4] = pos[best4.indices[best_score.indices,torch.arange(len(grid))][is_better]] |
|
accu[is_better,4] = best_score.values[is_better] |
|
accu[is_better,5] = code |
|
|
|
def reciprocal(self, corres1, corres2 ): |
|
grid1, corres1 = cpu(corres1) |
|
grid2, corres2 = cpu(corres2) |
|
|
|
(H1, W1), (H2, W2) = grid1[-1]+1, grid2[-1]+1 |
|
pos1 = corres1[:,:,0:4].view(-1,4) |
|
pos2 = corres2[:,:,0:4].view(-1,4) |
|
|
|
to_int = torch.tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32) |
|
inter1 = myF.intersection(pos1@to_int, pos2@to_int) |
|
return corres1.view(-1,6)[inter1] |
|
|
|
def _enum_scales(self): |
|
for i in range(-100,101): |
|
scale = 2**(i/2) |
|
|
|
if self.min_sc <= scale <= self.max_sc: |
|
yield i,scale |
|
|
|
def _enum_rotations(self): |
|
for i in range(-180//self.rot_step, 180//self.rot_step): |
|
rot = i * self.rot_step |
|
if self.min_rot <= rot <= self.max_rot: |
|
yield i,-rot |
|
|
|
def _enum_scaled_pairs(self, img1, img2): |
|
for s, scale in self._enum_scales(): |
|
(i1,sca1), (i2,sca2) = starmap(downsample_img, [(img1, min(scale, 1)), (img2, min(1/scale, 1))]) |
|
|
|
size1 = min(i1.shape[-2:]) |
|
size2 = min(i2.shape[-2:]) |
|
swapped = size1*self.swap_mode < size2*self.swap_mode |
|
swap = (1 - 2*swapped) |
|
if swapped: |
|
(i1,sca1), (i2,sca2) = (i2,sca2), (i1,sca1) |
|
|
|
for r, ang in self._enum_rotations(): |
|
code = myF.encode_scale_rot(scale, ang) |
|
trf1 = (sca1, swap*ang) if ang != 0 else sca1 |
|
yield scale, ang, code, swap, swapped, ((i1,trf1), (i2,sca2)) |
|
|
|
def _make_accu(self, img): |
|
C, H, W = img.shape |
|
step = self.matcher.pixel_desc.get_atomic_patch_size() // 2 |
|
h = step//2 - 1 |
|
accu = img.new_zeros(((H+h)//step, (W+h)//step, 6), dtype=torch.float32, device=self.merge_device or img.device) |
|
grid = step * myF.mgrid(accu[:,:,0], device=img.device) + (step//2) |
|
return grid, accu |
|
|
|
|
|
def downsample_img(img, scale=0): |
|
assert scale <= 1 |
|
img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device)) |
|
if scale == 1: return img, trf |
|
|
|
assert img.dtype == torch.uint8 |
|
trf = trf.clone() |
|
trf[:2,:2] /= scale |
|
while scale <= 0.5: |
|
img = F.avg_pool2d(img[None].float(), 2, stride=2, count_include_pad=False)[0] |
|
scale *= 2 |
|
if scale != 1: |
|
img = F.interpolate(img[None].float(), scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=False).clamp(min=0, max=255)[0] |
|
return img.byte(), trf |
|
|
|
|
|
def ceil(i): |
|
return int(np.ceil(i)) |
|
|
|
def unswap( corres, swapped ): |
|
swap = -1 if swapped else 1 |
|
corres, rots = corres |
|
corres = corres[::swap] |
|
rots = rots[::swap] |
|
if swapped: |
|
for pos, _ in corres: |
|
pos[:,0:4] = pos[:,[2,3,0,1]].clone() |
|
return corres, rots |
|
|
|
|
|
def demultiplex_img_trf(self, img, force=False): |
|
""" img is: |
|
- an image |
|
- a tuple (image, trf) |
|
- a tuple (image, (cur_trf, trf_todo)) |
|
In any case, trf: cur_pix --> old_pix |
|
""" |
|
img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device)) |
|
|
|
if isinstance(trf, tuple): |
|
trf, todo = trf |
|
if isinstance(todo, (int,float)): |
|
img, trf = myF.rotate_img((img,trf), angle=todo, crop=self.crop_rot) |
|
else: |
|
img = myF.apply_trf_to_img(todo, img) |
|
trf = trf @ todo |
|
return img, trf |
|
|
|
|
|
class Main (tss.Main): |
|
@staticmethod |
|
def get_options( args ): |
|
return dict(max_scale=args.max_scale, min_scale=args.min_scale, |
|
max_rot=args.max_rot, min_rot=args.min_rot, rot_step=args.rot_step, |
|
swap_mode=args.no_swap, same_levels=args.same_levels, crop_rot=args.crop_rot) |
|
|
|
@staticmethod |
|
def tune_matcher( args, matcher, device ): |
|
if device == 'cpu': |
|
args.merge = 'cpu' |
|
|
|
if args.merge == 'cpu': type(matcher).merge_corres = myF.merge_corres; matcher.merge_device = 'cpu' |
|
elif args.merge == 'cuda': type(matcher).merge_corres = myF.merge_corres |
|
|
|
return matcher.to(device) |
|
|
|
@staticmethod |
|
def build_matcher( args, device): |
|
|
|
matcher = tss.Main.build_matcher(args, device) |
|
type(matcher).demultiplex_img_trf = demultiplex_img_trf |
|
|
|
options = Main.get_options(args) |
|
return Main.tune_matcher(args, MultiScalePUMP(matcher, **options), device) |
|
|
|
|
|
if __name__ == '__main__': |
|
Main().run_from_args(arg_parser().parse_args()) |
|
|