Spaces:
Running
Running
import pdb | |
import os | |
import sys | |
import tqdm | |
import numpy as np | |
import torch | |
from PIL import Image | |
from matplotlib import pyplot as pl | |
pl.ion() | |
from scipy.ndimage import uniform_filter | |
smooth = lambda arr: uniform_filter(arr, 3) | |
def transparent(img, alpha, cmap, **kw): | |
from matplotlib.colors import Normalize | |
colored_img = cmap(Normalize(clip=True, **kw)(img)) | |
colored_img[:, :, -1] = alpha | |
return colored_img | |
from tools import common | |
from tools.dataloader import norm_RGB | |
from nets.patchnet import * | |
from extract import NonMaxSuppression | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") | |
parser.add_argument("--img", type=str, default="imgs/brooklyn.png") | |
parser.add_argument("--resize", type=int, default=512) | |
parser.add_argument("--out", type=str, default="viz.png") | |
parser.add_argument("--checkpoint", type=str, required=True, help="network path") | |
parser.add_argument("--net", type=str, default="", help="network command") | |
parser.add_argument("--max-kpts", type=int, default=200) | |
parser.add_argument("--reliability-thr", type=float, default=0.8) | |
parser.add_argument("--repeatability-thr", type=float, default=0.7) | |
parser.add_argument( | |
"--border", type=int, default=20, help="rm keypoints close to border" | |
) | |
parser.add_argument("--gpu", type=int, nargs="+", required=True, help="-1 for CPU") | |
parser.add_argument("--dbg", type=str, nargs="+", default=(), help="debug options") | |
args = parser.parse_args() | |
args.dbg = set(args.dbg) | |
iscuda = common.torch_set_gpu(args.gpu) | |
device = torch.device("cuda" if iscuda else "cpu") | |
# create network | |
checkpoint = torch.load(args.checkpoint, lambda a, b: a) | |
args.net = args.net or checkpoint["net"] | |
print("\n>> Creating net = " + args.net) | |
net = eval(args.net) | |
net.load_state_dict( | |
{k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} | |
) | |
if iscuda: | |
net = net.cuda() | |
print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") | |
img = Image.open(args.img).convert("RGB") | |
if args.resize: | |
img.thumbnail((args.resize, args.resize)) | |
img = np.asarray(img) | |
detector = NonMaxSuppression( | |
rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr | |
) | |
with torch.no_grad(): | |
print(">> computing features...") | |
res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) | |
rela = res.get("reliability") | |
repe = res.get("repeatability") | |
kpts = detector(**res).T[:, [1, 0]] | |
kpts = kpts[repe[0][0, 0][kpts[:, 1], kpts[:, 0]].argsort()[-args.max_kpts :]] | |
fig = pl.figure("viz") | |
kw = dict(cmap=pl.cm.RdYlGn, vmax=1) | |
crop = (slice(args.border, -args.border or 1),) * 2 | |
if "reliability" in args.dbg: | |
ax1 = pl.subplot(131) | |
pl.imshow(img[crop], cmap=pl.cm.gray) | |
pl.xticks(()) | |
pl.yticks(()) | |
pl.subplot(132) | |
pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) | |
pl.xticks(()) | |
pl.yticks(()) | |
x, y = kpts[:, 0:2].cpu().numpy().T - args.border | |
pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) | |
ax1 = pl.subplot(133) | |
rela = rela[0][0, 0].cpu().numpy() | |
pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) | |
pl.xticks(()) | |
pl.yticks(()) | |
else: | |
ax1 = pl.subplot(131) | |
pl.imshow(img[crop], cmap=pl.cm.gray) | |
pl.xticks(()) | |
pl.yticks(()) | |
x, y = kpts[:, 0:2].cpu().numpy().T - args.border | |
pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) | |
pl.subplot(132) | |
pl.imshow(img[crop], cmap=pl.cm.gray) | |
pl.xticks(()) | |
pl.yticks(()) | |
c = repe[0][0, 0].cpu().numpy() | |
pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) | |
ax1 = pl.subplot(133) | |
pl.imshow(img[crop], cmap=pl.cm.gray) | |
pl.xticks(()) | |
pl.yticks(()) | |
rela = rela[0][0, 0].cpu().numpy() | |
pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) | |
pl.gcf().set_size_inches(9, 2.73) | |
pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, hspace=0.1) | |
pl.savefig(args.out) | |
pdb.set_trace() | |