import os import random import numpy as np import torch from itertools import combinations import cv2 import gradio as gr from hloc import matchers, extractors from hloc.utils.base_model import dynamic_load from hloc import match_dense, match_features, extract_features from hloc.utils.viz import add_text, plot_keypoints from .viz import draw_matches, fig2im, plot_images, plot_color_line_matches device = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_SETTING_THRESHOLD = 0.1 DEFAULT_SETTING_MAX_FEATURES = 4096 DEFAULT_DEFAULT_KEYPOINT_THRESHOLD = 0.01 DEFAULT_ENABLE_RANSAC = True DEFAULT_RANSAC_METHOD = "USAC_MAGSAC" DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 DEFAULT_RANSAC_CONFIDENCE = 0.999 DEFAULT_RANSAC_MAX_ITER = 10000 DEFAULT_MIN_NUM_MATCHES = 4 DEFAULT_MATCHING_THRESHOLD = 0.2 DEFAULT_SETTING_GEOMETRY = "Homography" def get_model(match_conf): Model = dynamic_load(matchers, match_conf["model"]["name"]) model = Model(match_conf["model"]).eval().to(device) return model def get_feature_model(conf): Model = dynamic_load(extractors, conf["model"]["name"]) model = Model(conf["model"]).eval().to(device) return model def gen_examples(): import datetime print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'gen_examples start') random.seed(1) example_matchers = [ "gim", "gim", "gim", "gim", ] def gen_images_pairs(path: str, count: int = 5): imgs_list = [ os.path.join(path, file) for file in os.listdir(path) if file.lower().endswith((".jpg", ".jpeg", ".png")) ] pairs = list(combinations(imgs_list, 2)) selected = random.sample(range(len(pairs)), count) return [pairs[i] for i in selected] # image pair path path = "datasets/sacre_coeur/mapping" pairs = gen_images_pairs(path, len(example_matchers)) gim_pairs = [('datasets/gim/0a.png', 'datasets/gim/0b.png'), ('datasets/gim/1a.png', 'datasets/gim/1b.png'), ('datasets/gim/2a.png', 'datasets/gim/2b.png'), ('datasets/gim/3a.png', 'datasets/gim/3b.png')] pairs = gim_pairs match_setting_threshold = DEFAULT_SETTING_THRESHOLD match_setting_max_features = DEFAULT_SETTING_MAX_FEATURES detect_keypoints_threshold = DEFAULT_DEFAULT_KEYPOINT_THRESHOLD ransac_method = DEFAULT_RANSAC_METHOD ransac_reproj_threshold = DEFAULT_RANSAC_REPROJ_THRESHOLD ransac_confidence = DEFAULT_RANSAC_CONFIDENCE ransac_max_iter = DEFAULT_RANSAC_MAX_ITER input_lists = [] for pair, mt in zip(gim_pairs, example_matchers): input_lists.append( [ pair[0], pair[1], match_setting_threshold, match_setting_max_features, detect_keypoints_threshold, mt, # enable_ransac, ransac_method, ransac_reproj_threshold, ransac_confidence, ransac_max_iter, ] ) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'gen_examples end') return input_lists def filter_matches( pred, ransac_method=DEFAULT_RANSAC_METHOD, ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, ): mkpts0 = None mkpts1 = None feature_type = None if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): mkpts0 = pred["keypoints0_orig"] mkpts1 = pred["keypoints1_orig"] feature_type = "KEYPOINT" elif ( "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys() ): mkpts0 = pred["line_keypoints0_orig"] mkpts1 = pred["line_keypoints1_orig"] feature_type = "LINE" else: return pred if mkpts0 is None or mkpts0 is None: return pred if ransac_method not in ransac_zoo.keys(): ransac_method = DEFAULT_RANSAC_METHOD if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES: return pred H, mask = cv2.findHomography( mkpts0, mkpts1, method=ransac_zoo[ransac_method], ransacReprojThreshold=ransac_reproj_threshold, confidence=ransac_confidence, maxIters=ransac_max_iter, ) mask = np.array(mask.ravel().astype("bool"), dtype="bool") if H is not None: if feature_type == "KEYPOINT": pred["keypoints0_orig"] = mkpts0[mask] pred["keypoints1_orig"] = mkpts1[mask] pred["mconf"] = pred["mconf"][mask] elif feature_type == "LINE": pred["line_keypoints0_orig"] = mkpts0[mask] pred["line_keypoints1_orig"] = mkpts1[mask] return pred def compute_geom( pred, ransac_method=DEFAULT_RANSAC_METHOD, ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, ) -> dict: mkpts0 = None mkpts1 = None if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): mkpts0 = pred["keypoints0_orig"] mkpts1 = pred["keypoints1_orig"] if ( "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys() ): mkpts0 = pred["line_keypoints0_orig"] mkpts1 = pred["line_keypoints1_orig"] if mkpts0 is not None and mkpts1 is not None: if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: return {} h1, w1, _ = pred["image0_orig"].shape geo_info = {} F, inliers = cv2.findFundamentalMat( mkpts0, mkpts1, method=ransac_zoo[ransac_method], ransacReprojThreshold=ransac_reproj_threshold, confidence=ransac_confidence, maxIters=ransac_max_iter, ) if F is not None: geo_info["Fundamental"] = F.tolist() H, _ = cv2.findHomography( mkpts1, mkpts0, method=ransac_zoo[ransac_method], ransacReprojThreshold=ransac_reproj_threshold, confidence=ransac_confidence, maxIters=ransac_max_iter, ) if H is not None: geo_info["Homography"] = H.tolist() _, H1, H2 = cv2.stereoRectifyUncalibrated( mkpts0.reshape(-1, 2), mkpts1.reshape(-1, 2), F, imgSize=(w1, h1), ) geo_info["H1"] = H1.tolist() geo_info["H2"] = H2.tolist() return geo_info else: return {} def wrap_images(img0, img1, geo_info, geom_type): h1, w1, _ = img0.shape h2, w2, _ = img1.shape result_matrix = None if geo_info is not None and len(geo_info) != 0: rectified_image0 = img0 rectified_image1 = None H = np.array(geo_info["Homography"]) F = np.array(geo_info["Fundamental"]) title = [] if geom_type == "Homography": rectified_image1 = cv2.warpPerspective( img1, H, (img0.shape[1], img0.shape[0]) ) result_matrix = H title = ["Image 0", "Image 1 - warped"] elif geom_type == "Fundamental": H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1)) rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2)) result_matrix = F title = ["Image 0 - warped", "Image 1 - warped"] else: print("Error: Unknown geometry type") fig = plot_images( [rectified_image0.squeeze(), rectified_image1.squeeze()], title, dpi=300, ) dictionary = { "row1": result_matrix[0].tolist(), "row2": result_matrix[1].tolist(), "row3": result_matrix[2].tolist(), } return fig2im(fig), dictionary else: return None, None def change_estimate_geom(input_image0, input_image1, matches_info, choice): if ( matches_info is None or len(matches_info) < 1 or "geom_info" not in matches_info.keys() ): return None, None geom_info = matches_info["geom_info"] wrapped_images = None if choice != "No": wrapped_images, _ = wrap_images( input_image0, input_image1, geom_info, choice ) return wrapped_images, matches_info else: return None, None def display_matches(pred: dict, titles=[], dpi=300): img0 = pred["image0_orig"] img1 = pred["image1_orig"] num_inliers = 0 if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): mkpts0 = pred["keypoints0_orig"] mkpts1 = pred["keypoints1_orig"] num_inliers = len(mkpts0) if "mconf" in pred.keys(): mconf = pred["mconf"] else: mconf = np.ones(len(mkpts0)) fig_mkpts = draw_matches( mkpts0, mkpts1, img0, img1, mconf, dpi=dpi, titles=titles, ) fig = fig_mkpts if "line0_orig" in pred.keys() and "line1_orig" in pred.keys(): # lines mtlines0 = pred["line0_orig"] mtlines1 = pred["line1_orig"] num_inliers = len(mtlines0) fig_lines = plot_images( [img0.squeeze(), img1.squeeze()], ["Image 0 - matched lines", "Image 1 - matched lines"], dpi=300, ) fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2) fig_lines = fig2im(fig_lines) # keypoints mkpts0 = pred["line_keypoints0_orig"] mkpts1 = pred["line_keypoints1_orig"] if mkpts0 is not None and mkpts1 is not None: num_inliers = len(mkpts0) if "mconf" in pred.keys(): mconf = pred["mconf"] else: mconf = np.ones(len(mkpts0)) fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300) fig_lines = cv2.resize( fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]) ) fig = np.concatenate([fig_mkpts, fig_lines], axis=0) else: fig = fig_lines return fig, num_inliers def run_matching( image0, image1, match_threshold, extract_max_keypoints, keypoint_threshold, key, ransac_method=DEFAULT_RANSAC_METHOD, ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, choice_estimate_geom=DEFAULT_SETTING_GEOMETRY, ): import datetime print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'run_matching start') # image0 and image1 is RGB mode if image0 is None or image1 is None: raise gr.Error("Error: No images found! Please upload two images.") # init output output_keypoints = None output_matches_raw = None output_matches_ransac = None model = matcher_zoo[key] match_conf = model["config"] # update match config match_conf["model"]["match_threshold"] = match_threshold match_conf["model"]["max_keypoints"] = extract_max_keypoints matcher = get_model(match_conf) if model["dense"]: pred = match_dense.match_images( matcher, image0, image1, match_conf["preprocessing"], device=device ) del matcher extract_conf = None else: extract_conf = model["config_feature"] # update extract config extract_conf["model"]["max_keypoints"] = extract_max_keypoints extract_conf["model"]["keypoint_threshold"] = keypoint_threshold extractor = get_feature_model(extract_conf) pred0 = extract_features.extract( extractor, image0, extract_conf["preprocessing"] ) pred1 = extract_features.extract( extractor, image1, extract_conf["preprocessing"] ) pred = match_features.match_images(matcher, pred0, pred1) del extractor # plot images with keypoints titles = [ "Image 0 - Keypoints", "Image 1 - Keypoints", ] import datetime print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot_images start') output_keypoints = plot_images([image0, image1], titles=titles, dpi=300) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot_images end') print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot_keypoints start') if "keypoints0" in pred.keys() and "keypoints1" in pred.keys(): plot_keypoints([pred["keypoints0"], pred["keypoints1"]]) text = ( f"# keypoints0: {len(pred['keypoints0'])} \n" + f"# keypoints1: {len(pred['keypoints1'])}" ) add_text(0, text, fs=15) output_keypoints = fig2im(output_keypoints) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot_keypoints end') # plot images with raw matches titles = [ "Image 0 - Raw matched keypoints", "Image 1 - Raw matched keypoints", ] print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot images with raw matches start') output_matches_raw, num_matches_raw = display_matches(pred, titles=titles) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot images with raw matches end') # if enable_ransac: print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'filter_matches start') filter_matches( pred, ransac_method=ransac_method, ransac_reproj_threshold=ransac_reproj_threshold, ransac_confidence=ransac_confidence, ransac_max_iter=ransac_max_iter, ) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'filter_matches end') # plot images with ransac matches titles = [ "Image 0 - Ransac matched keypoints", "Image 1 - Ransac matched keypoints", ] print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot images with raw matches start') output_matches_ransac, num_matches_ransac = display_matches( pred, titles=titles ) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot images with raw matches end') # plot wrapped images print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot wrapped images start') geom_info = compute_geom(pred) output_wrapped, _ = change_estimate_geom( pred["image0_orig"], pred["image1_orig"], {"geom_info": geom_info}, choice_estimate_geom, ) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'plot wrapped images end') del pred return ( output_keypoints, output_matches_raw, output_matches_ransac, { "number raw matches": num_matches_raw, "number ransac matches": num_matches_ransac, }, { "match_conf": match_conf, "extractor_conf": extract_conf, }, { "geom_info": geom_info, }, output_wrapped, ) # @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html # AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs ransac_zoo = { "RANSAC": cv2.RANSAC, "USAC_MAGSAC": cv2.USAC_MAGSAC, "USAC_DEFAULT": cv2.USAC_DEFAULT, "USAC_FM_8PTS": cv2.USAC_FM_8PTS, "USAC_PROSAC": cv2.USAC_PROSAC, "USAC_FAST": cv2.USAC_FAST, "USAC_ACCURATE": cv2.USAC_ACCURATE, "USAC_PARALLEL": cv2.USAC_PARALLEL, } # Matchers collections matcher_zoo = { "gim": {"config": match_dense.confs["gim"], "dense": True}, "loftr": {"config": match_dense.confs["loftr"], "dense": True}, "superpoint+superglue": { "config": match_features.confs["superglue"], "config_feature": extract_features.confs["superpoint_max"], "dense": False, }, "superpoint+lightglue": { "config": match_features.confs["superpoint-lightglue"], "config_feature": extract_features.confs["superpoint_max"], "dense": False, }, "d2net": { "config": match_features.confs["NN-mutual"], "config_feature": extract_features.confs["d2net-ss"], "dense": False, }, "hardnet": { "config": match_features.confs["NN-mutual"], "config_feature": extract_features.confs["hardnet"], "dense": False, }, "sift": { "config": match_features.confs["NN-mutual"], "config_feature": extract_features.confs["sift"], "dense": False, }, }