PUMP / test_singlescale.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
raw history blame
No virus
11.8 kB
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from core import functional as myF
from core.pixel_desc import PixelDesc
from tools.common import mkdir_for, todevice, cudnn_benchmark, nparray, image, image_with_trf
from tools.viz import dbgfig, show_correspondences
def arg_parser():
import argparse
parser = argparse.ArgumentParser('SingleScalePUMP on GPU with PyTorch')
parser.add_argument('--img1', required=True, help='path to img1')
parser.add_argument('--img2', required=True, help='path to img2')
parser.add_argument('--resize', type=int, default=512, nargs='+', help='prior downsize of img1 and img2')
parser.add_argument('--output', default=None, help='output path for correspondences')
parser.add_argument('--levels', type=int, default=99, help='number of pyramid levels')
parser.add_argument('--min-shape', type=int, default=5, help='minimum size of corr maps')
parser.add_argument('--nlpow', type=float, default=1.5, help='non-linear activation power in [1,2]')
parser.add_argument('--border', type=float, default=0.9, help='border invariance level in [0,1]')
parser.add_argument('--dtype', default='float16', choices='float16 float32 float64'.split())
parser.add_argument('--desc', default='PUMP-stytrf', help='checkpoint name')
parser.add_argument('--first-level', choices='torch'.split(), default='torch')
parser.add_argument('--activation', choices='torch'.split(), default='torch')
parser.add_argument('--forward', choices='torch cuda cuda-lowmem'.split(), default='cuda-lowmem')
parser.add_argument('--backward', choices='python torch cuda'.split(), default='cuda')
parser.add_argument('--reciprocal', choices='cpu cuda'.split(), default='cpu')
parser.add_argument('--post-filter', default=None, const=True, nargs='?', help='post-filtering (See post_filter.py)')
parser.add_argument('--verbose', type=int, default=0, help='verbosity')
parser.add_argument('--device', default='cuda', help='gpu device')
parser.add_argument('--dbg', nargs='*', default=(), help='debug options')
return parser
class SingleScalePUMP (nn.Module):
def __init__(self, levels = 9, nlpow = 1.4, cutoff = 1,
border_inv=0.9, min_shape=5, renorm=(),
pixel_desc = None, dtype = torch.float32,
verbose = True ):
super().__init__()
self.levels = levels
self.min_shape = min_shape
self.nlpow = nlpow
self.border_inv = border_inv
assert pixel_desc, 'Requires a pixel descriptor'
self.pixel_desc = pixel_desc.configure(self)
self.dtype = dtype
self.verbose = verbose
@torch.no_grad()
def forward(self, img1, img2, ret='corres', dbg=()):
with cudnn_benchmark(False):
# compute descriptors
(img1, img2), pixel_descs, trfs = self.extract_descs(img1, img2, dtype=self.dtype)
# backward and forward passes
pixel_corr = self.first_level(*pixel_descs, dbg=dbg)
pixel_corr = self.backward_pass(self.forward_pass(pixel_corr, dbg=dbg), dbg=dbg)
# recover correspondences
corres = myF.best_correspondences( pixel_corr )
if dbgfig('corres', dbg): viz_correspondences(img1[0], img2[0], *corres, fig='last')
corres = [(myF.affmul(trfs,pos),score) for pos, score in corres] # rectify scaling etc.
if ret == 'raw': return corres, trfs
return self.reciprocal(*corres)
def extract_descs(self, img1, img2, dtype=None):
img1, sca1 = self.demultiplex_img_trf(img1)
img2, sca2 = self.demultiplex_img_trf(img2)
desc1, trf1 = self.pixel_desc(img1)
desc2, trf2 = self.pixel_desc(img2)
return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)
def demultiplex_img_trf(self, img, **kw):
return img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
def forward_pass(self, pixel_corr, dbg=()):
weights = None
if isinstance(pixel_corr, tuple):
pixel_corr, weights = pixel_corr
# first-level with activation
if self.verbose: print(f' Pyramid level {0} shape={tuple(pixel_corr.shape)}')
pyramid = [ self.activation(0,pixel_corr) ]
if dbgfig(f'corr0', dbg): viz_correlation_maps(*from_stack('img1','img2'), pyramid[0], fig='last')
for level in range(1, self.levels+1):
upper, weights = self.forward_level(level, pyramid[-1], weights)
if weights.sum() == 0: break # img1 has become too small
# activation
pyramid.append( self.activation(level,upper) )
if self.verbose: print(f' Pyramid level {level} shape={tuple(upper.shape)}')
if dbgfig(f'corr{level}', dbg): viz_correlation_maps(*from_stack('img1','img2'), upper, level=level, fig='last')
if min(upper.shape[-2:]) <= self.min_shape: break # img2 has become too small
return pyramid
def forward_level(self, level, corr, weights):
# max-pooling
pooled = F.max_pool2d(corr, 3, padding=1, stride=2)
# sparse conv
return myF.sparse_conv(level, pooled, weights, norm=self.border_inv)
def backward_pass(self, pyramid, dbg=()):
# same than forward in reverse order
for level in range(len(pyramid)-1, 0, -1):
lower = self.backward_level(level, pyramid)
# assert not torch.isnan(lower).any(), bb()
if self.verbose: print(f' Pyramid level {level-1} shape={tuple(lower.shape)}')
del pyramid[-1] # free memory
if dbgfig(f'corr{level}-bw', dbg): viz_correlation_maps(img1, img2, lower, fig='last')
return pyramid[0]
def backward_level(self, level, pyramid):
# reverse sparse-coonv
pooled = myF.sparse_conv(level, pyramid[level], reverse=True)
# reverse max-pool and add to lower level
return myF.max_unpool(pooled, pyramid[level-1])
def activation(self, level, corr):
assert 1 <= self.nlpow <= 3
corr.clamp_(min=0).pow_(self.nlpow)
return corr
def first_level(self, desc1, desc2, dbg=()):
assert desc1.ndim == desc2.ndim == 4
assert len(desc1) == len(desc2) == 1, "not implemented"
H1, W1 = desc1.shape[-2:]
H2, W2 = desc2.shape[-2:]
patches = F.unfold(desc1, 4, stride=4) # C*4*4, H1*W1//16
B, C, N = patches.shape
# rearrange(patches, 'B (C Kh Kw) H1W1 -> B H1W1 C Kh Kw', Kh=4, Kw=4)
patches = patches.permute(0, 2, 1).view(B, H1W1, C//16, 4, 4)
corr, norms = myF.normalized_corr(patches[0], desc2[0], ret_norms=True)
if dbgfig('ncc',dbg):
for j in range(0,len(corr),9):
for i in range(9):
pl.subplot(3,3,i+1).cla()
i += j
pl.imshow(corr[i], vmin=0.9, vmax=1)
pl.plot(2+(i%16)*4, 2+(i//16)*4,'xr', ms=10)
bb()
return corr.view(H1//4, W1//4, H2+1, W2+1), (norms.view(H1//4, W1//4)>0).float()
def reciprocal(self, corres1, corres2 ):
corres1, corres2 = todevice(corres1, 'cpu'), todevice(corres2, 'cpu')
return myF.reciprocal(self, corres1, corres2)
class Main:
def __init__(self):
self.post_filtering = False
def run_from_args(self, args):
device = args.device
self.matcher = self.build_matcher(args, device)
if args.post_filter:
self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
corres = self(*self.load_images(args, device), dbg=set(args.dbg))
if args.output:
self.save_output( args.output, corres )
def run_from_args_with_images(self, img1, img2, args):
device = args.device
self.matcher = self.build_matcher(args, device)
if args.post_filter:
self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
args.resize = (args.resize, args.resize)
if len(args.resize) == 1:
args.resize = 2 * args.resize
images = []
for imgx, size in zip([img1, img2], args.resize):
img = torch.from_numpy(np.array(imgx.convert('RGB'))).permute(2,0,1).to(device)
img = myF.imresize(img, size)
images.append( img )
corres = self(*images, dbg=set(args.dbg))
if args.output:
self.save_output( args.output, corres )
return corres
@staticmethod
def get_options( args ):
# configure the pipeline
pixel_desc = PixelDesc(path=f'checkpoints/{args.desc}.pt')
return dict(levels=args.levels, min_shape=args.min_shape, border_inv=args.border, nlpow=args.nlpow,
pixel_desc=pixel_desc, dtype=eval(f'torch.{args.dtype}'), verbose=args.verbose)
@staticmethod
def tune_matcher( args, matcher, device ):
if device == 'cpu':
matcher.dtype = torch.float32
args.forward = 'torch'
args.backward = 'torch'
args.reciprocal = 'cpu'
if args.forward == 'cuda': type(matcher).forward_level = myF.forward_cuda
if args.forward == 'cuda-lowmem':type(matcher).forward_level = myF.forward_cuda_lowmem
if args.backward == 'python': type(matcher).backward_pass = legacy.backward_python
if args.backward == 'cuda': type(matcher).backward_level = myF.backward_cuda
if args.reciprocal == 'cuda': type(matcher).reciprocal = myF.reciprocal
return matcher.to(device)
@staticmethod
def build_matcher(args, device):
options = Main.get_options(args)
matcher = SingleScalePUMP(**options)
return Main.tune_matcher(args, matcher, device)
def __call__(self, *imgs, dbg=()):
corres = self.matcher( *imgs, dbg=dbg).cpu().numpy()
if self.post_filtering is not False:
corres = self.post_filter( imgs, corres )
if 'print' in dbg: print(corres)
if dbgfig('viz',dbg): show_correspondences(*imgs, corres)
return corres
@staticmethod
def load_images( args, device='cpu' ):
def read_image(impath):
try:
from torchvision.io.image import read_image, ImageReadMode
return read_image(impath, mode=ImageReadMode.RGB)
except RuntimeError:
from PIL import Image
return torch.from_numpy(np.array(Image.open(impath).convert('RGB'))).permute(2,0,1)
if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
args.resize = (args.resize, args.resize)
if len(args.resize) == 1:
args.resize = 2 * args.resize
images = []
for impath, size in zip([args.img1, args.img2], args.resize):
img = read_image(impath).to(device)
img = myF.imresize(img, size)
images.append( img )
return images
def post_filter(self, imgs, corres ):
from post_filter import filter_corres
return filter_corres(*map(image_with_trf,imgs), corres, **self.post_filtering)
def save_output(self, output_path, corres ):
mkdir_for( output_path )
np.savez(open(output_path,'wb'), corres=corres)
if __name__ == '__main__':
Main().run_from_args(arg_parser().parse_args())