import os from typing import Tuple import gradio as gr import numpy as np import requests from matplotlib.figure import Figure from numpy import ndarray from demo import main, matcher_configs from src.models.utils import make_matching_figure HF_TOKEN = os.getenv("HF_TOKEN") CSS = """ #desc, #desc * { text-align: center !important; justify-content: center !important; align-items: center !important; } """ DESCRIPTION = """

CasP 🪜

Improving Semi-Dense Feature Matching Pipeline Leveraging
Cascaded Correspondence Priors for Guidance

ICCV 2025

Peiqi Chen1* · Lei Yu2* · Yi Wan1† Yingying Pei1 · Xinyi Liu1 · Yongxiang Yao1
Yingying Zhang2 · Lixiang Ru2 · Liheng Zhong2 · Jingdong Chen2 · Ming Yang2 · Yongjun Zhang1†

1Wuhan University   2Ant Group
*Equal contribution   †Corresponding author

""" examples = [ [ "assets/example_pairs/pair1-1.png", "assets/example_pairs/pair1-2.png", "casp_outdoor", "fundamental", ], [ "assets/example_pairs/pair2-1.png", "assets/example_pairs/pair2-2.png", "casp_outdoor", "fundamental", ], [ "assets/example_pairs/pair3-1.png", "assets/example_pairs/pair3-2.png", "casp_outdoor", "fundamental", ], [ "assets/example_pairs/pair4-1.jpg", "assets/example_pairs/pair4-2.jpg", "casp_minima", "homography", ], [ "assets/example_pairs/pair5-1.jpg", "assets/example_pairs/pair5-2.jpg", "casp_minima", "homography", ], [ "assets/example_pairs/pair6-1.jpg", "assets/example_pairs/pair6-2.jpg", "casp_minima", "homography", ], ] def fig_to_ndarray(fig: Figure) -> ndarray: fig.canvas.draw() w, h = fig.canvas.get_width_height() buffer = fig.canvas.buffer_rgba() out = np.frombuffer(buffer, dtype=np.uint8).reshape(h, w, 4) return out def run_matching( method: str, path0: str, path1: str, image_size: int, matching_threshold: float, ransac: str, estimator: str, inlier_threshold: float, ) -> Tuple[ndarray, ndarray, ndarray, ndarray]: ransac = None if ransac == "none" else ransac matching_args = { "method": method, "path0": path0, "path1": path1, "image_size": image_size, "matching_threshold": matching_threshold, "ransac": ransac, "estimator": estimator, "inlier_threshold": inlier_threshold, } points0, points1, scores, inlier_mask = main(matching_args) errors = 1 - scores text = [f"{method} (raw)", f"#matches: {len(points0)}"] plotting_args = { "path0": path0, "path1": path1, "points0": points0, "points1": points1, "errors": errors, "threshold": 0.5, "text": text, "dpi": 300, } raw_keypoint_fig = fig_to_ndarray( make_matching_figure(**plotting_args, enable_line=False) ) raw_matching_fig = fig_to_ndarray(make_matching_figure(**plotting_args)) ransac_keypoint_fig = ransac_matching_fig = None if inlier_mask is not None: for key in ["points0", "points1", "errors"]: plotting_args[key] = plotting_args[key][inlier_mask] plotting_args["text"] = [ f"{method} (RANSAC)", f"#matches: {inlier_mask.sum()}", ] ransac_keypoint_fig = fig_to_ndarray( make_matching_figure(**plotting_args, enable_line=False) ) ransac_matching_fig = fig_to_ndarray( make_matching_figure(**plotting_args) ) return ( raw_keypoint_fig, raw_matching_fig, ransac_keypoint_fig, ransac_matching_fig, ) with gr.Blocks(css=CSS) as demo: with gr.Tab("Image Matching"): with gr.Row(): with gr.Column(scale=3): gr.HTML(DESCRIPTION, elem_id="desc") with gr.Row(): with gr.Column(): gr.Markdown("### Input Panels:") with gr.Row(): method = gr.Dropdown( choices=["casp_outdoor", "casp_minima"], value="casp_outdoor", label="Matching Model", ) with gr.Row(): path0 = gr.Image( height=300, image_mode="RGB", type="filepath", label="Image 0", ) path1 = gr.Image( height=300, image_mode="RGB", type="filepath", label="Image 1", ) with gr.Row(): stop = gr.Button(value="Stop", variant="stop") run = gr.Button(value="Run", variant="primary") with gr.Accordion("Advanced Setting", open=False): with gr.Accordion("Image Setting"): with gr.Row(): force_resize = gr.Checkbox( label="Force Resize", value=True ) image_size = gr.Slider( minimum=512, maximum=1408, value=1152, step=32, label="Longer Side (pixels)", ) with gr.Accordion("Matching Setting"): with gr.Row(): matching_threshold = gr.Slider( minimum=0.0, maximum=1, value=0.2, step=0.05, label="Matching Threshold", ) with gr.Accordion("RANSAC Setting"): with gr.Row(): ransac = gr.Dropdown( choices=["none", "fundamental", "homography"], value="none", label="Model", ) with gr.Row(): estimator = gr.Dropdown( choices=["CV2_RANSAC", "CV2_USAC_MAGSAC"], value="CV2_USAC_MAGSAC", label="Estimator", visible=False, ) with gr.Row(): inlier_threshold = gr.Slider( minimum=0.0, maximum=10.0, value=3.0, step=0.5, label="Inlier Threshold", visible=False, ) with gr.Row(): with gr.Accordion("Example Pairs"): gr.Examples( examples=examples, inputs=[path0, path1, method, ransac], label="Click an example pair below", ) with gr.Column(): gr.Markdown( "### Output Panels: 🟢▲ High confidence | 🔴▼ Low confidence" ) with gr.Accordion("Raw Keypoints", open=False): raw_keypoint_fig = gr.Image( format="png", type="numpy", label="Raw Keypoints" ) with gr.Accordion("Raw Matches"): raw_matching_fig = gr.Image( format="png", type="numpy", label="Raw Matches" ) with gr.Accordion("RANSAC Keypoints", open=False): ransac_keypoint_fig = gr.Image( format="png", type="numpy", label="RANSAC Keypoints" ) with gr.Accordion("RANSAC Matches"): ransac_matching_fig = gr.Image( format="png", type="numpy", label="RANSAC Matches" ) inputs = [ method, path0, path1, image_size, matching_threshold, ransac, estimator, inlier_threshold, ] outputs = [ raw_keypoint_fig, raw_matching_fig, ransac_keypoint_fig, ransac_matching_fig, ] running_event = run.click( fn=run_matching, inputs=inputs, outputs=outputs ) stop.click( fn=None, inputs=None, outputs=None, cancels=[running_event] ) force_resize.select( fn=lambda checked: gr.update( visible=checked, value=1152 if checked else None ), inputs=force_resize, outputs=image_size, ) ransac.change( fn=lambda model: ( gr.update(visible=model != "none"), gr.update(visible=model != "none"), ), inputs=ransac, outputs=[estimator, inlier_threshold], ) if __name__ == "__main__": if HF_TOKEN: for method, config in matcher_configs.items(): url = ( f"https://huggingface.co/pq-chen/CasP/resolve/main/{method}.pth" ) headers = {"Authorization": f"Bearer {HF_TOKEN}"} response = requests.get(url, headers=headers) with open(config["ckpt_path"], "wb") as f: f.write(response.content) demo.launch()