# Copyright 2019-present NAVER Corp. # CC BY-NC-SA 3.0 # Available only for non-commercial use import os, pdb from PIL import Image import numpy as np import torch from .tools import common from .tools.dataloader import norm_RGB from .nets.patchnet import * def load_network(model_fn): checkpoint = torch.load(model_fn) print("\n>> Creating net = " + checkpoint["net"]) net = eval(checkpoint["net"]) nb_of_weights = common.model_size(net) print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )") # initialization weights = checkpoint["state_dict"] net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()}) return net.eval() class NonMaxSuppression(torch.nn.Module): def __init__(self, rel_thr=0.7, rep_thr=0.7): nn.Module.__init__(self) self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) self.rel_thr = rel_thr self.rep_thr = rep_thr def forward(self, reliability, repeatability, **kw): assert len(reliability) == len(repeatability) == 1 reliability, repeatability = reliability[0], repeatability[0] # local maxima maxima = repeatability == self.max_filter(repeatability) # remove low peaks maxima *= repeatability >= self.rep_thr maxima *= reliability >= self.rel_thr return maxima.nonzero().t()[2:4] def extract_multiscale( net, img, detector, scale_f=2**0.25, min_scale=0.0, max_scale=1, min_size=256, max_size=1024, verbose=False, ): old_bm = torch.backends.cudnn.benchmark torch.backends.cudnn.benchmark = False # speedup # extract keypoints at multiple scales B, three, H, W = img.shape assert B == 1 and three == 3, "should be a batch with a single RGB image" assert max_scale <= 1 s = 1.0 # current scale factor X, Y, S, C, Q, D = [], [], [], [], [], [] while s + 0.001 >= max(min_scale, min_size / max(H, W)): if s - 0.001 <= min(max_scale, max_size / max(H, W)): nh, nw = img.shape[2:] if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") # extract descriptors with torch.no_grad(): res = net(imgs=[img]) # get output and reliability map descriptors = res["descriptors"][0] reliability = res["reliability"][0] repeatability = res["repeatability"][0] # normalize the reliability for nms # extract maxima and descs y, x = detector(**res) # nms c = reliability[0, 0, y, x] q = repeatability[0, 0, y, x] d = descriptors[0, :, y, x].t() n = d.shape[0] # accumulate multiple scales X.append(x.float() * W / nw) Y.append(y.float() * H / nh) S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device)) C.append(c) Q.append(q) D.append(d) s /= scale_f # down-scale the image for next iteration nh, nw = round(H * s), round(W * s) img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False) # restore value torch.backends.cudnn.benchmark = old_bm Y = torch.cat(Y) X = torch.cat(X) S = torch.cat(S) # scale scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability XYS = torch.stack([X, Y, S], dim=-1) D = torch.cat(D) return XYS, D, scores def extract_keypoints(args): iscuda = common.torch_set_gpu(args.gpu) # load the network... net = load_network(args.model) if iscuda: net = net.cuda() # create the non-maxima detector detector = NonMaxSuppression( rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr ) while args.images: img_path = args.images.pop(0) if img_path.endswith(".txt"): args.images = open(img_path).read().splitlines() + args.images continue print(f"\nExtracting features for {img_path}") img = Image.open(img_path).convert("RGB") W, H = img.size img = norm_RGB(img)[None] if iscuda: img = img.cuda() # extract keypoints/descriptors for a single image xys, desc, scores = extract_multiscale( net, img, detector, scale_f=args.scale_f, min_scale=args.min_scale, max_scale=args.max_scale, min_size=args.min_size, max_size=args.max_size, verbose=True, ) xys = xys.cpu().numpy() desc = desc.cpu().numpy() scores = scores.cpu().numpy() idxs = scores.argsort()[-args.top_k or None :] outpath = img_path + "." + args.tag print(f"Saving {len(idxs)} keypoints to {outpath}") np.savez( open(outpath, "wb"), imsize=(W, H), keypoints=xys[idxs], descriptors=desc[idxs], scores=scores[idxs], ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Extract keypoints for a given image") parser.add_argument("--model", type=str, required=True, help="model path") parser.add_argument( "--images", type=str, required=True, nargs="+", help="images / list" ) parser.add_argument("--tag", type=str, default="r2d2", help="output file tag") parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") parser.add_argument("--scale-f", type=float, default=2**0.25) parser.add_argument("--min-size", type=int, default=256) parser.add_argument("--max-size", type=int, default=1024) parser.add_argument("--min-scale", type=float, default=0) parser.add_argument("--max-scale", type=float, default=1) parser.add_argument("--reliability-thr", type=float, default=0.7) parser.add_argument("--repeatability-thr", type=float, default=0.7) parser.add_argument( "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" ) args = parser.parse_args() extract_keypoints(args)