import pdb import os import sys import tqdm import numpy as np import torch from PIL import Image from matplotlib import pyplot as pl; pl.ion() from scipy.ndimage import uniform_filter smooth = lambda arr: uniform_filter(arr, 3) def transparent(img, alpha, cmap, **kw): from matplotlib.colors import Normalize colored_img = cmap(Normalize(clip=True,**kw)(img)) colored_img[:,:,-1] = alpha return colored_img from tools import common from tools.dataloader import norm_RGB from nets.patchnet import * from extract import NonMaxSuppression if __name__ == '__main__': import argparse parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") parser.add_argument("--img", type=str, default="imgs/brooklyn.png") parser.add_argument("--resize", type=int, default=512) parser.add_argument("--out", type=str, default="viz.png") parser.add_argument("--checkpoint", type=str, required=True, help='network path') parser.add_argument("--net", type=str, default="", help='network command') parser.add_argument("--max-kpts", type=int, default=200) parser.add_argument("--reliability-thr", type=float, default=0.8) parser.add_argument("--repeatability-thr", type=float, default=0.7) parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border') parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU') parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options') args = parser.parse_args() args.dbg = set(args.dbg) iscuda = common.torch_set_gpu(args.gpu) device = torch.device('cuda' if iscuda else 'cpu') # create network checkpoint = torch.load(args.checkpoint, lambda a,b:a) args.net = args.net or checkpoint['net'] print("\n>> Creating net = " + args.net) net = eval(args.net) net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()}) if iscuda: net = net.cuda() print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") img = Image.open(args.img).convert('RGB') if args.resize: img.thumbnail((args.resize,args.resize)) img = np.asarray(img) detector = NonMaxSuppression( rel_thr = args.reliability_thr, rep_thr = args.repeatability_thr) with torch.no_grad(): print(">> computing features...") res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) rela = res.get('reliability') repe = res.get('repeatability') kpts = detector(**res).T[:,[1,0]] kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]] fig = pl.figure("viz") kw = dict(cmap=pl.cm.RdYlGn, vmax=1) crop = (slice(args.border,-args.border or 1),)*2 if 'reliability' in args.dbg: ax1 = pl.subplot(131) pl.imshow(img[crop], cmap=pl.cm.gray) pl.xticks(()); pl.yticks(()) pl.subplot(132) pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) pl.xticks(()); pl.yticks(()) x,y = kpts[:,0:2].cpu().numpy().T - args.border pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) ax1 = pl.subplot(133) rela = rela[0][0,0].cpu().numpy() pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) pl.xticks(()); pl.yticks(()) else: ax1 = pl.subplot(131) pl.imshow(img[crop], cmap=pl.cm.gray) pl.xticks(()); pl.yticks(()) x,y = kpts[:,0:2].cpu().numpy().T - args.border pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) pl.subplot(132) pl.imshow(img[crop], cmap=pl.cm.gray) pl.xticks(()); pl.yticks(()) c = repe[0][0,0].cpu().numpy() pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) ax1 = pl.subplot(133) pl.imshow(img[crop], cmap=pl.cm.gray) pl.xticks(()); pl.yticks(()) rela = rela[0][0,0].cpu().numpy() pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) pl.gcf().set_size_inches(9, 2.73) pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1) pl.savefig(args.out) pdb.set_trace()