Spaces:
Running
Running
import argparse | |
import numpy as np | |
from PIL import Image | |
import torch | |
import math | |
from tqdm import tqdm | |
from os import path | |
# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) and more generally sensor-acquired data | |
# it can be installed with | |
# pip install kapture | |
# for more information check out https://github.com/naver/kapture | |
import kapture | |
from kapture.io.records import get_image_fullpath | |
from kapture.io.csv import kapture_from_dir, get_all_tar_handlers | |
from kapture.io.csv import ( | |
get_feature_csv_fullpath, | |
keypoints_to_file, | |
descriptors_to_file, | |
) | |
from kapture.io.features import ( | |
get_keypoints_fullpath, | |
keypoints_check_dir, | |
image_keypoints_to_file, | |
) | |
from kapture.io.features import ( | |
get_descriptors_fullpath, | |
descriptors_check_dir, | |
image_descriptors_to_file, | |
) | |
from lib.model_test import D2Net | |
from lib.utils import preprocess_image | |
from lib.pyramid import process_multiscale | |
# import imageio | |
# CUDA | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda:0" if use_cuda else "cpu") | |
# Argument parsing | |
parser = argparse.ArgumentParser(description="Feature extraction script") | |
parser.add_argument( | |
"--kapture-root", type=str, required=True, help="path to kapture root directory" | |
) | |
parser.add_argument( | |
"--preprocessing", | |
type=str, | |
default="caffe", | |
help="image preprocessing (caffe or torch)", | |
) | |
parser.add_argument( | |
"--model_file", type=str, default="models/d2_tf.pth", help="path to the full model" | |
) | |
parser.add_argument( | |
"--keypoints-type", | |
type=str, | |
default=None, | |
help="keypoint type_name, default is filename of model", | |
) | |
parser.add_argument( | |
"--descriptors-type", | |
type=str, | |
default=None, | |
help="descriptors type_name, default is filename of model", | |
) | |
parser.add_argument( | |
"--max_edge", type=int, default=1600, help="maximum image size at network input" | |
) | |
parser.add_argument( | |
"--max_sum_edges", | |
type=int, | |
default=2800, | |
help="maximum sum of image sizes at network input", | |
) | |
parser.add_argument( | |
"--multiscale", | |
dest="multiscale", | |
action="store_true", | |
help="extract multiscale features", | |
) | |
parser.set_defaults(multiscale=False) | |
parser.add_argument( | |
"--no-relu", | |
dest="use_relu", | |
action="store_false", | |
help="remove ReLU after the dense feature extraction module", | |
) | |
parser.set_defaults(use_relu=True) | |
parser.add_argument( | |
"--max-keypoints", | |
type=int, | |
default=float("+inf"), | |
help="max number of keypoints save to disk", | |
) | |
args = parser.parse_args() | |
print(args) | |
with get_all_tar_handlers( | |
args.kapture_root, | |
mode={ | |
kapture.Keypoints: "a", | |
kapture.Descriptors: "a", | |
kapture.GlobalFeatures: "r", | |
kapture.Matches: "r", | |
}, | |
) as tar_handlers: | |
kdata = kapture_from_dir( | |
args.kapture_root, | |
skip_list=[ | |
kapture.GlobalFeatures, | |
kapture.Matches, | |
kapture.Points3d, | |
kapture.Observations, | |
], | |
tar_handlers=tar_handlers, | |
) | |
if kdata.keypoints is None: | |
kdata.keypoints = {} | |
if kdata.descriptors is None: | |
kdata.descriptors = {} | |
assert kdata.records_camera is not None | |
image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] | |
if args.keypoints_type is None: | |
args.keypoints_type = path.splitext(path.basename(args.model_file))[0] | |
print(f"keypoints_type set to {args.keypoints_type}") | |
if args.descriptors_type is None: | |
args.descriptors_type = path.splitext(path.basename(args.model_file))[0] | |
print(f"descriptors_type set to {args.descriptors_type}") | |
if ( | |
args.keypoints_type in kdata.keypoints | |
and args.descriptors_type in kdata.descriptors | |
): | |
image_list = [ | |
name | |
for name in image_list | |
if name not in kdata.keypoints[args.keypoints_type] | |
or name not in kdata.descriptors[args.descriptors_type] | |
] | |
if len(image_list) == 0: | |
print("All features were already extracted") | |
exit(0) | |
else: | |
print(f"Extracting d2net features for {len(image_list)} images") | |
# Creating CNN model | |
model = D2Net(model_file=args.model_file, use_relu=args.use_relu, use_cuda=use_cuda) | |
if args.keypoints_type not in kdata.keypoints: | |
keypoints_dtype = None | |
keypoints_dsize = None | |
else: | |
keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype | |
keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize | |
if args.descriptors_type not in kdata.descriptors: | |
descriptors_dtype = None | |
descriptors_dsize = None | |
else: | |
descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype | |
descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize | |
# Process the files | |
for image_name in tqdm(image_list, total=len(image_list)): | |
img_path = get_image_fullpath(args.kapture_root, image_name) | |
image = Image.open(img_path).convert("RGB") | |
width, height = image.size | |
resized_image = image | |
resized_width = width | |
resized_height = height | |
max_edge = args.max_edge | |
max_sum_edges = args.max_sum_edges | |
if max(resized_width, resized_height) > max_edge: | |
scale_multiplier = max_edge / max(resized_width, resized_height) | |
resized_width = math.floor(resized_width * scale_multiplier) | |
resized_height = math.floor(resized_height * scale_multiplier) | |
resized_image = image.resize((resized_width, resized_height)) | |
if resized_width + resized_height > max_sum_edges: | |
scale_multiplier = max_sum_edges / (resized_width + resized_height) | |
resized_width = math.floor(resized_width * scale_multiplier) | |
resized_height = math.floor(resized_height * scale_multiplier) | |
resized_image = image.resize((resized_width, resized_height)) | |
fact_i = width / resized_width | |
fact_j = height / resized_height | |
resized_image = np.array(resized_image).astype("float") | |
input_image = preprocess_image(resized_image, preprocessing=args.preprocessing) | |
with torch.no_grad(): | |
if args.multiscale: | |
keypoints, scores, descriptors = process_multiscale( | |
torch.tensor( | |
input_image[np.newaxis, :, :, :].astype(np.float32), | |
device=device, | |
), | |
model, | |
) | |
else: | |
keypoints, scores, descriptors = process_multiscale( | |
torch.tensor( | |
input_image[np.newaxis, :, :, :].astype(np.float32), | |
device=device, | |
), | |
model, | |
scales=[1], | |
) | |
# Input image coordinates | |
keypoints[:, 0] *= fact_i | |
keypoints[:, 1] *= fact_j | |
# i, j -> u, v | |
keypoints = keypoints[:, [1, 0, 2]] | |
if args.max_keypoints != float("+inf"): | |
# keep the last (the highest) indexes | |
idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints) :] | |
keypoints = keypoints[idx_keep] | |
descriptors = descriptors[idx_keep] | |
if keypoints_dtype is None or descriptors_dtype is None: | |
keypoints_dtype = keypoints.dtype | |
descriptors_dtype = descriptors.dtype | |
keypoints_dsize = keypoints.shape[1] | |
descriptors_dsize = descriptors.shape[1] | |
kdata.keypoints[args.keypoints_type] = kapture.Keypoints( | |
"d2net", keypoints_dtype, keypoints_dsize | |
) | |
kdata.descriptors[args.descriptors_type] = kapture.Descriptors( | |
"d2net", descriptors_dtype, descriptors_dsize, args.keypoints_type, "L2" | |
) | |
keypoints_config_absolute_path = get_feature_csv_fullpath( | |
kapture.Keypoints, args.keypoints_type, args.kapture_root | |
) | |
descriptors_config_absolute_path = get_feature_csv_fullpath( | |
kapture.Descriptors, args.descriptors_type, args.kapture_root | |
) | |
keypoints_to_file( | |
keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type] | |
) | |
descriptors_to_file( | |
descriptors_config_absolute_path, | |
kdata.descriptors[args.descriptors_type], | |
) | |
else: | |
assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype | |
assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype | |
assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1] | |
assert ( | |
kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1] | |
) | |
assert ( | |
kdata.descriptors[args.descriptors_type].keypoints_type | |
== args.keypoints_type | |
) | |
assert kdata.descriptors[args.descriptors_type].metric_type == "L2" | |
keypoints_fullpath = get_keypoints_fullpath( | |
args.keypoints_type, args.kapture_root, image_name, tar_handlers | |
) | |
print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}") | |
image_keypoints_to_file(keypoints_fullpath, keypoints) | |
kdata.keypoints[args.keypoints_type].add(image_name) | |
descriptors_fullpath = get_descriptors_fullpath( | |
args.descriptors_type, args.kapture_root, image_name, tar_handlers | |
) | |
print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}") | |
image_descriptors_to_file(descriptors_fullpath, descriptors) | |
kdata.descriptors[args.descriptors_type].add(image_name) | |
if not keypoints_check_dir( | |
kdata.keypoints[args.keypoints_type], | |
args.keypoints_type, | |
args.kapture_root, | |
tar_handlers, | |
) or not descriptors_check_dir( | |
kdata.descriptors[args.descriptors_type], | |
args.descriptors_type, | |
args.kapture_root, | |
tar_handlers, | |
): | |
print( | |
"local feature extraction ended successfully but not all files were saved" | |
) | |