gim-online / third_party /r2d2 /extract_kapture.py
Vincentqyw
fix: roma
8b973ee
raw
history blame
9.83 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
from PIL import Image
from tools import common
from tools.dataloader import norm_RGB
from nets.patchnet import *
from os import path
from extract import load_network, NonMaxSuppression, extract_multiscale
# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion)
# and more generally sensor-acquired data
# it can be installed with
# pip install kapture
# for more information check out https://github.com/naver/kapture
import kapture
from kapture.io.records import get_image_fullpath
from kapture.io.csv import kapture_from_dir
from kapture.io.csv import (
get_feature_csv_fullpath,
keypoints_to_file,
descriptors_to_file,
)
from kapture.io.features import (
get_keypoints_fullpath,
keypoints_check_dir,
image_keypoints_to_file,
)
from kapture.io.features import (
get_descriptors_fullpath,
descriptors_check_dir,
image_descriptors_to_file,
)
from kapture.io.csv import get_all_tar_handlers
def extract_kapture_keypoints(args):
"""
Extract r2d2 keypoints and descritors to the kapture format directly
"""
print("extract_kapture_keypoints...")
with get_all_tar_handlers(
args.kapture_root,
mode={
kapture.Keypoints: "a",
kapture.Descriptors: "a",
kapture.GlobalFeatures: "r",
kapture.Matches: "r",
},
) as tar_handlers:
kdata = kapture_from_dir(
args.kapture_root,
None,
skip_list=[
kapture.GlobalFeatures,
kapture.Matches,
kapture.Points3d,
kapture.Observations,
],
tar_handlers=tar_handlers,
)
assert kdata.records_camera is not None
image_list = [
filename for _, _, filename in kapture.flatten(kdata.records_camera)
]
if args.keypoints_type is None:
args.keypoints_type = path.splitext(path.basename(args.model))[0]
print(f"keypoints_type set to {args.keypoints_type}")
if args.descriptors_type is None:
args.descriptors_type = path.splitext(path.basename(args.model))[0]
print(f"descriptors_type set to {args.descriptors_type}")
if (
kdata.keypoints is not None
and args.keypoints_type in kdata.keypoints
and kdata.descriptors is not None
and args.descriptors_type in kdata.descriptors
):
print(
"detected already computed features of same keypoints_type/descriptors_type, resuming extraction..."
)
image_list = [
name
for name in image_list
if name not in kdata.keypoints[args.keypoints_type]
or name not in kdata.descriptors[args.descriptors_type]
]
if len(image_list) == 0:
print("All features were already extracted")
return
else:
print(f"Extracting r2d2 features for {len(image_list)} images")
iscuda = common.torch_set_gpu(args.gpu)
# load the network...
net = load_network(args.model)
if iscuda:
net = net.cuda()
# create the non-maxima detector
detector = NonMaxSuppression(
rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
)
if kdata.keypoints is None:
kdata.keypoints = {}
if kdata.descriptors is None:
kdata.descriptors = {}
if args.keypoints_type not in kdata.keypoints:
keypoints_dtype = None
keypoints_dsize = None
else:
keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype
keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize
if args.descriptors_type not in kdata.descriptors:
descriptors_dtype = None
descriptors_dsize = None
else:
descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype
descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize
for image_name in image_list:
img_path = get_image_fullpath(args.kapture_root, image_name)
print(f"\nExtracting features for {img_path}")
img = Image.open(img_path).convert("RGB")
W, H = img.size
img = norm_RGB(img)[None]
if iscuda:
img = img.cuda()
# extract keypoints/descriptors for a single image
xys, desc, scores = extract_multiscale(
net,
img,
detector,
scale_f=args.scale_f,
min_scale=args.min_scale,
max_scale=args.max_scale,
min_size=args.min_size,
max_size=args.max_size,
verbose=True,
)
xys = xys.cpu().numpy()
desc = desc.cpu().numpy()
scores = scores.cpu().numpy()
idxs = scores.argsort()[-args.top_k or None :]
xys = xys[idxs]
desc = desc[idxs]
if keypoints_dtype is None or descriptors_dtype is None:
keypoints_dtype = xys.dtype
descriptors_dtype = desc.dtype
keypoints_dsize = xys.shape[1]
descriptors_dsize = desc.shape[1]
kdata.keypoints[args.keypoints_type] = kapture.Keypoints(
"r2d2", keypoints_dtype, keypoints_dsize
)
kdata.descriptors[args.descriptors_type] = kapture.Descriptors(
"r2d2",
descriptors_dtype,
descriptors_dsize,
args.keypoints_type,
"L2",
)
keypoints_config_absolute_path = get_feature_csv_fullpath(
kapture.Keypoints, args.keypoints_type, args.kapture_root
)
descriptors_config_absolute_path = get_feature_csv_fullpath(
kapture.Descriptors, args.descriptors_type, args.kapture_root
)
keypoints_to_file(
keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]
)
descriptors_to_file(
descriptors_config_absolute_path,
kdata.descriptors[args.descriptors_type],
)
else:
assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype
assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype
assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1]
assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1]
assert (
kdata.descriptors[args.descriptors_type].keypoints_type
== args.keypoints_type
)
assert kdata.descriptors[args.descriptors_type].metric_type == "L2"
keypoints_fullpath = get_keypoints_fullpath(
args.keypoints_type, args.kapture_root, image_name, tar_handlers
)
print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}")
image_keypoints_to_file(keypoints_fullpath, xys)
kdata.keypoints[args.keypoints_type].add(image_name)
descriptors_fullpath = get_descriptors_fullpath(
args.descriptors_type, args.kapture_root, image_name, tar_handlers
)
print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}")
image_descriptors_to_file(descriptors_fullpath, desc)
kdata.descriptors[args.descriptors_type].add(image_name)
if not keypoints_check_dir(
kdata.keypoints[args.keypoints_type],
args.keypoints_type,
args.kapture_root,
tar_handlers,
) or not descriptors_check_dir(
kdata.descriptors[args.descriptors_type],
args.descriptors_type,
args.kapture_root,
tar_handlers,
):
print(
"local feature extraction ended successfully but not all files were saved"
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
"Extract r2d2 local features for all images in a dataset stored in the kapture format"
)
parser.add_argument("--model", type=str, required=True, help="model path")
parser.add_argument(
"--keypoints-type",
default=None,
help="keypoint type_name, default is filename of model",
)
parser.add_argument(
"--descriptors-type",
default=None,
help="descriptors type_name, default is filename of model",
)
parser.add_argument(
"--kapture-root", type=str, required=True, help="path to kapture root directory"
)
parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints")
parser.add_argument("--scale-f", type=float, default=2**0.25)
parser.add_argument("--min-size", type=int, default=256)
parser.add_argument("--max-size", type=int, default=1024)
parser.add_argument("--min-scale", type=float, default=0)
parser.add_argument("--max-scale", type=float, default=1)
parser.add_argument("--reliability-thr", type=float, default=0.7)
parser.add_argument("--repeatability-thr", type=float, default=0.7)
parser.add_argument(
"--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU"
)
args = parser.parse_args()
extract_kapture_keypoints(args)