Vincentqyw
fix: roma
8b973ee
raw
history blame
No virus
4.23 kB
import argparse
import os
from os.path import join
import cv2
import torch
from matplotlib import pyplot as plt
from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
from .drawing import (
plot_images,
plot_lines,
plot_color_line_matches,
plot_keypoints,
plot_matches,
)
from .models.two_view_pipeline import TwoViewPipeline
def main():
# Parse input parameters
parser = argparse.ArgumentParser(
prog="GlueStick Demo",
description="Demo app to show the point and line matches obtained by GlueStick",
)
parser.add_argument("-img1", default=join("resources" + os.path.sep + "img1.jpg"))
parser.add_argument("-img2", default=join("resources" + os.path.sep + "img2.jpg"))
parser.add_argument("--max_pts", type=int, default=1000)
parser.add_argument("--max_lines", type=int, default=300)
parser.add_argument("--skip-imshow", default=False, action="store_true")
args = parser.parse_args()
# Evaluation config
conf = {
"name": "two_view_pipeline",
"use_lines": True,
"extractor": {
"name": "wireframe",
"sp_params": {
"force_num_keypoints": False,
"max_num_keypoints": args.max_pts,
},
"wireframe_params": {
"merge_points": True,
"merge_line_endpoints": True,
},
"max_n_lines": args.max_lines,
},
"matcher": {
"name": "gluestick",
"weights": str(
GLUESTICK_ROOT / "resources" / "weights" / "checkpoint_GlueStick_MD.tar"
),
"trainable": False,
},
"ground_truth": {
"from_pose_depth": False,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline_model = TwoViewPipeline(conf).to(device).eval()
gray0 = cv2.imread(args.img1, 0)
gray1 = cv2.imread(args.img2, 0)
torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
torch_gray0, torch_gray1 = (
torch_gray0.to(device)[None],
torch_gray1.to(device)[None],
)
x = {"image0": torch_gray0, "image1": torch_gray1}
pred = pipeline_model(x)
pred = batch_to_np(pred)
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0 = pred["matches0"]
line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
line_matches = pred["line_matches0"]
valid_matches = m0 != -1
match_indices = m0[valid_matches]
matched_kps0 = kp0[valid_matches]
matched_kps1 = kp1[match_indices]
valid_matches = line_matches != -1
match_indices = line_matches[valid_matches]
matched_lines0 = line_seg0[valid_matches]
matched_lines1 = line_seg1[match_indices]
# Plot the matches
img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(
gray1, cv2.COLOR_GRAY2BGR
)
plot_images(
[img0, img1],
["Image 1 - detected lines", "Image 2 - detected lines"],
dpi=200,
pad=2.0,
)
plot_lines([line_seg0, line_seg1], ps=4, lw=2)
plt.gcf().canvas.manager.set_window_title("Detected Lines")
plt.savefig("detected_lines.png")
plot_images(
[img0, img1],
["Image 1 - detected points", "Image 2 - detected points"],
dpi=200,
pad=2.0,
)
plot_keypoints([kp0, kp1], colors="c")
plt.gcf().canvas.manager.set_window_title("Detected Points")
plt.savefig("detected_points.png")
plot_images(
[img0, img1],
["Image 1 - line matches", "Image 2 - line matches"],
dpi=200,
pad=2.0,
)
plot_color_line_matches([matched_lines0, matched_lines1], lw=2)
plt.gcf().canvas.manager.set_window_title("Line Matches")
plt.savefig("line_matches.png")
plot_images(
[img0, img1],
["Image 1 - point matches", "Image 2 - point matches"],
dpi=200,
pad=2.0,
)
plot_matches(matched_kps0, matched_kps1, "green", lw=1, ps=0)
plt.gcf().canvas.manager.set_window_title("Point Matches")
plt.savefig("detected_points.png")
if not args.skip_imshow:
plt.show()
if __name__ == "__main__":
main()