Spaces:
Sleeping
Sleeping
| #! /usr/bin/env python3 | |
| # | |
| # %BANNER_BEGIN% | |
| # --------------------------------------------------------------------- | |
| # %COPYRIGHT_BEGIN% | |
| # | |
| # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL | |
| # | |
| # Unpublished Copyright (c) 2020 | |
| # Magic Leap, Inc., All Rights Reserved. | |
| # | |
| # NOTICE: All information contained herein is, and remains the property | |
| # of COMPANY. The intellectual and technical concepts contained herein | |
| # are proprietary to COMPANY and may be covered by U.S. and Foreign | |
| # Patents, patents in process, and are protected by trade secret or | |
| # copyright law. Dissemination of this information or reproduction of | |
| # this material is strictly forbidden unless prior written permission is | |
| # obtained from COMPANY. Access to the source code contained herein is | |
| # hereby forbidden to anyone except current COMPANY employees, managers | |
| # or contractors who have executed Confidentiality and Non-disclosure | |
| # agreements explicitly covering such access. | |
| # | |
| # The copyright notice above does not evidence any actual or intended | |
| # publication or disclosure of this source code, which includes | |
| # information that is confidential and/or proprietary, and is a trade | |
| # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, | |
| # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS | |
| # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS | |
| # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND | |
| # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE | |
| # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS | |
| # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, | |
| # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. | |
| # | |
| # %COPYRIGHT_END% | |
| # ---------------------------------------------------------------------- | |
| # %AUTHORS_BEGIN% | |
| # | |
| # Originating Authors: Paul-Edouard Sarlin | |
| # Daniel DeTone | |
| # Tomasz Malisiewicz | |
| # | |
| # %AUTHORS_END% | |
| # --------------------------------------------------------------------*/ | |
| # %BANNER_END% | |
| from pathlib import Path | |
| import argparse | |
| import random | |
| import numpy as np | |
| import matplotlib.cm as cm | |
| import torch | |
| from models.matching import Matching | |
| from models.utils import ( | |
| compute_pose_error, | |
| compute_epipolar_error, | |
| estimate_pose, | |
| make_matching_plot, | |
| error_colormap, | |
| AverageTimer, | |
| pose_auc, | |
| read_image, | |
| rotate_intrinsics, | |
| rotate_pose_inplane, | |
| scale_intrinsics, | |
| ) | |
| torch.set_grad_enabled(False) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Image pair matching and pose evaluation with SuperGlue", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--input_pairs", | |
| type=str, | |
| default="assets/scannet_sample_pairs_with_gt.txt", | |
| help="Path to the list of image pairs", | |
| ) | |
| parser.add_argument( | |
| "--input_dir", | |
| type=str, | |
| default="assets/scannet_sample_images/", | |
| help="Path to the directory that contains the images", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="dump_match_pairs/", | |
| help="Path to the directory in which the .npz results and optionally," | |
| "the visualization images are written", | |
| ) | |
| parser.add_argument( | |
| "--max_length", type=int, default=-1, help="Maximum number of pairs to evaluate" | |
| ) | |
| parser.add_argument( | |
| "--resize", | |
| type=int, | |
| nargs="+", | |
| default=[640, 480], | |
| help="Resize the input image before running inference. If two numbers, " | |
| "resize to the exact dimensions, if one number, resize the max " | |
| "dimension, if -1, do not resize", | |
| ) | |
| parser.add_argument( | |
| "--resize_float", | |
| action="store_true", | |
| help="Resize the image after casting uint8 to float", | |
| ) | |
| parser.add_argument( | |
| "--superglue", | |
| choices={"indoor", "outdoor"}, | |
| default="indoor", | |
| help="SuperGlue weights", | |
| ) | |
| parser.add_argument( | |
| "--max_keypoints", | |
| type=int, | |
| default=1024, | |
| help="Maximum number of keypoints detected by Superpoint" | |
| " ('-1' keeps all keypoints)", | |
| ) | |
| parser.add_argument( | |
| "--keypoint_threshold", | |
| type=float, | |
| default=0.005, | |
| help="SuperPoint keypoint detector confidence threshold", | |
| ) | |
| parser.add_argument( | |
| "--nms_radius", | |
| type=int, | |
| default=4, | |
| help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)", | |
| ) | |
| parser.add_argument( | |
| "--sinkhorn_iterations", | |
| type=int, | |
| default=20, | |
| help="Number of Sinkhorn iterations performed by SuperGlue", | |
| ) | |
| parser.add_argument( | |
| "--match_threshold", type=float, default=0.2, help="SuperGlue match threshold" | |
| ) | |
| parser.add_argument( | |
| "--viz", action="store_true", help="Visualize the matches and dump the plots" | |
| ) | |
| parser.add_argument( | |
| "--eval", | |
| action="store_true", | |
| help="Perform the evaluation" " (requires ground truth pose and intrinsics)", | |
| ) | |
| parser.add_argument( | |
| "--fast_viz", | |
| action="store_true", | |
| help="Use faster image visualization with OpenCV instead of Matplotlib", | |
| ) | |
| parser.add_argument( | |
| "--cache", | |
| action="store_true", | |
| help="Skip the pair if output .npz files are already found", | |
| ) | |
| parser.add_argument( | |
| "--show_keypoints", | |
| action="store_true", | |
| help="Plot the keypoints in addition to the matches", | |
| ) | |
| parser.add_argument( | |
| "--viz_extension", | |
| type=str, | |
| default="png", | |
| choices=["png", "pdf"], | |
| help="Visualization file extension. Use pdf for highest-quality.", | |
| ) | |
| parser.add_argument( | |
| "--opencv_display", | |
| action="store_true", | |
| help="Visualize via OpenCV before saving output images", | |
| ) | |
| parser.add_argument( | |
| "--shuffle", | |
| action="store_true", | |
| help="Shuffle ordering of pairs before processing", | |
| ) | |
| parser.add_argument( | |
| "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode." | |
| ) | |
| opt = parser.parse_args() | |
| print(opt) | |
| assert not ( | |
| opt.opencv_display and not opt.viz | |
| ), "Must use --viz with --opencv_display" | |
| assert not ( | |
| opt.opencv_display and not opt.fast_viz | |
| ), "Cannot use --opencv_display without --fast_viz" | |
| assert not (opt.fast_viz and not opt.viz), "Must use --viz with --fast_viz" | |
| assert not ( | |
| opt.fast_viz and opt.viz_extension == "pdf" | |
| ), "Cannot use pdf extension with --fast_viz" | |
| if len(opt.resize) == 2 and opt.resize[1] == -1: | |
| opt.resize = opt.resize[0:1] | |
| if len(opt.resize) == 2: | |
| print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1])) | |
| elif len(opt.resize) == 1 and opt.resize[0] > 0: | |
| print("Will resize max dimension to {}".format(opt.resize[0])) | |
| elif len(opt.resize) == 1: | |
| print("Will not resize images") | |
| else: | |
| raise ValueError("Cannot specify more than two integers for --resize") | |
| with open(opt.input_pairs, "r") as f: | |
| pairs = [l.split() for l in f.readlines()] | |
| if opt.max_length > -1: | |
| pairs = pairs[0 : np.min([len(pairs), opt.max_length])] | |
| if opt.shuffle: | |
| random.Random(0).shuffle(pairs) | |
| if opt.eval: | |
| if not all([len(p) == 38 for p in pairs]): | |
| raise ValueError( | |
| "All pairs should have ground truth info for evaluation." | |
| 'File "{}" needs 38 valid entries per row'.format(opt.input_pairs) | |
| ) | |
| # Load the SuperPoint and SuperGlue models. | |
| device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu" | |
| print('Running inference on device "{}"'.format(device)) | |
| config = { | |
| "superpoint": { | |
| "nms_radius": opt.nms_radius, | |
| "keypoint_threshold": opt.keypoint_threshold, | |
| "max_keypoints": opt.max_keypoints, | |
| }, | |
| "superglue": { | |
| "weights": opt.superglue, | |
| "sinkhorn_iterations": opt.sinkhorn_iterations, | |
| "match_threshold": opt.match_threshold, | |
| }, | |
| } | |
| matching = Matching(config).eval().to(device) | |
| # Create the output directories if they do not exist already. | |
| input_dir = Path(opt.input_dir) | |
| print('Looking for data in directory "{}"'.format(input_dir)) | |
| output_dir = Path(opt.output_dir) | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| print('Will write matches to directory "{}"'.format(output_dir)) | |
| if opt.eval: | |
| print("Will write evaluation results", 'to directory "{}"'.format(output_dir)) | |
| if opt.viz: | |
| print("Will write visualization images to", 'directory "{}"'.format(output_dir)) | |
| timer = AverageTimer(newline=True) | |
| for i, pair in enumerate(pairs): | |
| name0, name1 = pair[:2] | |
| stem0, stem1 = Path(name0).stem, Path(name1).stem | |
| matches_path = output_dir / "{}_{}_matches.npz".format(stem0, stem1) | |
| eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1) | |
| viz_path = output_dir / "{}_{}_matches.{}".format( | |
| stem0, stem1, opt.viz_extension | |
| ) | |
| viz_eval_path = output_dir / "{}_{}_evaluation.{}".format( | |
| stem0, stem1, opt.viz_extension | |
| ) | |
| # Handle --cache logic. | |
| do_match = True | |
| do_eval = opt.eval | |
| do_viz = opt.viz | |
| do_viz_eval = opt.eval and opt.viz | |
| if opt.cache: | |
| if matches_path.exists(): | |
| try: | |
| results = np.load(matches_path) | |
| except: | |
| raise IOError("Cannot load matches .npz file: %s" % matches_path) | |
| kpts0, kpts1 = results["keypoints0"], results["keypoints1"] | |
| matches, conf = results["matches"], results["match_confidence"] | |
| do_match = False | |
| if opt.eval and eval_path.exists(): | |
| try: | |
| results = np.load(eval_path) | |
| except: | |
| raise IOError("Cannot load eval .npz file: %s" % eval_path) | |
| err_R, err_t = results["error_R"], results["error_t"] | |
| precision = results["precision"] | |
| matching_score = results["matching_score"] | |
| num_correct = results["num_correct"] | |
| epi_errs = results["epipolar_errors"] | |
| do_eval = False | |
| if opt.viz and viz_path.exists(): | |
| do_viz = False | |
| if opt.viz and opt.eval and viz_eval_path.exists(): | |
| do_viz_eval = False | |
| timer.update("load_cache") | |
| if not (do_match or do_eval or do_viz or do_viz_eval): | |
| timer.print("Finished pair {:5} of {:5}".format(i, len(pairs))) | |
| continue | |
| # If a rotation integer is provided (e.g. from EXIF data), use it: | |
| if len(pair) >= 5: | |
| rot0, rot1 = int(pair[2]), int(pair[3]) | |
| else: | |
| rot0, rot1 = 0, 0 | |
| # Load the image pair. | |
| image0, inp0, scales0 = read_image( | |
| input_dir / name0, device, opt.resize, rot0, opt.resize_float | |
| ) | |
| image1, inp1, scales1 = read_image( | |
| input_dir / name1, device, opt.resize, rot1, opt.resize_float | |
| ) | |
| if image0 is None or image1 is None: | |
| print( | |
| "Problem reading image pair: {} {}".format( | |
| input_dir / name0, input_dir / name1 | |
| ) | |
| ) | |
| exit(1) | |
| timer.update("load_image") | |
| if do_match: | |
| # Perform the matching. | |
| pred = matching({"image0": inp0, "image1": inp1}) | |
| pred = {k: v[0].cpu().numpy() for k, v in pred.items()} | |
| kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] | |
| matches, conf = pred["matches0"], pred["matching_scores0"] | |
| timer.update("matcher") | |
| # Write the matches to disk. | |
| out_matches = { | |
| "keypoints0": kpts0, | |
| "keypoints1": kpts1, | |
| "matches": matches, | |
| "match_confidence": conf, | |
| } | |
| np.savez(str(matches_path), **out_matches) | |
| # Keep the matching keypoints. | |
| valid = matches > -1 | |
| mkpts0 = kpts0[valid] | |
| mkpts1 = kpts1[matches[valid]] | |
| mconf = conf[valid] | |
| if do_eval: | |
| # Estimate the pose and compute the pose error. | |
| assert len(pair) == 38, "Pair does not have ground truth info" | |
| K0 = np.array(pair[4:13]).astype(float).reshape(3, 3) | |
| K1 = np.array(pair[13:22]).astype(float).reshape(3, 3) | |
| T_0to1 = np.array(pair[22:]).astype(float).reshape(4, 4) | |
| # Scale the intrinsics to resized image. | |
| K0 = scale_intrinsics(K0, scales0) | |
| K1 = scale_intrinsics(K1, scales1) | |
| # Update the intrinsics + extrinsics if EXIF rotation was found. | |
| if rot0 != 0 or rot1 != 0: | |
| cam0_T_w = np.eye(4) | |
| cam1_T_w = T_0to1 | |
| if rot0 != 0: | |
| K0 = rotate_intrinsics(K0, image0.shape, rot0) | |
| cam0_T_w = rotate_pose_inplane(cam0_T_w, rot0) | |
| if rot1 != 0: | |
| K1 = rotate_intrinsics(K1, image1.shape, rot1) | |
| cam1_T_w = rotate_pose_inplane(cam1_T_w, rot1) | |
| cam1_T_cam0 = cam1_T_w @ np.linalg.inv(cam0_T_w) | |
| T_0to1 = cam1_T_cam0 | |
| epi_errs = compute_epipolar_error(mkpts0, mkpts1, T_0to1, K0, K1) | |
| correct = epi_errs < 5e-4 | |
| num_correct = np.sum(correct) | |
| precision = np.mean(correct) if len(correct) > 0 else 0 | |
| matching_score = num_correct / len(kpts0) if len(kpts0) > 0 else 0 | |
| thresh = 1.0 # In pixels relative to resized image size. | |
| ret = estimate_pose(mkpts0, mkpts1, K0, K1, thresh) | |
| if ret is None: | |
| err_t, err_R = np.inf, np.inf | |
| else: | |
| R, t, inliers = ret | |
| err_t, err_R = compute_pose_error(T_0to1, R, t) | |
| # Write the evaluation results to disk. | |
| out_eval = { | |
| "error_t": err_t, | |
| "error_R": err_R, | |
| "precision": precision, | |
| "matching_score": matching_score, | |
| "num_correct": num_correct, | |
| "epipolar_errors": epi_errs, | |
| } | |
| np.savez(str(eval_path), **out_eval) | |
| timer.update("eval") | |
| if do_viz: | |
| # Visualize the matches. | |
| color = cm.jet(mconf) | |
| text = [ | |
| "SuperGlue", | |
| "Keypoints: {}:{}".format(len(kpts0), len(kpts1)), | |
| "Matches: {}".format(len(mkpts0)), | |
| ] | |
| if rot0 != 0 or rot1 != 0: | |
| text.append("Rotation: {}:{}".format(rot0, rot1)) | |
| # Display extra parameter info. | |
| k_thresh = matching.superpoint.config["keypoint_threshold"] | |
| m_thresh = matching.superglue.config["match_threshold"] | |
| small_text = [ | |
| "Keypoint Threshold: {:.4f}".format(k_thresh), | |
| "Match Threshold: {:.2f}".format(m_thresh), | |
| "Image Pair: {}:{}".format(stem0, stem1), | |
| ] | |
| make_matching_plot( | |
| image0, | |
| image1, | |
| kpts0, | |
| kpts1, | |
| mkpts0, | |
| mkpts1, | |
| color, | |
| text, | |
| viz_path, | |
| opt.show_keypoints, | |
| opt.fast_viz, | |
| opt.opencv_display, | |
| "Matches", | |
| small_text, | |
| ) | |
| timer.update("viz_match") | |
| if do_viz_eval: | |
| # Visualize the evaluation results for the image pair. | |
| color = np.clip((epi_errs - 0) / (1e-3 - 0), 0, 1) | |
| color = error_colormap(1 - color) | |
| deg, delta = " deg", "Delta " | |
| if not opt.fast_viz: | |
| deg, delta = "°", "$\\Delta$" | |
| e_t = "FAIL" if np.isinf(err_t) else "{:.1f}{}".format(err_t, deg) | |
| e_R = "FAIL" if np.isinf(err_R) else "{:.1f}{}".format(err_R, deg) | |
| text = [ | |
| "SuperGlue", | |
| "{}R: {}".format(delta, e_R), | |
| "{}t: {}".format(delta, e_t), | |
| "inliers: {}/{}".format(num_correct, (matches > -1).sum()), | |
| ] | |
| if rot0 != 0 or rot1 != 0: | |
| text.append("Rotation: {}:{}".format(rot0, rot1)) | |
| # Display extra parameter info (only works with --fast_viz). | |
| k_thresh = matching.superpoint.config["keypoint_threshold"] | |
| m_thresh = matching.superglue.config["match_threshold"] | |
| small_text = [ | |
| "Keypoint Threshold: {:.4f}".format(k_thresh), | |
| "Match Threshold: {:.2f}".format(m_thresh), | |
| "Image Pair: {}:{}".format(stem0, stem1), | |
| ] | |
| make_matching_plot( | |
| image0, | |
| image1, | |
| kpts0, | |
| kpts1, | |
| mkpts0, | |
| mkpts1, | |
| color, | |
| text, | |
| viz_eval_path, | |
| opt.show_keypoints, | |
| opt.fast_viz, | |
| opt.opencv_display, | |
| "Relative Pose", | |
| small_text, | |
| ) | |
| timer.update("viz_eval") | |
| timer.print("Finished pair {:5} of {:5}".format(i, len(pairs))) | |
| if opt.eval: | |
| # Collate the results into a final table and print to terminal. | |
| pose_errors = [] | |
| precisions = [] | |
| matching_scores = [] | |
| for pair in pairs: | |
| name0, name1 = pair[:2] | |
| stem0, stem1 = Path(name0).stem, Path(name1).stem | |
| eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1) | |
| results = np.load(eval_path) | |
| pose_error = np.maximum(results["error_t"], results["error_R"]) | |
| pose_errors.append(pose_error) | |
| precisions.append(results["precision"]) | |
| matching_scores.append(results["matching_score"]) | |
| thresholds = [5, 10, 20] | |
| aucs = pose_auc(pose_errors, thresholds) | |
| aucs = [100.0 * yy for yy in aucs] | |
| prec = 100.0 * np.mean(precisions) | |
| ms = 100.0 * np.mean(matching_scores) | |
| print("Evaluation Results (mean over {} pairs):".format(len(pairs))) | |
| print("AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t") | |
| print( | |
| "{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t".format( | |
| aucs[0], aucs[1], aucs[2], prec, ms | |
| ) | |
| ) | |