import argparse import numpy as np from PIL import Image import torch import math from tqdm import tqdm from os import path # Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) and more generally sensor-acquired data # it can be installed with # pip install kapture # for more information check out https://github.com/naver/kapture import kapture from kapture.io.records import get_image_fullpath from kapture.io.csv import kapture_from_dir, get_all_tar_handlers from kapture.io.csv import ( get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file, ) from kapture.io.features import ( get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file, ) from kapture.io.features import ( get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file, ) from lib.model_test import D2Net from lib.utils import preprocess_image from lib.pyramid import process_multiscale # import imageio # CUDA use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") # Argument parsing parser = argparse.ArgumentParser(description="Feature extraction script") parser.add_argument( "--kapture-root", type=str, required=True, help="path to kapture root directory" ) parser.add_argument( "--preprocessing", type=str, default="caffe", help="image preprocessing (caffe or torch)", ) parser.add_argument( "--model_file", type=str, default="models/d2_tf.pth", help="path to the full model" ) parser.add_argument( "--keypoints-type", type=str, default=None, help="keypoint type_name, default is filename of model", ) parser.add_argument( "--descriptors-type", type=str, default=None, help="descriptors type_name, default is filename of model", ) parser.add_argument( "--max_edge", type=int, default=1600, help="maximum image size at network input" ) parser.add_argument( "--max_sum_edges", type=int, default=2800, help="maximum sum of image sizes at network input", ) parser.add_argument( "--multiscale", dest="multiscale", action="store_true", help="extract multiscale features", ) parser.set_defaults(multiscale=False) parser.add_argument( "--no-relu", dest="use_relu", action="store_false", help="remove ReLU after the dense feature extraction module", ) parser.set_defaults(use_relu=True) parser.add_argument( "--max-keypoints", type=int, default=float("+inf"), help="max number of keypoints save to disk", ) args = parser.parse_args() print(args) with get_all_tar_handlers( args.kapture_root, mode={ kapture.Keypoints: "a", kapture.Descriptors: "a", kapture.GlobalFeatures: "r", kapture.Matches: "r", }, ) as tar_handlers: kdata = kapture_from_dir( args.kapture_root, skip_list=[ kapture.GlobalFeatures, kapture.Matches, kapture.Points3d, kapture.Observations, ], tar_handlers=tar_handlers, ) if kdata.keypoints is None: kdata.keypoints = {} if kdata.descriptors is None: kdata.descriptors = {} assert kdata.records_camera is not None image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] if args.keypoints_type is None: args.keypoints_type = path.splitext(path.basename(args.model_file))[0] print(f"keypoints_type set to {args.keypoints_type}") if args.descriptors_type is None: args.descriptors_type = path.splitext(path.basename(args.model_file))[0] print(f"descriptors_type set to {args.descriptors_type}") if ( args.keypoints_type in kdata.keypoints and args.descriptors_type in kdata.descriptors ): image_list = [ name for name in image_list if name not in kdata.keypoints[args.keypoints_type] or name not in kdata.descriptors[args.descriptors_type] ] if len(image_list) == 0: print("All features were already extracted") exit(0) else: print(f"Extracting d2net features for {len(image_list)} images") # Creating CNN model model = D2Net(model_file=args.model_file, use_relu=args.use_relu, use_cuda=use_cuda) if args.keypoints_type not in kdata.keypoints: keypoints_dtype = None keypoints_dsize = None else: keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize if args.descriptors_type not in kdata.descriptors: descriptors_dtype = None descriptors_dsize = None else: descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize # Process the files for image_name in tqdm(image_list, total=len(image_list)): img_path = get_image_fullpath(args.kapture_root, image_name) image = Image.open(img_path).convert("RGB") width, height = image.size resized_image = image resized_width = width resized_height = height max_edge = args.max_edge max_sum_edges = args.max_sum_edges if max(resized_width, resized_height) > max_edge: scale_multiplier = max_edge / max(resized_width, resized_height) resized_width = math.floor(resized_width * scale_multiplier) resized_height = math.floor(resized_height * scale_multiplier) resized_image = image.resize((resized_width, resized_height)) if resized_width + resized_height > max_sum_edges: scale_multiplier = max_sum_edges / (resized_width + resized_height) resized_width = math.floor(resized_width * scale_multiplier) resized_height = math.floor(resized_height * scale_multiplier) resized_image = image.resize((resized_width, resized_height)) fact_i = width / resized_width fact_j = height / resized_height resized_image = np.array(resized_image).astype("float") input_image = preprocess_image(resized_image, preprocessing=args.preprocessing) with torch.no_grad(): if args.multiscale: keypoints, scores, descriptors = process_multiscale( torch.tensor( input_image[np.newaxis, :, :, :].astype(np.float32), device=device, ), model, ) else: keypoints, scores, descriptors = process_multiscale( torch.tensor( input_image[np.newaxis, :, :, :].astype(np.float32), device=device, ), model, scales=[1], ) # Input image coordinates keypoints[:, 0] *= fact_i keypoints[:, 1] *= fact_j # i, j -> u, v keypoints = keypoints[:, [1, 0, 2]] if args.max_keypoints != float("+inf"): # keep the last (the highest) indexes idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints) :] keypoints = keypoints[idx_keep] descriptors = descriptors[idx_keep] if keypoints_dtype is None or descriptors_dtype is None: keypoints_dtype = keypoints.dtype descriptors_dtype = descriptors.dtype keypoints_dsize = keypoints.shape[1] descriptors_dsize = descriptors.shape[1] kdata.keypoints[args.keypoints_type] = kapture.Keypoints( "d2net", keypoints_dtype, keypoints_dsize ) kdata.descriptors[args.descriptors_type] = kapture.Descriptors( "d2net", descriptors_dtype, descriptors_dsize, args.keypoints_type, "L2" ) keypoints_config_absolute_path = get_feature_csv_fullpath( kapture.Keypoints, args.keypoints_type, args.kapture_root ) descriptors_config_absolute_path = get_feature_csv_fullpath( kapture.Descriptors, args.descriptors_type, args.kapture_root ) keypoints_to_file( keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type] ) descriptors_to_file( descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type], ) else: assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1] assert ( kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1] ) assert ( kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type ) assert kdata.descriptors[args.descriptors_type].metric_type == "L2" keypoints_fullpath = get_keypoints_fullpath( args.keypoints_type, args.kapture_root, image_name, tar_handlers ) print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}") image_keypoints_to_file(keypoints_fullpath, keypoints) kdata.keypoints[args.keypoints_type].add(image_name) descriptors_fullpath = get_descriptors_fullpath( args.descriptors_type, args.kapture_root, image_name, tar_handlers ) print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}") image_descriptors_to_file(descriptors_fullpath, descriptors) kdata.descriptors[args.descriptors_type].add(image_name) if not keypoints_check_dir( kdata.keypoints[args.keypoints_type], args.keypoints_type, args.kapture_root, tar_handlers, ) or not descriptors_check_dir( kdata.descriptors[args.descriptors_type], args.descriptors_type, args.kapture_root, tar_handlers, ): print( "local feature extraction ended successfully but not all files were saved" )