|
|
|
|
|
|
|
|
|
|
|
import os, pdb |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
|
|
from .tools import common |
|
from .tools.dataloader import norm_RGB |
|
from .nets.patchnet import * |
|
|
|
|
|
def load_network(model_fn): |
|
checkpoint = torch.load(model_fn) |
|
print("\n>> Creating net = " + checkpoint["net"]) |
|
net = eval(checkpoint["net"]) |
|
nb_of_weights = common.model_size(net) |
|
print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )") |
|
|
|
|
|
weights = checkpoint["state_dict"] |
|
net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()}) |
|
return net.eval() |
|
|
|
|
|
class NonMaxSuppression(torch.nn.Module): |
|
def __init__(self, rel_thr=0.7, rep_thr=0.7): |
|
nn.Module.__init__(self) |
|
self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) |
|
self.rel_thr = rel_thr |
|
self.rep_thr = rep_thr |
|
|
|
def forward(self, reliability, repeatability, **kw): |
|
assert len(reliability) == len(repeatability) == 1 |
|
reliability, repeatability = reliability[0], repeatability[0] |
|
|
|
|
|
maxima = repeatability == self.max_filter(repeatability) |
|
|
|
|
|
maxima *= repeatability >= self.rep_thr |
|
maxima *= reliability >= self.rel_thr |
|
|
|
return maxima.nonzero().t()[2:4] |
|
|
|
|
|
def extract_multiscale( |
|
net, |
|
img, |
|
detector, |
|
scale_f=2**0.25, |
|
min_scale=0.0, |
|
max_scale=1, |
|
min_size=256, |
|
max_size=1024, |
|
verbose=False, |
|
): |
|
old_bm = torch.backends.cudnn.benchmark |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
B, three, H, W = img.shape |
|
assert B == 1 and three == 3, "should be a batch with a single RGB image" |
|
|
|
assert max_scale <= 1 |
|
s = 1.0 |
|
|
|
X, Y, S, C, Q, D = [], [], [], [], [], [] |
|
while s + 0.001 >= max(min_scale, min_size / max(H, W)): |
|
if s - 0.001 <= min(max_scale, max_size / max(H, W)): |
|
nh, nw = img.shape[2:] |
|
if verbose: |
|
print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") |
|
|
|
with torch.no_grad(): |
|
res = net(imgs=[img]) |
|
|
|
|
|
descriptors = res["descriptors"][0] |
|
reliability = res["reliability"][0] |
|
repeatability = res["repeatability"][0] |
|
|
|
|
|
|
|
y, x = detector(**res) |
|
c = reliability[0, 0, y, x] |
|
q = repeatability[0, 0, y, x] |
|
d = descriptors[0, :, y, x].t() |
|
n = d.shape[0] |
|
|
|
|
|
X.append(x.float() * W / nw) |
|
Y.append(y.float() * H / nh) |
|
S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device)) |
|
C.append(c) |
|
Q.append(q) |
|
D.append(d) |
|
s /= scale_f |
|
|
|
|
|
nh, nw = round(H * s), round(W * s) |
|
img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False) |
|
|
|
|
|
torch.backends.cudnn.benchmark = old_bm |
|
|
|
Y = torch.cat(Y) |
|
X = torch.cat(X) |
|
S = torch.cat(S) |
|
scores = torch.cat(C) * torch.cat(Q) |
|
XYS = torch.stack([X, Y, S], dim=-1) |
|
D = torch.cat(D) |
|
return XYS, D, scores |
|
|
|
|
|
def extract_keypoints(args): |
|
iscuda = common.torch_set_gpu(args.gpu) |
|
|
|
|
|
net = load_network(args.model) |
|
if iscuda: |
|
net = net.cuda() |
|
|
|
|
|
detector = NonMaxSuppression( |
|
rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr |
|
) |
|
|
|
while args.images: |
|
img_path = args.images.pop(0) |
|
|
|
if img_path.endswith(".txt"): |
|
args.images = open(img_path).read().splitlines() + args.images |
|
continue |
|
|
|
print(f"\nExtracting features for {img_path}") |
|
img = Image.open(img_path).convert("RGB") |
|
W, H = img.size |
|
img = norm_RGB(img)[None] |
|
if iscuda: |
|
img = img.cuda() |
|
|
|
|
|
xys, desc, scores = extract_multiscale( |
|
net, |
|
img, |
|
detector, |
|
scale_f=args.scale_f, |
|
min_scale=args.min_scale, |
|
max_scale=args.max_scale, |
|
min_size=args.min_size, |
|
max_size=args.max_size, |
|
verbose=True, |
|
) |
|
|
|
xys = xys.cpu().numpy() |
|
desc = desc.cpu().numpy() |
|
scores = scores.cpu().numpy() |
|
idxs = scores.argsort()[-args.top_k or None :] |
|
|
|
outpath = img_path + "." + args.tag |
|
print(f"Saving {len(idxs)} keypoints to {outpath}") |
|
np.savez( |
|
open(outpath, "wb"), |
|
imsize=(W, H), |
|
keypoints=xys[idxs], |
|
descriptors=desc[idxs], |
|
scores=scores[idxs], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser("Extract keypoints for a given image") |
|
parser.add_argument("--model", type=str, required=True, help="model path") |
|
|
|
parser.add_argument( |
|
"--images", type=str, required=True, nargs="+", help="images / list" |
|
) |
|
parser.add_argument("--tag", type=str, default="r2d2", help="output file tag") |
|
|
|
parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") |
|
|
|
parser.add_argument("--scale-f", type=float, default=2**0.25) |
|
parser.add_argument("--min-size", type=int, default=256) |
|
parser.add_argument("--max-size", type=int, default=1024) |
|
parser.add_argument("--min-scale", type=float, default=0) |
|
parser.add_argument("--max-scale", type=float, default=1) |
|
|
|
parser.add_argument("--reliability-thr", type=float, default=0.7) |
|
parser.add_argument("--repeatability-thr", type=float, default=0.7) |
|
|
|
parser.add_argument( |
|
"--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" |
|
) |
|
args = parser.parse_args() |
|
|
|
extract_keypoints(args) |
|
|