Spaces:
Runtime error
Runtime error
import matplotlib.cm as cm | |
import torch | |
import gradio as gr | |
from models.matching import Matching | |
from models.utils import (make_matching_plot_fast, process_image) | |
torch.set_grad_enabled(False) | |
# Load the SuperPoint and SuperGlue models. | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
resize = [640, 640] | |
max_keypoints = 1024 | |
keypoint_threshold = 0.005 | |
nms_radius = 4 | |
sinkhorn_iterations = 20 | |
match_threshold = 0.2 | |
resize_float = False | |
config_indoor = { | |
'superpoint': { | |
'nms_radius': nms_radius, | |
'keypoint_threshold': keypoint_threshold, | |
'max_keypoints': max_keypoints | |
}, | |
'superglue': { | |
'weights': "indoor", | |
'sinkhorn_iterations': sinkhorn_iterations, | |
'match_threshold': match_threshold, | |
} | |
} | |
config_outdoor = { | |
'superpoint': { | |
'nms_radius': nms_radius, | |
'keypoint_threshold': keypoint_threshold, | |
'max_keypoints': max_keypoints | |
}, | |
'superglue': { | |
'weights': "outdoor", | |
'sinkhorn_iterations': sinkhorn_iterations, | |
'match_threshold': match_threshold, | |
} | |
} | |
matching_indoor = Matching(config_indoor).eval().to(device) | |
matching_outdoor = Matching(config_outdoor).eval().to(device) | |
def run(input0, input1, superglue): | |
if superglue == "indoor": | |
matching = matching_indoor | |
else: | |
matching = matching_outdoor | |
name0 = 'image1' | |
name1 = 'image2' | |
# If a rotation integer is provided (e.g. from EXIF data), use it: | |
rot0, rot1 = 0, 0 | |
# Load the image pair. | |
image0, inp0, scales0 = process_image(input0, device, resize, rot0, resize_float) | |
image1, inp1, scales1 = process_image(input1, device, resize, rot1, resize_float) | |
if image0 is None or image1 is None: | |
print('Problem reading image pair') | |
return | |
# Perform the matching. | |
pred = matching({'image0': inp0, 'image1': inp1}) | |
pred = {k: v[0].detach().numpy() for k, v in pred.items()} | |
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] | |
matches, conf = pred['matches0'], pred['matching_scores0'] | |
valid = matches > -1 | |
mkpts0 = kpts0[valid] | |
mkpts1 = kpts1[matches[valid]] | |
mconf = conf[valid] | |
# Visualize the matches. | |
color = cm.jet(mconf) | |
text = [ | |
'SuperGlue', | |
'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), | |
'{}'.format(len(mkpts0)), | |
] | |
if rot0 != 0 or rot1 != 0: | |
text.append('Rotation: {}:{}'.format(rot0, rot1)) | |
# Display extra parameter info. | |
k_thresh = matching.superpoint.config['keypoint_threshold'] | |
m_thresh = matching.superglue.config['match_threshold'] | |
small_text = [ | |
'Keypoint Threshold: {:.4f}'.format(k_thresh), | |
'Match Threshold: {:.2f}'.format(m_thresh), | |
'Image Pair: {}:{}'.format(name0, name1), | |
] | |
output = make_matching_plot_fast( | |
image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, | |
text, show_keypoints=True, small_text=small_text) | |
print('Source Image - {}, Destination Image - {}, {}, Match Percentage - {}'.format(name0, name1, text[2], len(mkpts0)/len(kpts0))) | |
return output, text[2], str((len(mkpts0)/len(kpts0))*100.0) + '%' | |
if __name__ == '__main__': | |
glue = gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.Image(label='Input Image'), | |
gr.Image(label='Match Image'), | |
gr.Radio(choices=["indoor", "outdoor"], value="Indoor", type="value", label="SuperGlueType", interactive=True), | |
], | |
outputs=[gr.Image( | |
type="pil", | |
label="Result"), | |
gr.Textbox(label="Keypoints Matched"), | |
gr.Textbox(label="Match Percentage") | |
] | |
) | |
glue.queue() | |
glue.launch() |