import argparse import os from os.path import join import cv2 import torch from matplotlib import pyplot as plt from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT from .drawing import ( plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches, ) from .models.two_view_pipeline import TwoViewPipeline def main(): # Parse input parameters parser = argparse.ArgumentParser( prog="GlueStick Demo", description="Demo app to show the point and line matches obtained by GlueStick", ) parser.add_argument("-img1", default=join("resources" + os.path.sep + "img1.jpg")) parser.add_argument("-img2", default=join("resources" + os.path.sep + "img2.jpg")) parser.add_argument("--max_pts", type=int, default=1000) parser.add_argument("--max_lines", type=int, default=300) parser.add_argument("--skip-imshow", default=False, action="store_true") args = parser.parse_args() # Evaluation config conf = { "name": "two_view_pipeline", "use_lines": True, "extractor": { "name": "wireframe", "sp_params": { "force_num_keypoints": False, "max_num_keypoints": args.max_pts, }, "wireframe_params": { "merge_points": True, "merge_line_endpoints": True, }, "max_n_lines": args.max_lines, }, "matcher": { "name": "gluestick", "weights": str( GLUESTICK_ROOT / "resources" / "weights" / "checkpoint_GlueStick_MD.tar" ), "trainable": False, }, "ground_truth": { "from_pose_depth": False, }, } device = "cuda" if torch.cuda.is_available() else "cpu" pipeline_model = TwoViewPipeline(conf).to(device).eval() gray0 = cv2.imread(args.img1, 0) gray1 = cv2.imread(args.img2, 0) torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) torch_gray0, torch_gray1 = ( torch_gray0.to(device)[None], torch_gray1.to(device)[None], ) x = {"image0": torch_gray0, "image1": torch_gray1} pred = pipeline_model(x) pred = batch_to_np(pred) kp0, kp1 = pred["keypoints0"], pred["keypoints1"] m0 = pred["matches0"] line_seg0, line_seg1 = pred["lines0"], pred["lines1"] line_matches = pred["line_matches0"] valid_matches = m0 != -1 match_indices = m0[valid_matches] matched_kps0 = kp0[valid_matches] matched_kps1 = kp1[match_indices] valid_matches = line_matches != -1 match_indices = line_matches[valid_matches] matched_lines0 = line_seg0[valid_matches] matched_lines1 = line_seg1[match_indices] # Plot the matches img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor( gray1, cv2.COLOR_GRAY2BGR ) plot_images( [img0, img1], ["Image 1 - detected lines", "Image 2 - detected lines"], dpi=200, pad=2.0, ) plot_lines([line_seg0, line_seg1], ps=4, lw=2) plt.gcf().canvas.manager.set_window_title("Detected Lines") plt.savefig("detected_lines.png") plot_images( [img0, img1], ["Image 1 - detected points", "Image 2 - detected points"], dpi=200, pad=2.0, ) plot_keypoints([kp0, kp1], colors="c") plt.gcf().canvas.manager.set_window_title("Detected Points") plt.savefig("detected_points.png") plot_images( [img0, img1], ["Image 1 - line matches", "Image 2 - line matches"], dpi=200, pad=2.0, ) plot_color_line_matches([matched_lines0, matched_lines1], lw=2) plt.gcf().canvas.manager.set_window_title("Line Matches") plt.savefig("line_matches.png") plot_images( [img0, img1], ["Image 1 - point matches", "Image 2 - point matches"], dpi=200, pad=2.0, ) plot_matches(matched_kps0, matched_kps1, "green", lw=1, ps=0) plt.gcf().canvas.manager.set_window_title("Point Matches") plt.savefig("detected_points.png") if not args.skip_imshow: plt.show() if __name__ == "__main__": main()