|
import pdb |
|
import os |
|
import sys |
|
import tqdm |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from PIL import Image |
|
from matplotlib import pyplot as pl |
|
|
|
pl.ion() |
|
from scipy.ndimage import uniform_filter |
|
|
|
smooth = lambda arr: uniform_filter(arr, 3) |
|
|
|
|
|
def transparent(img, alpha, cmap, **kw): |
|
from matplotlib.colors import Normalize |
|
|
|
colored_img = cmap(Normalize(clip=True, **kw)(img)) |
|
colored_img[:, :, -1] = alpha |
|
return colored_img |
|
|
|
|
|
from tools import common |
|
from tools.dataloader import norm_RGB |
|
from nets.patchnet import * |
|
from extract import NonMaxSuppression |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") |
|
|
|
parser.add_argument("--img", type=str, default="imgs/brooklyn.png") |
|
parser.add_argument("--resize", type=int, default=512) |
|
parser.add_argument("--out", type=str, default="viz.png") |
|
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="network path") |
|
parser.add_argument("--net", type=str, default="", help="network command") |
|
|
|
parser.add_argument("--max-kpts", type=int, default=200) |
|
parser.add_argument("--reliability-thr", type=float, default=0.8) |
|
parser.add_argument("--repeatability-thr", type=float, default=0.7) |
|
parser.add_argument( |
|
"--border", type=int, default=20, help="rm keypoints close to border" |
|
) |
|
|
|
parser.add_argument("--gpu", type=int, nargs="+", required=True, help="-1 for CPU") |
|
parser.add_argument("--dbg", type=str, nargs="+", default=(), help="debug options") |
|
|
|
args = parser.parse_args() |
|
args.dbg = set(args.dbg) |
|
|
|
iscuda = common.torch_set_gpu(args.gpu) |
|
device = torch.device("cuda" if iscuda else "cpu") |
|
|
|
|
|
checkpoint = torch.load(args.checkpoint, lambda a, b: a) |
|
args.net = args.net or checkpoint["net"] |
|
print("\n>> Creating net = " + args.net) |
|
net = eval(args.net) |
|
net.load_state_dict( |
|
{k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} |
|
) |
|
if iscuda: |
|
net = net.cuda() |
|
print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") |
|
|
|
img = Image.open(args.img).convert("RGB") |
|
if args.resize: |
|
img.thumbnail((args.resize, args.resize)) |
|
img = np.asarray(img) |
|
|
|
detector = NonMaxSuppression( |
|
rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr |
|
) |
|
|
|
with torch.no_grad(): |
|
print(">> computing features...") |
|
res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) |
|
rela = res.get("reliability") |
|
repe = res.get("repeatability") |
|
kpts = detector(**res).T[:, [1, 0]] |
|
kpts = kpts[repe[0][0, 0][kpts[:, 1], kpts[:, 0]].argsort()[-args.max_kpts :]] |
|
|
|
fig = pl.figure("viz") |
|
kw = dict(cmap=pl.cm.RdYlGn, vmax=1) |
|
crop = (slice(args.border, -args.border or 1),) * 2 |
|
|
|
if "reliability" in args.dbg: |
|
|
|
ax1 = pl.subplot(131) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()) |
|
pl.yticks(()) |
|
|
|
pl.subplot(132) |
|
pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) |
|
pl.xticks(()) |
|
pl.yticks(()) |
|
|
|
x, y = kpts[:, 0:2].cpu().numpy().T - args.border |
|
pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) |
|
|
|
ax1 = pl.subplot(133) |
|
rela = rela[0][0, 0].cpu().numpy() |
|
pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) |
|
pl.xticks(()) |
|
pl.yticks(()) |
|
|
|
else: |
|
ax1 = pl.subplot(131) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()) |
|
pl.yticks(()) |
|
|
|
x, y = kpts[:, 0:2].cpu().numpy().T - args.border |
|
pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) |
|
|
|
pl.subplot(132) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()) |
|
pl.yticks(()) |
|
c = repe[0][0, 0].cpu().numpy() |
|
pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) |
|
|
|
ax1 = pl.subplot(133) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()) |
|
pl.yticks(()) |
|
rela = rela[0][0, 0].cpu().numpy() |
|
pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) |
|
|
|
pl.gcf().set_size_inches(9, 2.73) |
|
pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, hspace=0.1) |
|
pl.savefig(args.out) |
|
pdb.set_trace() |
|
|