gim-online / app.py
Vincentqyw
update: app.py
a19d7bd
raw history blame
No virus
10.1 kB
import argparse
import gradio as gr
from hloc import extract_features
from extra_utils.utils import (
matcher_zoo,
device,
match_dense,
match_features,
get_model,
get_feature_model,
display_matches
)
def run_matching(
match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
):
# image0 and image1 is RGB mode
if image0 is None or image1 is None:
raise gr.Error("Error: No images found! Please upload two images.")
model = matcher_zoo[key]
match_conf = model["config"]
# update match config
match_conf["model"]["match_threshold"] = match_threshold
match_conf["model"]["max_keypoints"] = extract_max_keypoints
matcher = get_model(match_conf)
if model["dense"]:
pred = match_dense.match_images(
matcher, image0, image1, match_conf["preprocessing"], device=device
)
del matcher
extract_conf = None
else:
extract_conf = model["config_feature"]
# update extract config
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
extractor = get_feature_model(extract_conf)
pred0 = extract_features.extract(
extractor, image0, extract_conf["preprocessing"]
)
pred1 = extract_features.extract(
extractor, image1, extract_conf["preprocessing"]
)
pred = match_features.match_images(matcher, pred0, pred1)
del extractor
fig, num_inliers = display_matches(pred)
del pred
return (
fig,
{"matches number": num_inliers},
{"match_conf": match_conf, "extractor_conf": extract_conf},
)
def ui_change_imagebox(choice):
return {"value": None, "source": choice, "__type__": "update"}
def ui_reset_state(
match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
):
match_threshold = 0.2
extract_max_keypoints = 1000
keypoint_threshold = 0.015
key = list(matcher_zoo.keys())[0]
image0 = None
image1 = None
return (
match_threshold,
extract_max_keypoints,
keypoint_threshold,
key,
image0,
image1,
{"value": None, "source": "upload", "__type__": "update"},
{"value": None, "source": "upload", "__type__": "update"},
"upload",
None,
{},
{},
)
def run(config):
with gr.Blocks(
theme=gr.themes.Monochrome(), css="footer {visibility: hidden}"
) as app:
gr.Markdown(
"""
<p align="center">
<h1 align="center">Image Matching WebUI</h1>
</p>
"""
)
with gr.Row(equal_height=False):
with gr.Column():
with gr.Row():
matcher_list = gr.Dropdown(
choices=list(matcher_zoo.keys()),
value="disk+lightglue",
label="Matching Model",
interactive=True,
)
match_image_src = gr.Radio(
["upload", "webcam", "canvas"],
label="Image Source",
value="upload",
)
with gr.Row():
match_setting_threshold = gr.Slider(
minimum=0.0,
maximum=1,
step=0.001,
label="Match threshold",
value=0.1,
)
match_setting_max_features = gr.Slider(
minimum=10,
maximum=10000,
step=10,
label="Max number of features",
value=1000,
)
# TODO: add line settings
with gr.Row():
detect_keypoints_threshold = gr.Slider(
minimum=0,
maximum=1,
step=0.001,
label="Keypoint threshold",
value=0.015,
)
detect_line_threshold = gr.Slider(
minimum=0.1,
maximum=1,
step=0.01,
label="Line threshold",
value=0.2,
)
# matcher_lists = gr.Radio(
# ["NN-mutual", "Dual-Softmax"],
# label="Matcher mode",
# value="NN-mutual",
# )
with gr.Row():
input_image0 = gr.Image(
label="Image 0",
type="numpy",
interactive=True,
image_mode="RGB",
)
input_image1 = gr.Image(
label="Image 1",
type="numpy",
interactive=True,
image_mode="RGB",
)
with gr.Row():
button_reset = gr.Button(label="Reset", value="Reset")
button_run = gr.Button(
label="Run Match", value="Run Match", variant="primary"
)
with gr.Accordion("Open for More!", open=False):
gr.Markdown(
f"""
<h3>Supported Algorithms</h3>
{", ".join(matcher_zoo.keys())}
"""
)
# collect inputs
inputs = [
match_setting_threshold,
match_setting_max_features,
detect_keypoints_threshold,
matcher_list,
input_image0,
input_image1,
]
# Add some examples
with gr.Row():
examples = [
[
0.1,
2000,
0.015,
"disk+lightglue",
"datasets/sacre_coeur/mapping/71295362_4051449754.jpg",
"datasets/sacre_coeur/mapping/93341989_396310999.jpg",
],
[
0.1,
2000,
0.015,
"loftr",
"datasets/sacre_coeur/mapping/03903474_1471484089.jpg",
"datasets/sacre_coeur/mapping/02928139_3448003521.jpg",
],
[
0.1,
2000,
0.015,
"disk",
"datasets/sacre_coeur/mapping/10265353_3838484249.jpg",
"datasets/sacre_coeur/mapping/51091044_3486849416.jpg",
],
[
0.1,
2000,
0.015,
"topicfm",
"datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
"datasets/sacre_coeur/mapping/93341989_396310999.jpg",
],
[
0.1,
2000,
0.015,
"superpoint+superglue",
"datasets/sacre_coeur/mapping/17295357_9106075285.jpg",
"datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
],
]
# Example inputs
gr.Examples(
examples=examples,
inputs=inputs,
outputs=[],
fn=run_matching,
cache_examples=False,
label="Examples (click one of the images below to Run Match)",
)
with gr.Column():
output_mkpts = gr.Image(label="Keypoints Matching", type="numpy")
matches_result_info = gr.JSON(label="Matches Statistics")
matcher_info = gr.JSON(label="Match info")
# callbacks
match_image_src.change(
fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image0
)
match_image_src.change(
fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image1
)
# collect outputs
outputs = [
output_mkpts,
matches_result_info,
matcher_info,
]
# button callbacks
button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
# Reset images
reset_outputs = [
match_setting_threshold,
match_setting_max_features,
detect_keypoints_threshold,
matcher_list,
input_image0,
input_image1,
input_image0,
input_image1,
match_image_src,
output_mkpts,
matches_result_info,
matcher_info,
]
button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
app.launch(share=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path", type=str, default="config.yaml", help="configuration file path"
)
args = parser.parse_args()
config = None
run(config)