# Copyright 2019-present NAVER Corp. # CC BY-NC-SA 3.0 # Available only for non-commercial use from PIL import Image from tools import common from tools.dataloader import norm_RGB from nets.patchnet import * from os import path from extract import load_network, NonMaxSuppression, extract_multiscale # Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) # and more generally sensor-acquired data # it can be installed with # pip install kapture # for more information check out https://github.com/naver/kapture import kapture from kapture.io.records import get_image_fullpath from kapture.io.csv import kapture_from_dir from kapture.io.csv import ( get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file, ) from kapture.io.features import ( get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file, ) from kapture.io.features import ( get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file, ) from kapture.io.csv import get_all_tar_handlers def extract_kapture_keypoints(args): """ Extract r2d2 keypoints and descritors to the kapture format directly """ print("extract_kapture_keypoints...") with get_all_tar_handlers( args.kapture_root, mode={ kapture.Keypoints: "a", kapture.Descriptors: "a", kapture.GlobalFeatures: "r", kapture.Matches: "r", }, ) as tar_handlers: kdata = kapture_from_dir( args.kapture_root, None, skip_list=[ kapture.GlobalFeatures, kapture.Matches, kapture.Points3d, kapture.Observations, ], tar_handlers=tar_handlers, ) assert kdata.records_camera is not None image_list = [ filename for _, _, filename in kapture.flatten(kdata.records_camera) ] if args.keypoints_type is None: args.keypoints_type = path.splitext(path.basename(args.model))[0] print(f"keypoints_type set to {args.keypoints_type}") if args.descriptors_type is None: args.descriptors_type = path.splitext(path.basename(args.model))[0] print(f"descriptors_type set to {args.descriptors_type}") if ( kdata.keypoints is not None and args.keypoints_type in kdata.keypoints and kdata.descriptors is not None and args.descriptors_type in kdata.descriptors ): print( "detected already computed features of same keypoints_type/descriptors_type, resuming extraction..." ) image_list = [ name for name in image_list if name not in kdata.keypoints[args.keypoints_type] or name not in kdata.descriptors[args.descriptors_type] ] if len(image_list) == 0: print("All features were already extracted") return else: print(f"Extracting r2d2 features for {len(image_list)} images") iscuda = common.torch_set_gpu(args.gpu) # load the network... net = load_network(args.model) if iscuda: net = net.cuda() # create the non-maxima detector detector = NonMaxSuppression( rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr ) if kdata.keypoints is None: kdata.keypoints = {} if kdata.descriptors is None: kdata.descriptors = {} if args.keypoints_type not in kdata.keypoints: keypoints_dtype = None keypoints_dsize = None else: keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize if args.descriptors_type not in kdata.descriptors: descriptors_dtype = None descriptors_dsize = None else: descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize for image_name in image_list: img_path = get_image_fullpath(args.kapture_root, image_name) print(f"\nExtracting features for {img_path}") img = Image.open(img_path).convert("RGB") W, H = img.size img = norm_RGB(img)[None] if iscuda: img = img.cuda() # extract keypoints/descriptors for a single image xys, desc, scores = extract_multiscale( net, img, detector, scale_f=args.scale_f, min_scale=args.min_scale, max_scale=args.max_scale, min_size=args.min_size, max_size=args.max_size, verbose=True, ) xys = xys.cpu().numpy() desc = desc.cpu().numpy() scores = scores.cpu().numpy() idxs = scores.argsort()[-args.top_k or None :] xys = xys[idxs] desc = desc[idxs] if keypoints_dtype is None or descriptors_dtype is None: keypoints_dtype = xys.dtype descriptors_dtype = desc.dtype keypoints_dsize = xys.shape[1] descriptors_dsize = desc.shape[1] kdata.keypoints[args.keypoints_type] = kapture.Keypoints( "r2d2", keypoints_dtype, keypoints_dsize ) kdata.descriptors[args.descriptors_type] = kapture.Descriptors( "r2d2", descriptors_dtype, descriptors_dsize, args.keypoints_type, "L2", ) keypoints_config_absolute_path = get_feature_csv_fullpath( kapture.Keypoints, args.keypoints_type, args.kapture_root ) descriptors_config_absolute_path = get_feature_csv_fullpath( kapture.Descriptors, args.descriptors_type, args.kapture_root ) keypoints_to_file( keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type] ) descriptors_to_file( descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type], ) else: assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1] assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1] assert ( kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type ) assert kdata.descriptors[args.descriptors_type].metric_type == "L2" keypoints_fullpath = get_keypoints_fullpath( args.keypoints_type, args.kapture_root, image_name, tar_handlers ) print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}") image_keypoints_to_file(keypoints_fullpath, xys) kdata.keypoints[args.keypoints_type].add(image_name) descriptors_fullpath = get_descriptors_fullpath( args.descriptors_type, args.kapture_root, image_name, tar_handlers ) print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}") image_descriptors_to_file(descriptors_fullpath, desc) kdata.descriptors[args.descriptors_type].add(image_name) if not keypoints_check_dir( kdata.keypoints[args.keypoints_type], args.keypoints_type, args.kapture_root, tar_handlers, ) or not descriptors_check_dir( kdata.descriptors[args.descriptors_type], args.descriptors_type, args.kapture_root, tar_handlers, ): print( "local feature extraction ended successfully but not all files were saved" ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( "Extract r2d2 local features for all images in a dataset stored in the kapture format" ) parser.add_argument("--model", type=str, required=True, help="model path") parser.add_argument( "--keypoints-type", default=None, help="keypoint type_name, default is filename of model", ) parser.add_argument( "--descriptors-type", default=None, help="descriptors type_name, default is filename of model", ) parser.add_argument( "--kapture-root", type=str, required=True, help="path to kapture root directory" ) parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") parser.add_argument("--scale-f", type=float, default=2**0.25) parser.add_argument("--min-size", type=int, default=256) parser.add_argument("--max-size", type=int, default=1024) parser.add_argument("--min-scale", type=float, default=0) parser.add_argument("--max-scale", type=float, default=1) parser.add_argument("--reliability-thr", type=float, default=0.7) parser.add_argument("--repeatability-thr", type=float, default=0.7) parser.add_argument( "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" ) args = parser.parse_args() extract_kapture_keypoints(args)