Spaces:
Running
Running
Vincentqyw
commited on
Commit
•
8b9ccdd
1
Parent(s):
1928ea3
update: ui
Browse files- app.py +186 -162
- common/utils.py +323 -12
- common/visualize_util.py +0 -642
- common/{plotting.py → viz.py} +116 -21
- style.css +18 -0
app.py
CHANGED
@@ -1,59 +1,20 @@
|
|
1 |
import argparse
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
from hloc import extract_features
|
5 |
from common.utils import (
|
6 |
matcher_zoo,
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
get_feature_model,
|
12 |
-
display_matches,
|
13 |
)
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
|
18 |
-
):
|
19 |
-
# image0 and image1 is RGB mode
|
20 |
-
if image0 is None or image1 is None:
|
21 |
-
raise gr.Error("Error: No images found! Please upload two images.")
|
22 |
-
|
23 |
-
model = matcher_zoo[key]
|
24 |
-
match_conf = model["config"]
|
25 |
-
# update match config
|
26 |
-
match_conf["model"]["match_threshold"] = match_threshold
|
27 |
-
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
28 |
|
29 |
-
|
30 |
-
if model["dense"]:
|
31 |
-
pred = match_dense.match_images(
|
32 |
-
matcher, image0, image1, match_conf["preprocessing"], device=device
|
33 |
-
)
|
34 |
-
del matcher
|
35 |
-
extract_conf = None
|
36 |
-
else:
|
37 |
-
extract_conf = model["config_feature"]
|
38 |
-
# update extract config
|
39 |
-
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
40 |
-
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
41 |
-
extractor = get_feature_model(extract_conf)
|
42 |
-
pred0 = extract_features.extract(
|
43 |
-
extractor, image0, extract_conf["preprocessing"]
|
44 |
-
)
|
45 |
-
pred1 = extract_features.extract(
|
46 |
-
extractor, image1, extract_conf["preprocessing"]
|
47 |
-
)
|
48 |
-
pred = match_features.match_images(matcher, pred0, pred1)
|
49 |
-
del extractor
|
50 |
-
fig, num_inliers = display_matches(pred)
|
51 |
-
del pred
|
52 |
-
return (
|
53 |
-
fig,
|
54 |
-
{"matches number": num_inliers},
|
55 |
-
{"match_conf": match_conf, "extractor_conf": extract_conf},
|
56 |
-
)
|
57 |
|
58 |
|
59 |
def ui_change_imagebox(choice):
|
@@ -61,7 +22,18 @@ def ui_change_imagebox(choice):
|
|
61 |
|
62 |
|
63 |
def ui_reset_state(
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
):
|
66 |
match_threshold = 0.2
|
67 |
extract_max_keypoints = 1000
|
@@ -69,31 +41,35 @@ def ui_reset_state(
|
|
69 |
key = list(matcher_zoo.keys())[0]
|
70 |
image0 = None
|
71 |
image1 = None
|
|
|
72 |
return (
|
|
|
|
|
73 |
match_threshold,
|
74 |
extract_max_keypoints,
|
75 |
keypoint_threshold,
|
76 |
key,
|
77 |
-
|
78 |
-
|
79 |
-
{"value": None, "source": "upload", "__type__": "update"},
|
80 |
-
{"value": None, "source": "upload", "__type__": "update"},
|
81 |
"upload",
|
82 |
None,
|
83 |
{},
|
84 |
{},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
|
87 |
|
|
|
88 |
def run(config):
|
89 |
-
with gr.Blocks(css="
|
90 |
-
gr.Markdown(
|
91 |
-
"""
|
92 |
-
<p align="center">
|
93 |
-
<h1 align="center">Image Matching WebUI</h1>
|
94 |
-
</p>
|
95 |
-
"""
|
96 |
-
)
|
97 |
|
98 |
with gr.Row(equal_height=False):
|
99 |
with gr.Column():
|
@@ -109,43 +85,6 @@ def run(config):
|
|
109 |
label="Image Source",
|
110 |
value="upload",
|
111 |
)
|
112 |
-
|
113 |
-
with gr.Row():
|
114 |
-
match_setting_threshold = gr.Slider(
|
115 |
-
minimum=0.0,
|
116 |
-
maximum=1,
|
117 |
-
step=0.001,
|
118 |
-
label="Match threshold",
|
119 |
-
value=0.1,
|
120 |
-
)
|
121 |
-
match_setting_max_features = gr.Slider(
|
122 |
-
minimum=10,
|
123 |
-
maximum=10000,
|
124 |
-
step=10,
|
125 |
-
label="Max number of features",
|
126 |
-
value=1000,
|
127 |
-
)
|
128 |
-
# TODO: add line settings
|
129 |
-
with gr.Row():
|
130 |
-
detect_keypoints_threshold = gr.Slider(
|
131 |
-
minimum=0,
|
132 |
-
maximum=1,
|
133 |
-
step=0.001,
|
134 |
-
label="Keypoint threshold",
|
135 |
-
value=0.015,
|
136 |
-
)
|
137 |
-
detect_line_threshold = gr.Slider(
|
138 |
-
minimum=0.1,
|
139 |
-
maximum=1,
|
140 |
-
step=0.01,
|
141 |
-
label="Line threshold",
|
142 |
-
value=0.2,
|
143 |
-
)
|
144 |
-
# matcher_lists = gr.Radio(
|
145 |
-
# ["NN-mutual", "Dual-Softmax"],
|
146 |
-
# label="Matcher mode",
|
147 |
-
# value="NN-mutual",
|
148 |
-
# )
|
149 |
with gr.Row():
|
150 |
input_image0 = gr.Image(
|
151 |
label="Image 0",
|
@@ -166,89 +105,147 @@ def run(config):
|
|
166 |
label="Run Match", value="Run Match", variant="primary"
|
167 |
)
|
168 |
|
169 |
-
with gr.Accordion("
|
170 |
-
gr.
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
|
|
177 |
# collect inputs
|
178 |
inputs = [
|
|
|
|
|
179 |
match_setting_threshold,
|
180 |
match_setting_max_features,
|
181 |
detect_keypoints_threshold,
|
182 |
matcher_list,
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
]
|
186 |
|
187 |
# Add some examples
|
188 |
with gr.Row():
|
189 |
-
examples = [
|
190 |
-
[
|
191 |
-
0.1,
|
192 |
-
2000,
|
193 |
-
0.015,
|
194 |
-
"disk+lightglue",
|
195 |
-
"datasets/sacre_coeur/mapping/71295362_4051449754.jpg",
|
196 |
-
"datasets/sacre_coeur/mapping/93341989_396310999.jpg",
|
197 |
-
],
|
198 |
-
[
|
199 |
-
0.1,
|
200 |
-
2000,
|
201 |
-
0.015,
|
202 |
-
"loftr",
|
203 |
-
"datasets/sacre_coeur/mapping/03903474_1471484089.jpg",
|
204 |
-
"datasets/sacre_coeur/mapping/02928139_3448003521.jpg",
|
205 |
-
],
|
206 |
-
[
|
207 |
-
0.1,
|
208 |
-
2000,
|
209 |
-
0.015,
|
210 |
-
"disk",
|
211 |
-
"datasets/sacre_coeur/mapping/10265353_3838484249.jpg",
|
212 |
-
"datasets/sacre_coeur/mapping/51091044_3486849416.jpg",
|
213 |
-
],
|
214 |
-
[
|
215 |
-
0.1,
|
216 |
-
2000,
|
217 |
-
0.015,
|
218 |
-
"topicfm",
|
219 |
-
"datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
|
220 |
-
"datasets/sacre_coeur/mapping/93341989_396310999.jpg",
|
221 |
-
],
|
222 |
-
[
|
223 |
-
0.1,
|
224 |
-
2000,
|
225 |
-
0.015,
|
226 |
-
"superpoint+superglue",
|
227 |
-
"datasets/sacre_coeur/mapping/17295357_9106075285.jpg",
|
228 |
-
"datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
|
229 |
-
],
|
230 |
-
]
|
231 |
# Example inputs
|
232 |
gr.Examples(
|
233 |
-
examples=
|
234 |
inputs=inputs,
|
235 |
outputs=[],
|
236 |
fn=run_matching,
|
237 |
-
cache_examples=
|
238 |
-
label=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
)
|
240 |
|
241 |
with gr.Column():
|
242 |
-
output_mkpts = gr.Image(
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
# callbacks
|
247 |
match_image_src.change(
|
248 |
-
fn=ui_change_imagebox,
|
|
|
|
|
249 |
)
|
250 |
match_image_src.change(
|
251 |
-
fn=ui_change_imagebox,
|
|
|
|
|
252 |
)
|
253 |
|
254 |
# collect outputs
|
@@ -256,34 +253,61 @@ def run(config):
|
|
256 |
output_mkpts,
|
257 |
matches_result_info,
|
258 |
matcher_info,
|
|
|
|
|
259 |
]
|
260 |
# button callbacks
|
261 |
button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
|
262 |
|
263 |
# Reset images
|
264 |
reset_outputs = [
|
|
|
|
|
265 |
match_setting_threshold,
|
266 |
match_setting_max_features,
|
267 |
detect_keypoints_threshold,
|
268 |
matcher_list,
|
269 |
input_image0,
|
270 |
input_image1,
|
271 |
-
input_image0,
|
272 |
-
input_image1,
|
273 |
match_image_src,
|
274 |
output_mkpts,
|
275 |
matches_result_info,
|
276 |
matcher_info,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
]
|
278 |
-
button_reset.click(
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
app.launch(share=False)
|
281 |
|
282 |
|
283 |
if __name__ == "__main__":
|
284 |
parser = argparse.ArgumentParser()
|
285 |
parser.add_argument(
|
286 |
-
"--config_path",
|
|
|
|
|
|
|
287 |
)
|
288 |
args = parser.parse_args()
|
289 |
config = None
|
|
|
1 |
import argparse
|
2 |
import gradio as gr
|
|
|
|
|
3 |
from common.utils import (
|
4 |
matcher_zoo,
|
5 |
+
change_estimate_geom,
|
6 |
+
run_matching,
|
7 |
+
ransac_zoo,
|
8 |
+
gen_examples,
|
|
|
|
|
9 |
)
|
10 |
|
11 |
+
DESCRIPTION = """
|
12 |
+
# Image Matching WebUI
|
13 |
+
This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
|
14 |
|
15 |
+
🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
def ui_change_imagebox(choice):
|
|
|
22 |
|
23 |
|
24 |
def ui_reset_state(
|
25 |
+
image0,
|
26 |
+
image1,
|
27 |
+
match_threshold,
|
28 |
+
extract_max_keypoints,
|
29 |
+
keypoint_threshold,
|
30 |
+
key,
|
31 |
+
enable_ransac=False,
|
32 |
+
ransac_method="RANSAC",
|
33 |
+
ransac_reproj_threshold=8,
|
34 |
+
ransac_confidence=0.999,
|
35 |
+
ransac_max_iter=10000,
|
36 |
+
choice_estimate_geom="Homography",
|
37 |
):
|
38 |
match_threshold = 0.2
|
39 |
extract_max_keypoints = 1000
|
|
|
41 |
key = list(matcher_zoo.keys())[0]
|
42 |
image0 = None
|
43 |
image1 = None
|
44 |
+
enable_ransac = False
|
45 |
return (
|
46 |
+
image0,
|
47 |
+
image1,
|
48 |
match_threshold,
|
49 |
extract_max_keypoints,
|
50 |
keypoint_threshold,
|
51 |
key,
|
52 |
+
ui_change_imagebox("upload"),
|
53 |
+
ui_change_imagebox("upload"),
|
|
|
|
|
54 |
"upload",
|
55 |
None,
|
56 |
{},
|
57 |
{},
|
58 |
+
None,
|
59 |
+
{},
|
60 |
+
False,
|
61 |
+
"RANSAC",
|
62 |
+
8,
|
63 |
+
0.999,
|
64 |
+
10000,
|
65 |
+
"Homography",
|
66 |
)
|
67 |
|
68 |
|
69 |
+
# "footer {visibility: hidden}"
|
70 |
def run(config):
|
71 |
+
with gr.Blocks(css="style.css") as app:
|
72 |
+
gr.Markdown(DESCRIPTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
with gr.Row(equal_height=False):
|
75 |
with gr.Column():
|
|
|
85 |
label="Image Source",
|
86 |
value="upload",
|
87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
with gr.Row():
|
89 |
input_image0 = gr.Image(
|
90 |
label="Image 0",
|
|
|
105 |
label="Run Match", value="Run Match", variant="primary"
|
106 |
)
|
107 |
|
108 |
+
with gr.Accordion("Advanced Setting", open=False):
|
109 |
+
with gr.Accordion("Matching Setting", open=True):
|
110 |
+
with gr.Row():
|
111 |
+
match_setting_threshold = gr.Slider(
|
112 |
+
minimum=0.0,
|
113 |
+
maximum=1,
|
114 |
+
step=0.001,
|
115 |
+
label="Match thres.",
|
116 |
+
value=0.1,
|
117 |
+
)
|
118 |
+
match_setting_max_features = gr.Slider(
|
119 |
+
minimum=10,
|
120 |
+
maximum=10000,
|
121 |
+
step=10,
|
122 |
+
label="Max features",
|
123 |
+
value=1000,
|
124 |
+
)
|
125 |
+
# TODO: add line settings
|
126 |
+
with gr.Row():
|
127 |
+
detect_keypoints_threshold = gr.Slider(
|
128 |
+
minimum=0,
|
129 |
+
maximum=1,
|
130 |
+
step=0.001,
|
131 |
+
label="Keypoint thres.",
|
132 |
+
value=0.015,
|
133 |
+
)
|
134 |
+
detect_line_threshold = gr.Slider(
|
135 |
+
minimum=0.1,
|
136 |
+
maximum=1,
|
137 |
+
step=0.01,
|
138 |
+
label="Line thres.",
|
139 |
+
value=0.2,
|
140 |
+
)
|
141 |
+
# matcher_lists = gr.Radio(
|
142 |
+
# ["NN-mutual", "Dual-Softmax"],
|
143 |
+
# label="Matcher mode",
|
144 |
+
# value="NN-mutual",
|
145 |
+
# )
|
146 |
+
with gr.Accordion("RANSAC Setting", open=False):
|
147 |
+
with gr.Row(equal_height=False):
|
148 |
+
enable_ransac = gr.Checkbox(label="Enable RANSAC")
|
149 |
+
ransac_method = gr.Dropdown(
|
150 |
+
choices=ransac_zoo.keys(),
|
151 |
+
value="RANSAC",
|
152 |
+
label="RANSAC Method",
|
153 |
+
interactive=True,
|
154 |
+
)
|
155 |
+
ransac_reproj_threshold = gr.Slider(
|
156 |
+
minimum=0.0,
|
157 |
+
maximum=12,
|
158 |
+
step=0.01,
|
159 |
+
label="Ransac Reproj threshold",
|
160 |
+
value=8.0,
|
161 |
+
)
|
162 |
+
ransac_confidence = gr.Slider(
|
163 |
+
minimum=0.0,
|
164 |
+
maximum=1,
|
165 |
+
step=0.00001,
|
166 |
+
label="Ransac Confidence",
|
167 |
+
value=0.99999,
|
168 |
+
)
|
169 |
+
ransac_max_iter = gr.Slider(
|
170 |
+
minimum=0.0,
|
171 |
+
maximum=100000,
|
172 |
+
step=100,
|
173 |
+
label="Ransac Iterations",
|
174 |
+
value=10000,
|
175 |
+
)
|
176 |
+
|
177 |
+
with gr.Accordion("Geometry Setting", open=True):
|
178 |
+
with gr.Row(equal_height=False):
|
179 |
+
# show_geom = gr.Checkbox(label="Show Geometry")
|
180 |
+
choice_estimate_geom = gr.Radio(
|
181 |
+
["Fundamental", "Homography"],
|
182 |
+
label="Reconstruct Geometry",
|
183 |
+
value="Homography",
|
184 |
+
)
|
185 |
|
186 |
+
# with gr.Column():
|
187 |
# collect inputs
|
188 |
inputs = [
|
189 |
+
input_image0,
|
190 |
+
input_image1,
|
191 |
match_setting_threshold,
|
192 |
match_setting_max_features,
|
193 |
detect_keypoints_threshold,
|
194 |
matcher_list,
|
195 |
+
enable_ransac,
|
196 |
+
ransac_method,
|
197 |
+
ransac_reproj_threshold,
|
198 |
+
ransac_confidence,
|
199 |
+
ransac_max_iter,
|
200 |
+
choice_estimate_geom,
|
201 |
]
|
202 |
|
203 |
# Add some examples
|
204 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
# Example inputs
|
206 |
gr.Examples(
|
207 |
+
examples=gen_examples(),
|
208 |
inputs=inputs,
|
209 |
outputs=[],
|
210 |
fn=run_matching,
|
211 |
+
cache_examples=False,
|
212 |
+
label=(
|
213 |
+
"Examples (click one of the images below to Run"
|
214 |
+
" Match)"
|
215 |
+
),
|
216 |
+
)
|
217 |
+
with gr.Accordion("Open for More!", open=False):
|
218 |
+
gr.Markdown(
|
219 |
+
f"""
|
220 |
+
<h3>Supported Algorithms</h3>
|
221 |
+
{", ".join(matcher_zoo.keys())}
|
222 |
+
"""
|
223 |
)
|
224 |
|
225 |
with gr.Column():
|
226 |
+
output_mkpts = gr.Image(
|
227 |
+
label="Keypoints Matching", type="numpy"
|
228 |
+
)
|
229 |
+
with gr.Accordion(
|
230 |
+
"Open for More: Matches Statistics", open=False
|
231 |
+
):
|
232 |
+
matches_result_info = gr.JSON(label="Matches Statistics")
|
233 |
+
matcher_info = gr.JSON(label="Match info")
|
234 |
+
|
235 |
+
output_wrapped = gr.Image(label="Wrapped Pair", type="numpy")
|
236 |
+
with gr.Accordion("Open for More: Geometry info", open=False):
|
237 |
+
geometry_result = gr.JSON(label="Reconstructed Geometry")
|
238 |
|
239 |
# callbacks
|
240 |
match_image_src.change(
|
241 |
+
fn=ui_change_imagebox,
|
242 |
+
inputs=match_image_src,
|
243 |
+
outputs=input_image0,
|
244 |
)
|
245 |
match_image_src.change(
|
246 |
+
fn=ui_change_imagebox,
|
247 |
+
inputs=match_image_src,
|
248 |
+
outputs=input_image1,
|
249 |
)
|
250 |
|
251 |
# collect outputs
|
|
|
253 |
output_mkpts,
|
254 |
matches_result_info,
|
255 |
matcher_info,
|
256 |
+
geometry_result,
|
257 |
+
output_wrapped,
|
258 |
]
|
259 |
# button callbacks
|
260 |
button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
|
261 |
|
262 |
# Reset images
|
263 |
reset_outputs = [
|
264 |
+
input_image0,
|
265 |
+
input_image1,
|
266 |
match_setting_threshold,
|
267 |
match_setting_max_features,
|
268 |
detect_keypoints_threshold,
|
269 |
matcher_list,
|
270 |
input_image0,
|
271 |
input_image1,
|
|
|
|
|
272 |
match_image_src,
|
273 |
output_mkpts,
|
274 |
matches_result_info,
|
275 |
matcher_info,
|
276 |
+
output_wrapped,
|
277 |
+
geometry_result,
|
278 |
+
enable_ransac,
|
279 |
+
ransac_method,
|
280 |
+
ransac_reproj_threshold,
|
281 |
+
ransac_confidence,
|
282 |
+
ransac_max_iter,
|
283 |
+
choice_estimate_geom,
|
284 |
]
|
285 |
+
button_reset.click(
|
286 |
+
fn=ui_reset_state, inputs=inputs, outputs=reset_outputs
|
287 |
+
)
|
288 |
+
|
289 |
+
# estimate geo
|
290 |
+
choice_estimate_geom.change(
|
291 |
+
fn=change_estimate_geom,
|
292 |
+
inputs=[
|
293 |
+
input_image0,
|
294 |
+
input_image1,
|
295 |
+
geometry_result,
|
296 |
+
choice_estimate_geom,
|
297 |
+
],
|
298 |
+
outputs=[output_wrapped, geometry_result],
|
299 |
+
)
|
300 |
+
|
301 |
app.launch(share=False)
|
302 |
|
303 |
|
304 |
if __name__ == "__main__":
|
305 |
parser = argparse.ArgumentParser()
|
306 |
parser.add_argument(
|
307 |
+
"--config_path",
|
308 |
+
type=str,
|
309 |
+
default="config.yaml",
|
310 |
+
help="configuration file path",
|
311 |
)
|
312 |
args = parser.parse_args()
|
313 |
config = None
|
common/utils.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
-
import
|
|
|
2 |
import numpy as np
|
|
|
|
|
3 |
import cv2
|
|
|
4 |
from hloc import matchers, extractors
|
5 |
from hloc.utils.base_model import dynamic_load
|
6 |
from hloc import match_dense, match_features, extract_features
|
7 |
-
from .
|
8 |
-
from .visualize_util import plot_images, plot_color_line_matches
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
|
@@ -22,6 +25,217 @@ def get_feature_model(conf):
|
|
22 |
return model
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def display_matches(pred: dict):
|
26 |
img0 = pred["image0_orig"]
|
27 |
img1 = pred["image1_orig"]
|
@@ -42,7 +256,10 @@ def display_matches(pred: dict):
|
|
42 |
img1,
|
43 |
mconf,
|
44 |
dpi=300,
|
45 |
-
titles=[
|
|
|
|
|
|
|
46 |
)
|
47 |
fig = fig_mkpts
|
48 |
if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
|
@@ -69,13 +286,107 @@ def display_matches(pred: dict):
|
|
69 |
else:
|
70 |
mconf = np.ones(len(mkpts0))
|
71 |
fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
|
72 |
-
fig_lines = cv2.resize(
|
|
|
|
|
73 |
fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
|
74 |
else:
|
75 |
fig = fig_lines
|
76 |
return fig, num_inliers
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# Matchers collections
|
80 |
matcher_zoo = {
|
81 |
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
@@ -147,11 +458,11 @@ matcher_zoo = {
|
|
147 |
"config_feature": extract_features.confs["d2net-ss"],
|
148 |
"dense": False,
|
149 |
},
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
"alike": {
|
156 |
"config": match_features.confs["NN-mutual"],
|
157 |
"config_feature": extract_features.confs["alike"],
|
@@ -177,6 +488,6 @@ matcher_zoo = {
|
|
177 |
"config_feature": extract_features.confs["sift"],
|
178 |
"dense": False,
|
179 |
},
|
180 |
-
|
181 |
-
|
182 |
}
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
import numpy as np
|
4 |
+
import torch
|
5 |
+
from itertools import combinations
|
6 |
import cv2
|
7 |
+
import gradio as gr
|
8 |
from hloc import matchers, extractors
|
9 |
from hloc.utils.base_model import dynamic_load
|
10 |
from hloc import match_dense, match_features, extract_features
|
11 |
+
from .viz import draw_matches, fig2im, plot_images, plot_color_line_matches
|
|
|
12 |
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
|
|
|
25 |
return model
|
26 |
|
27 |
|
28 |
+
def gen_examples():
|
29 |
+
random.seed(1)
|
30 |
+
example_matchers = [
|
31 |
+
"disk+lightglue",
|
32 |
+
"loftr",
|
33 |
+
"disk",
|
34 |
+
"d2net",
|
35 |
+
"topicfm",
|
36 |
+
"superpoint+superglue",
|
37 |
+
"disk+dualsoftmax",
|
38 |
+
"lanet",
|
39 |
+
]
|
40 |
+
|
41 |
+
def gen_images_pairs(path: str, count: int = 5):
|
42 |
+
imgs_list = [
|
43 |
+
os.path.join(path, file)
|
44 |
+
for file in os.listdir(path)
|
45 |
+
if file.lower().endswith((".jpg", ".jpeg", ".png"))
|
46 |
+
]
|
47 |
+
pairs = list(combinations(imgs_list, 2))
|
48 |
+
selected = random.sample(range(len(pairs)), count)
|
49 |
+
return [pairs[i] for i in selected]
|
50 |
+
# image pair path
|
51 |
+
path = "datasets/sacre_coeur/mapping"
|
52 |
+
pairs = gen_images_pairs(path, len(example_matchers))
|
53 |
+
match_setting_threshold = 0.1
|
54 |
+
match_setting_max_features = 2000
|
55 |
+
detect_keypoints_threshold = 0.01
|
56 |
+
enable_ransac = False
|
57 |
+
ransac_method = "RANSAC"
|
58 |
+
ransac_reproj_threshold = 8
|
59 |
+
ransac_confidence = 0.999
|
60 |
+
ransac_max_iter = 10000
|
61 |
+
input_lists = []
|
62 |
+
for pair, mt in zip(pairs, example_matchers):
|
63 |
+
input_lists.append(
|
64 |
+
[
|
65 |
+
pair[0],
|
66 |
+
pair[1],
|
67 |
+
match_setting_threshold,
|
68 |
+
match_setting_max_features,
|
69 |
+
detect_keypoints_threshold,
|
70 |
+
mt,
|
71 |
+
enable_ransac,
|
72 |
+
ransac_method,
|
73 |
+
ransac_reproj_threshold,
|
74 |
+
ransac_confidence,
|
75 |
+
ransac_max_iter,
|
76 |
+
]
|
77 |
+
)
|
78 |
+
return input_lists
|
79 |
+
|
80 |
+
|
81 |
+
def filter_matches(
|
82 |
+
pred,
|
83 |
+
ransac_method="RANSAC",
|
84 |
+
ransac_reproj_threshold=8,
|
85 |
+
ransac_confidence=0.999,
|
86 |
+
ransac_max_iter=10000,
|
87 |
+
):
|
88 |
+
mkpts0 = None
|
89 |
+
mkpts1 = None
|
90 |
+
feature_type = None
|
91 |
+
if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
|
92 |
+
mkpts0 = pred["keypoints0_orig"]
|
93 |
+
mkpts1 = pred["keypoints1_orig"]
|
94 |
+
feature_type = "KEYPOINT"
|
95 |
+
elif (
|
96 |
+
"line_keypoints0_orig" in pred.keys()
|
97 |
+
and "line_keypoints1_orig" in pred.keys()
|
98 |
+
):
|
99 |
+
mkpts0 = pred["line_keypoints0_orig"]
|
100 |
+
mkpts1 = pred["line_keypoints1_orig"]
|
101 |
+
feature_type = "LINE"
|
102 |
+
else:
|
103 |
+
return pred
|
104 |
+
if mkpts0 is None or mkpts0 is None:
|
105 |
+
return pred
|
106 |
+
if ransac_method not in ransac_zoo.keys():
|
107 |
+
ransac_method = "RANSAC"
|
108 |
+
H, mask = cv2.findHomography(
|
109 |
+
mkpts0,
|
110 |
+
mkpts1,
|
111 |
+
method=ransac_zoo[ransac_method],
|
112 |
+
ransacReprojThreshold=ransac_reproj_threshold,
|
113 |
+
confidence=ransac_confidence,
|
114 |
+
maxIters=ransac_max_iter,
|
115 |
+
)
|
116 |
+
mask = np.array(mask.ravel().astype("bool"), dtype="bool")
|
117 |
+
if H is not None:
|
118 |
+
if feature_type == "KEYPOINT":
|
119 |
+
pred["keypoints0_orig"] = mkpts0[mask]
|
120 |
+
pred["keypoints1_orig"] = mkpts1[mask]
|
121 |
+
pred["mconf"] = pred["mconf"][mask]
|
122 |
+
elif feature_type == "LINE":
|
123 |
+
pred["line_keypoints0_orig"] = mkpts0[mask]
|
124 |
+
pred["line_keypoints1_orig"] = mkpts1[mask]
|
125 |
+
return pred
|
126 |
+
|
127 |
+
|
128 |
+
def compute_geom(
|
129 |
+
pred,
|
130 |
+
ransac_method="RANSAC",
|
131 |
+
ransac_reproj_threshold=8,
|
132 |
+
ransac_confidence=0.999,
|
133 |
+
ransac_max_iter=10000,
|
134 |
+
) -> dict:
|
135 |
+
mkpts0 = None
|
136 |
+
mkpts1 = None
|
137 |
+
|
138 |
+
if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
|
139 |
+
mkpts0 = pred["keypoints0_orig"]
|
140 |
+
mkpts1 = pred["keypoints1_orig"]
|
141 |
+
|
142 |
+
if (
|
143 |
+
"line_keypoints0_orig" in pred.keys()
|
144 |
+
and "line_keypoints1_orig" in pred.keys()
|
145 |
+
):
|
146 |
+
mkpts0 = pred["line_keypoints0_orig"]
|
147 |
+
mkpts1 = pred["line_keypoints1_orig"]
|
148 |
+
|
149 |
+
if mkpts0 is not None and mkpts1 is not None:
|
150 |
+
if len(mkpts0) < 8:
|
151 |
+
return {}
|
152 |
+
h1, w1, _ = pred["image0_orig"].shape
|
153 |
+
geo_info = {}
|
154 |
+
F, inliers = cv2.findFundamentalMat(
|
155 |
+
mkpts0,
|
156 |
+
mkpts1,
|
157 |
+
method=ransac_zoo[ransac_method],
|
158 |
+
ransacReprojThreshold=ransac_reproj_threshold,
|
159 |
+
confidence=ransac_confidence,
|
160 |
+
maxIters=ransac_max_iter,
|
161 |
+
)
|
162 |
+
geo_info["Fundamental"] = F.tolist()
|
163 |
+
H, _ = cv2.findHomography(
|
164 |
+
mkpts1,
|
165 |
+
mkpts0,
|
166 |
+
method=ransac_zoo[ransac_method],
|
167 |
+
ransacReprojThreshold=ransac_reproj_threshold,
|
168 |
+
confidence=ransac_confidence,
|
169 |
+
maxIters=ransac_max_iter,
|
170 |
+
)
|
171 |
+
geo_info["Homography"] = H.tolist()
|
172 |
+
_, H1, H2 = cv2.stereoRectifyUncalibrated(
|
173 |
+
mkpts0.reshape(-1, 2), mkpts1.reshape(-1, 2), F, imgSize=(w1, h1)
|
174 |
+
)
|
175 |
+
geo_info["H1"] = H1.tolist()
|
176 |
+
geo_info["H2"] = H2.tolist()
|
177 |
+
return geo_info
|
178 |
+
else:
|
179 |
+
return {}
|
180 |
+
|
181 |
+
|
182 |
+
def wrap_images(img0, img1, geo_info, geom_type):
|
183 |
+
h1, w1, _ = img0.shape
|
184 |
+
h2, w2, _ = img1.shape
|
185 |
+
result_matrix = None
|
186 |
+
if geo_info is not None and len(geo_info) != 0:
|
187 |
+
rectified_image0 = img0
|
188 |
+
rectified_image1 = None
|
189 |
+
H = np.array(geo_info["Homography"])
|
190 |
+
F = np.array(geo_info["Fundamental"])
|
191 |
+
title = []
|
192 |
+
if geom_type == "Homography":
|
193 |
+
rectified_image1 = cv2.warpPerspective(
|
194 |
+
img1, H, (img0.shape[1] + img1.shape[1], img0.shape[0])
|
195 |
+
)
|
196 |
+
result_matrix = H
|
197 |
+
title = ["Image 0", "Image 1 - warped"]
|
198 |
+
elif geom_type == "Fundamental":
|
199 |
+
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
|
200 |
+
rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1))
|
201 |
+
rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2))
|
202 |
+
result_matrix = F
|
203 |
+
title = ["Image 0 - warped", "Image 1 - warped"]
|
204 |
+
else:
|
205 |
+
print("Error: Unknown geometry type")
|
206 |
+
fig = plot_images(
|
207 |
+
[rectified_image0.squeeze(), rectified_image1.squeeze()],
|
208 |
+
title,
|
209 |
+
dpi=300,
|
210 |
+
)
|
211 |
+
dictionary = {
|
212 |
+
"row1": result_matrix[0].tolist(),
|
213 |
+
"row2": result_matrix[1].tolist(),
|
214 |
+
"row3": result_matrix[2].tolist(),
|
215 |
+
}
|
216 |
+
return fig2im(fig), dictionary
|
217 |
+
else:
|
218 |
+
return None, None
|
219 |
+
|
220 |
+
|
221 |
+
def change_estimate_geom(input_image0, input_image1, matches_info, choice):
|
222 |
+
if (
|
223 |
+
matches_info is None
|
224 |
+
or len(matches_info) < 1
|
225 |
+
or "geom_info" not in matches_info.keys()
|
226 |
+
):
|
227 |
+
return None, None
|
228 |
+
geom_info = matches_info["geom_info"]
|
229 |
+
wrapped_images = None
|
230 |
+
if choice != "No":
|
231 |
+
wrapped_images, _ = wrap_images(
|
232 |
+
input_image0, input_image1, geom_info, choice
|
233 |
+
)
|
234 |
+
return wrapped_images, matches_info
|
235 |
+
else:
|
236 |
+
return None, None
|
237 |
+
|
238 |
+
|
239 |
def display_matches(pred: dict):
|
240 |
img0 = pred["image0_orig"]
|
241 |
img1 = pred["image1_orig"]
|
|
|
256 |
img1,
|
257 |
mconf,
|
258 |
dpi=300,
|
259 |
+
titles=[
|
260 |
+
"Image 0 - matched keypoints",
|
261 |
+
"Image 1 - matched keypoints",
|
262 |
+
],
|
263 |
)
|
264 |
fig = fig_mkpts
|
265 |
if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
|
|
|
286 |
else:
|
287 |
mconf = np.ones(len(mkpts0))
|
288 |
fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
|
289 |
+
fig_lines = cv2.resize(
|
290 |
+
fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
|
291 |
+
)
|
292 |
fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
|
293 |
else:
|
294 |
fig = fig_lines
|
295 |
return fig, num_inliers
|
296 |
|
297 |
|
298 |
+
def run_matching(
|
299 |
+
image0,
|
300 |
+
image1,
|
301 |
+
match_threshold,
|
302 |
+
extract_max_keypoints,
|
303 |
+
keypoint_threshold,
|
304 |
+
key,
|
305 |
+
enable_ransac=False,
|
306 |
+
ransac_method="RANSAC",
|
307 |
+
ransac_reproj_threshold=8,
|
308 |
+
ransac_confidence=0.999,
|
309 |
+
ransac_max_iter=10000,
|
310 |
+
choice_estimate_geom="Homography",
|
311 |
+
):
|
312 |
+
# image0 and image1 is RGB mode
|
313 |
+
if image0 is None or image1 is None:
|
314 |
+
raise gr.Error("Error: No images found! Please upload two images.")
|
315 |
+
|
316 |
+
model = matcher_zoo[key]
|
317 |
+
match_conf = model["config"]
|
318 |
+
# update match config
|
319 |
+
match_conf["model"]["match_threshold"] = match_threshold
|
320 |
+
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
321 |
+
|
322 |
+
matcher = get_model(match_conf)
|
323 |
+
if model["dense"]:
|
324 |
+
pred = match_dense.match_images(
|
325 |
+
matcher, image0, image1, match_conf["preprocessing"], device=device
|
326 |
+
)
|
327 |
+
del matcher
|
328 |
+
extract_conf = None
|
329 |
+
else:
|
330 |
+
extract_conf = model["config_feature"]
|
331 |
+
# update extract config
|
332 |
+
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
333 |
+
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
334 |
+
extractor = get_feature_model(extract_conf)
|
335 |
+
pred0 = extract_features.extract(
|
336 |
+
extractor, image0, extract_conf["preprocessing"]
|
337 |
+
)
|
338 |
+
pred1 = extract_features.extract(
|
339 |
+
extractor, image1, extract_conf["preprocessing"]
|
340 |
+
)
|
341 |
+
pred = match_features.match_images(matcher, pred0, pred1)
|
342 |
+
del extractor
|
343 |
+
|
344 |
+
if enable_ransac:
|
345 |
+
filter_matches(
|
346 |
+
pred,
|
347 |
+
ransac_method=ransac_method,
|
348 |
+
ransac_reproj_threshold=ransac_reproj_threshold,
|
349 |
+
ransac_confidence=ransac_confidence,
|
350 |
+
ransac_max_iter=ransac_max_iter,
|
351 |
+
)
|
352 |
+
|
353 |
+
fig, num_inliers = display_matches(pred)
|
354 |
+
geom_info = compute_geom(pred)
|
355 |
+
output_wrapped, _ = change_estimate_geom(
|
356 |
+
pred["image0_orig"],
|
357 |
+
pred["image1_orig"],
|
358 |
+
{"geom_info": geom_info},
|
359 |
+
choice_estimate_geom,
|
360 |
+
)
|
361 |
+
del pred
|
362 |
+
return (
|
363 |
+
fig,
|
364 |
+
{"matches number": num_inliers},
|
365 |
+
{
|
366 |
+
"match_conf": match_conf,
|
367 |
+
"extractor_conf": extract_conf,
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"geom_info": geom_info,
|
371 |
+
},
|
372 |
+
output_wrapped,
|
373 |
+
# geometry_result,
|
374 |
+
)
|
375 |
+
|
376 |
+
|
377 |
+
# @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html
|
378 |
+
# AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs
|
379 |
+
ransac_zoo = {
|
380 |
+
"RANSAC": cv2.RANSAC,
|
381 |
+
"USAC_MAGSAC": cv2.USAC_MAGSAC,
|
382 |
+
"USAC_DEFAULT": cv2.USAC_DEFAULT,
|
383 |
+
"USAC_FM_8PTS": cv2.USAC_FM_8PTS,
|
384 |
+
"USAC_PROSAC": cv2.USAC_PROSAC,
|
385 |
+
"USAC_FAST": cv2.USAC_FAST,
|
386 |
+
"USAC_ACCURATE": cv2.USAC_ACCURATE,
|
387 |
+
"USAC_PARALLEL": cv2.USAC_PARALLEL,
|
388 |
+
}
|
389 |
+
|
390 |
# Matchers collections
|
391 |
matcher_zoo = {
|
392 |
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
|
|
458 |
"config_feature": extract_features.confs["d2net-ss"],
|
459 |
"dense": False,
|
460 |
},
|
461 |
+
"d2net-ms": {
|
462 |
+
"config": match_features.confs["NN-mutual"],
|
463 |
+
"config_feature": extract_features.confs["d2net-ms"],
|
464 |
+
"dense": False,
|
465 |
+
},
|
466 |
"alike": {
|
467 |
"config": match_features.confs["NN-mutual"],
|
468 |
"config_feature": extract_features.confs["alike"],
|
|
|
488 |
"config_feature": extract_features.confs["sift"],
|
489 |
"dense": False,
|
490 |
},
|
491 |
+
"roma": {"config": match_dense.confs["roma"], "dense": True},
|
492 |
+
"DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
|
493 |
}
|
common/visualize_util.py
DELETED
@@ -1,642 +0,0 @@
|
|
1 |
-
""" Organize some frequently used visualization functions. """
|
2 |
-
import cv2
|
3 |
-
import numpy as np
|
4 |
-
import matplotlib
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
import copy
|
7 |
-
import seaborn as sns
|
8 |
-
|
9 |
-
|
10 |
-
# Plot junctions onto the image (return a separate copy)
|
11 |
-
def plot_junctions(input_image, junctions, junc_size=3, color=None):
|
12 |
-
"""
|
13 |
-
input_image: can be 0~1 float or 0~255 uint8.
|
14 |
-
junctions: Nx2 or 2xN np array.
|
15 |
-
junc_size: the size of the plotted circles.
|
16 |
-
"""
|
17 |
-
# Create image copy
|
18 |
-
image = copy.copy(input_image)
|
19 |
-
# Make sure the image is converted to 255 uint8
|
20 |
-
if image.dtype == np.uint8:
|
21 |
-
pass
|
22 |
-
# A float type image ranging from 0~1
|
23 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
|
24 |
-
image = (image * 255.0).astype(np.uint8)
|
25 |
-
# A float type image ranging from 0.~255.
|
26 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
|
27 |
-
image = image.astype(np.uint8)
|
28 |
-
else:
|
29 |
-
raise ValueError(
|
30 |
-
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
|
31 |
-
)
|
32 |
-
|
33 |
-
# Check whether the image is single channel
|
34 |
-
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
|
35 |
-
# Squeeze to H*W first
|
36 |
-
image = image.squeeze()
|
37 |
-
|
38 |
-
# Stack to channle 3
|
39 |
-
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
|
40 |
-
|
41 |
-
# Junction dimensions should be N*2
|
42 |
-
if not len(junctions.shape) == 2:
|
43 |
-
raise ValueError("[Error] junctions should be 2-dim array.")
|
44 |
-
|
45 |
-
# Always convert to N*2
|
46 |
-
if junctions.shape[-1] != 2:
|
47 |
-
if junctions.shape[0] == 2:
|
48 |
-
junctions = junctions.T
|
49 |
-
else:
|
50 |
-
raise ValueError("[Error] At least one of the two dims should be 2.")
|
51 |
-
|
52 |
-
# Round and convert junctions to int (and check the boundary)
|
53 |
-
H, W = image.shape[:2]
|
54 |
-
junctions = (np.round(junctions)).astype(np.int)
|
55 |
-
junctions[junctions < 0] = 0
|
56 |
-
junctions[junctions[:, 0] >= H, 0] = H - 1 # (first dim) max bounded by H-1
|
57 |
-
junctions[junctions[:, 1] >= W, 1] = W - 1 # (second dim) max bounded by W-1
|
58 |
-
|
59 |
-
# Iterate through all the junctions
|
60 |
-
num_junc = junctions.shape[0]
|
61 |
-
if color is None:
|
62 |
-
color = (0, 255.0, 0)
|
63 |
-
for idx in range(num_junc):
|
64 |
-
# Fetch one junction
|
65 |
-
junc = junctions[idx, :]
|
66 |
-
cv2.circle(
|
67 |
-
image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
|
68 |
-
)
|
69 |
-
|
70 |
-
return image
|
71 |
-
|
72 |
-
|
73 |
-
# Plot line segements given junctions and line adjecent map
|
74 |
-
def plot_line_segments(
|
75 |
-
input_image,
|
76 |
-
junctions,
|
77 |
-
line_map,
|
78 |
-
junc_size=3,
|
79 |
-
color=(0, 255.0, 0),
|
80 |
-
line_width=1,
|
81 |
-
plot_survived_junc=True,
|
82 |
-
):
|
83 |
-
"""
|
84 |
-
input_image: can be 0~1 float or 0~255 uint8.
|
85 |
-
junctions: Nx2 or 2xN np array.
|
86 |
-
line_map: NxN np array
|
87 |
-
junc_size: the size of the plotted circles.
|
88 |
-
color: color of the line segments (can be string "random")
|
89 |
-
line_width: width of the drawn segments.
|
90 |
-
plot_survived_junc: whether we only plot the survived junctions.
|
91 |
-
"""
|
92 |
-
# Create image copy
|
93 |
-
image = copy.copy(input_image)
|
94 |
-
# Make sure the image is converted to 255 uint8
|
95 |
-
if image.dtype == np.uint8:
|
96 |
-
pass
|
97 |
-
# A float type image ranging from 0~1
|
98 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
|
99 |
-
image = (image * 255.0).astype(np.uint8)
|
100 |
-
# A float type image ranging from 0.~255.
|
101 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
|
102 |
-
image = image.astype(np.uint8)
|
103 |
-
else:
|
104 |
-
raise ValueError(
|
105 |
-
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
|
106 |
-
)
|
107 |
-
|
108 |
-
# Check whether the image is single channel
|
109 |
-
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
|
110 |
-
# Squeeze to H*W first
|
111 |
-
image = image.squeeze()
|
112 |
-
|
113 |
-
# Stack to channle 3
|
114 |
-
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
|
115 |
-
|
116 |
-
# Junction dimensions should be 2
|
117 |
-
if not len(junctions.shape) == 2:
|
118 |
-
raise ValueError("[Error] junctions should be 2-dim array.")
|
119 |
-
|
120 |
-
# Always convert to N*2
|
121 |
-
if junctions.shape[-1] != 2:
|
122 |
-
if junctions.shape[0] == 2:
|
123 |
-
junctions = junctions.T
|
124 |
-
else:
|
125 |
-
raise ValueError("[Error] At least one of the two dims should be 2.")
|
126 |
-
|
127 |
-
# line_map dimension should be 2
|
128 |
-
if not len(line_map.shape) == 2:
|
129 |
-
raise ValueError("[Error] line_map should be 2-dim array.")
|
130 |
-
|
131 |
-
# Color should be "random" or a list or tuple with length 3
|
132 |
-
if color != "random":
|
133 |
-
if not (isinstance(color, tuple) or isinstance(color, list)):
|
134 |
-
raise ValueError("[Error] color should have type list or tuple.")
|
135 |
-
else:
|
136 |
-
if len(color) != 3:
|
137 |
-
raise ValueError(
|
138 |
-
"[Error] color should be a list or tuple with length 3."
|
139 |
-
)
|
140 |
-
|
141 |
-
# Make a copy of the line_map
|
142 |
-
line_map_tmp = copy.copy(line_map)
|
143 |
-
|
144 |
-
# Parse line_map back to segment pairs
|
145 |
-
segments = np.zeros([0, 4])
|
146 |
-
for idx in range(junctions.shape[0]):
|
147 |
-
# if no connectivity, just skip it
|
148 |
-
if line_map_tmp[idx, :].sum() == 0:
|
149 |
-
continue
|
150 |
-
# record the line segment
|
151 |
-
else:
|
152 |
-
for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
|
153 |
-
p1 = np.flip(junctions[idx, :]) # Convert to xy format
|
154 |
-
p2 = np.flip(junctions[idx2, :]) # Convert to xy format
|
155 |
-
segments = np.concatenate(
|
156 |
-
(segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
|
157 |
-
axis=0,
|
158 |
-
)
|
159 |
-
|
160 |
-
# Update line_map
|
161 |
-
line_map_tmp[idx, idx2] = 0
|
162 |
-
line_map_tmp[idx2, idx] = 0
|
163 |
-
|
164 |
-
# Draw segment pairs
|
165 |
-
for idx in range(segments.shape[0]):
|
166 |
-
seg = np.round(segments[idx, :]).astype(np.int)
|
167 |
-
# Decide the color
|
168 |
-
if color != "random":
|
169 |
-
color = tuple(color)
|
170 |
-
else:
|
171 |
-
color = tuple(
|
172 |
-
np.random.rand(
|
173 |
-
3,
|
174 |
-
)
|
175 |
-
)
|
176 |
-
cv2.line(
|
177 |
-
image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
|
178 |
-
)
|
179 |
-
|
180 |
-
# Also draw the junctions
|
181 |
-
if not plot_survived_junc:
|
182 |
-
num_junc = junctions.shape[0]
|
183 |
-
for idx in range(num_junc):
|
184 |
-
# Fetch one junction
|
185 |
-
junc = junctions[idx, :]
|
186 |
-
cv2.circle(
|
187 |
-
image,
|
188 |
-
tuple(np.flip(junc)),
|
189 |
-
radius=junc_size,
|
190 |
-
color=(0, 255.0, 0),
|
191 |
-
thickness=3,
|
192 |
-
)
|
193 |
-
# Only plot the junctions which are part of a line segment
|
194 |
-
else:
|
195 |
-
for idx in range(segments.shape[0]):
|
196 |
-
seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
|
197 |
-
cv2.circle(
|
198 |
-
image,
|
199 |
-
tuple(seg[:2]),
|
200 |
-
radius=junc_size,
|
201 |
-
color=(0, 255.0, 0),
|
202 |
-
thickness=3,
|
203 |
-
)
|
204 |
-
cv2.circle(
|
205 |
-
image,
|
206 |
-
tuple(seg[2:]),
|
207 |
-
radius=junc_size,
|
208 |
-
color=(0, 255.0, 0),
|
209 |
-
thickness=3,
|
210 |
-
)
|
211 |
-
|
212 |
-
return image
|
213 |
-
|
214 |
-
|
215 |
-
# Plot line segments given Nx4 or Nx2x2 line segments
|
216 |
-
def plot_line_segments_from_segments(
|
217 |
-
input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
|
218 |
-
):
|
219 |
-
# Create image copy
|
220 |
-
image = copy.copy(input_image)
|
221 |
-
# Make sure the image is converted to 255 uint8
|
222 |
-
if image.dtype == np.uint8:
|
223 |
-
pass
|
224 |
-
# A float type image ranging from 0~1
|
225 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
|
226 |
-
image = (image * 255.0).astype(np.uint8)
|
227 |
-
# A float type image ranging from 0.~255.
|
228 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
|
229 |
-
image = image.astype(np.uint8)
|
230 |
-
else:
|
231 |
-
raise ValueError(
|
232 |
-
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
|
233 |
-
)
|
234 |
-
|
235 |
-
# Check whether the image is single channel
|
236 |
-
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
|
237 |
-
# Squeeze to H*W first
|
238 |
-
image = image.squeeze()
|
239 |
-
|
240 |
-
# Stack to channle 3
|
241 |
-
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
|
242 |
-
|
243 |
-
# Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
|
244 |
-
H, W, _ = image.shape
|
245 |
-
# (1) Nx4 format
|
246 |
-
if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
|
247 |
-
# Round to int32
|
248 |
-
line_segments = line_segments.astype(np.int32)
|
249 |
-
|
250 |
-
# Clip H dimension
|
251 |
-
line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
|
252 |
-
line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)
|
253 |
-
|
254 |
-
# Clip W dimension
|
255 |
-
line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
|
256 |
-
line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)
|
257 |
-
|
258 |
-
# Convert to Nx2x2 format
|
259 |
-
line_segments = np.concatenate(
|
260 |
-
[
|
261 |
-
np.expand_dims(line_segments[:, :2], axis=1),
|
262 |
-
np.expand_dims(line_segments[:, 2:], axis=1),
|
263 |
-
],
|
264 |
-
axis=1,
|
265 |
-
)
|
266 |
-
|
267 |
-
# (2) Nx2x2 format
|
268 |
-
elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
|
269 |
-
# Round to int32
|
270 |
-
line_segments = line_segments.astype(np.int32)
|
271 |
-
|
272 |
-
# Clip H dimension
|
273 |
-
line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
|
274 |
-
line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)
|
275 |
-
|
276 |
-
else:
|
277 |
-
raise ValueError(
|
278 |
-
"[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
|
279 |
-
)
|
280 |
-
|
281 |
-
# Draw segment pairs (all segments should be in HW format)
|
282 |
-
image = image.copy()
|
283 |
-
for idx in range(line_segments.shape[0]):
|
284 |
-
seg = np.round(line_segments[idx, :, :]).astype(np.int32)
|
285 |
-
# Decide the color
|
286 |
-
if color != "random":
|
287 |
-
color = tuple(color)
|
288 |
-
else:
|
289 |
-
color = tuple(
|
290 |
-
np.random.rand(
|
291 |
-
3,
|
292 |
-
)
|
293 |
-
)
|
294 |
-
cv2.line(
|
295 |
-
image,
|
296 |
-
tuple(np.flip(seg[0, :])),
|
297 |
-
tuple(np.flip(seg[1, :])),
|
298 |
-
color=color,
|
299 |
-
thickness=line_width,
|
300 |
-
)
|
301 |
-
|
302 |
-
# Also draw the junctions
|
303 |
-
cv2.circle(
|
304 |
-
image,
|
305 |
-
tuple(np.flip(seg[0, :])),
|
306 |
-
radius=junc_size,
|
307 |
-
color=(0, 255.0, 0),
|
308 |
-
thickness=3,
|
309 |
-
)
|
310 |
-
cv2.circle(
|
311 |
-
image,
|
312 |
-
tuple(np.flip(seg[1, :])),
|
313 |
-
radius=junc_size,
|
314 |
-
color=(0, 255.0, 0),
|
315 |
-
thickness=3,
|
316 |
-
)
|
317 |
-
|
318 |
-
return image
|
319 |
-
|
320 |
-
|
321 |
-
# Additional functions to visualize multiple images at the same time,
|
322 |
-
# e.g. for line matching
|
323 |
-
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
|
324 |
-
"""Plot a set of images horizontally.
|
325 |
-
Args:
|
326 |
-
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
|
327 |
-
titles: a list of strings, as titles for each image.
|
328 |
-
cmaps: colormaps for monochrome images.
|
329 |
-
"""
|
330 |
-
n = len(imgs)
|
331 |
-
if not isinstance(cmaps, (list, tuple)):
|
332 |
-
cmaps = [cmaps] * n
|
333 |
-
# figsize = (size*n, size*3/4) if size is not None else None
|
334 |
-
figsize = (size * n, size * 6 / 5) if size is not None else None
|
335 |
-
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
|
336 |
-
|
337 |
-
if n == 1:
|
338 |
-
ax = [ax]
|
339 |
-
for i in range(n):
|
340 |
-
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
|
341 |
-
ax[i].get_yaxis().set_ticks([])
|
342 |
-
ax[i].get_xaxis().set_ticks([])
|
343 |
-
ax[i].set_axis_off()
|
344 |
-
for spine in ax[i].spines.values(): # remove frame
|
345 |
-
spine.set_visible(False)
|
346 |
-
if titles:
|
347 |
-
ax[i].set_title(titles[i])
|
348 |
-
fig.tight_layout(pad=pad)
|
349 |
-
return fig
|
350 |
-
|
351 |
-
|
352 |
-
def plot_keypoints(kpts, colors="lime", ps=4):
|
353 |
-
"""Plot keypoints for existing images.
|
354 |
-
Args:
|
355 |
-
kpts: list of ndarrays of size (N, 2).
|
356 |
-
colors: string, or list of list of tuples (one for each keypoints).
|
357 |
-
ps: size of the keypoints as float.
|
358 |
-
"""
|
359 |
-
if not isinstance(colors, list):
|
360 |
-
colors = [colors] * len(kpts)
|
361 |
-
axes = plt.gcf().axes
|
362 |
-
for a, k, c in zip(axes, kpts, colors):
|
363 |
-
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
|
364 |
-
|
365 |
-
|
366 |
-
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
367 |
-
"""Plot matches for a pair of existing images.
|
368 |
-
Args:
|
369 |
-
kpts0, kpts1: corresponding keypoints of size (N, 2).
|
370 |
-
color: color of each match, string or RGB tuple. Random if not given.
|
371 |
-
lw: width of the lines.
|
372 |
-
ps: size of the end points (no endpoint if ps=0)
|
373 |
-
indices: indices of the images to draw the matches on.
|
374 |
-
a: alpha opacity of the match lines.
|
375 |
-
"""
|
376 |
-
fig = plt.gcf()
|
377 |
-
ax = fig.axes
|
378 |
-
assert len(ax) > max(indices)
|
379 |
-
ax0, ax1 = ax[indices[0]], ax[indices[1]]
|
380 |
-
fig.canvas.draw()
|
381 |
-
|
382 |
-
assert len(kpts0) == len(kpts1)
|
383 |
-
if color is None:
|
384 |
-
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
|
385 |
-
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
386 |
-
color = [color] * len(kpts0)
|
387 |
-
|
388 |
-
if lw > 0:
|
389 |
-
# transform the points into the figure coordinate system
|
390 |
-
transFigure = fig.transFigure.inverted()
|
391 |
-
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
|
392 |
-
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
|
393 |
-
fig.lines += [
|
394 |
-
matplotlib.lines.Line2D(
|
395 |
-
(fkpts0[i, 0], fkpts1[i, 0]),
|
396 |
-
(fkpts0[i, 1], fkpts1[i, 1]),
|
397 |
-
zorder=1,
|
398 |
-
transform=fig.transFigure,
|
399 |
-
c=color[i],
|
400 |
-
linewidth=lw,
|
401 |
-
alpha=a,
|
402 |
-
)
|
403 |
-
for i in range(len(kpts0))
|
404 |
-
]
|
405 |
-
|
406 |
-
# freeze the axes to prevent the transform to change
|
407 |
-
ax0.autoscale(enable=False)
|
408 |
-
ax1.autoscale(enable=False)
|
409 |
-
|
410 |
-
if ps > 0:
|
411 |
-
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
|
412 |
-
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
|
413 |
-
|
414 |
-
|
415 |
-
def plot_lines(
|
416 |
-
lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
|
417 |
-
):
|
418 |
-
"""Plot lines and endpoints for existing images.
|
419 |
-
Args:
|
420 |
-
lines: list of ndarrays of size (N, 2, 2).
|
421 |
-
colors: string, or list of list of tuples (one for each keypoints).
|
422 |
-
ps: size of the keypoints as float pixels.
|
423 |
-
lw: line width as float pixels.
|
424 |
-
indices: indices of the images to draw the matches on.
|
425 |
-
"""
|
426 |
-
if not isinstance(line_colors, list):
|
427 |
-
line_colors = [line_colors] * len(lines)
|
428 |
-
if not isinstance(point_colors, list):
|
429 |
-
point_colors = [point_colors] * len(lines)
|
430 |
-
|
431 |
-
fig = plt.gcf()
|
432 |
-
ax = fig.axes
|
433 |
-
assert len(ax) > max(indices)
|
434 |
-
axes = [ax[i] for i in indices]
|
435 |
-
fig.canvas.draw()
|
436 |
-
|
437 |
-
# Plot the lines and junctions
|
438 |
-
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
|
439 |
-
for i in range(len(l)):
|
440 |
-
line = matplotlib.lines.Line2D(
|
441 |
-
(l[i, 0, 0], l[i, 1, 0]),
|
442 |
-
(l[i, 0, 1], l[i, 1, 1]),
|
443 |
-
zorder=1,
|
444 |
-
c=lc,
|
445 |
-
linewidth=lw,
|
446 |
-
)
|
447 |
-
a.add_line(line)
|
448 |
-
pts = l.reshape(-1, 2)
|
449 |
-
a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)
|
450 |
-
|
451 |
-
return fig
|
452 |
-
|
453 |
-
|
454 |
-
def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
|
455 |
-
"""Plot matches for a pair of existing images, parametrized by their middle point.
|
456 |
-
Args:
|
457 |
-
kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
|
458 |
-
color: color of each match, string or RGB tuple. Random if not given.
|
459 |
-
lw: width of the lines.
|
460 |
-
indices: indices of the images to draw the matches on.
|
461 |
-
a: alpha opacity of the match lines.
|
462 |
-
"""
|
463 |
-
fig = plt.gcf()
|
464 |
-
ax = fig.axes
|
465 |
-
assert len(ax) > max(indices)
|
466 |
-
ax0, ax1 = ax[indices[0]], ax[indices[1]]
|
467 |
-
fig.canvas.draw()
|
468 |
-
|
469 |
-
assert len(kpts0) == len(kpts1)
|
470 |
-
if color is None:
|
471 |
-
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
|
472 |
-
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
473 |
-
color = [color] * len(kpts0)
|
474 |
-
|
475 |
-
if lw > 0:
|
476 |
-
# transform the points into the figure coordinate system
|
477 |
-
transFigure = fig.transFigure.inverted()
|
478 |
-
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
|
479 |
-
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
|
480 |
-
fig.lines += [
|
481 |
-
matplotlib.lines.Line2D(
|
482 |
-
(fkpts0[i, 0], fkpts1[i, 0]),
|
483 |
-
(fkpts0[i, 1], fkpts1[i, 1]),
|
484 |
-
zorder=1,
|
485 |
-
transform=fig.transFigure,
|
486 |
-
c=color[i],
|
487 |
-
linewidth=lw,
|
488 |
-
alpha=a,
|
489 |
-
)
|
490 |
-
for i in range(len(kpts0))
|
491 |
-
]
|
492 |
-
|
493 |
-
# freeze the axes to prevent the transform to change
|
494 |
-
ax0.autoscale(enable=False)
|
495 |
-
ax1.autoscale(enable=False)
|
496 |
-
|
497 |
-
|
498 |
-
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
|
499 |
-
"""Plot line matches for existing images with multiple colors.
|
500 |
-
Args:
|
501 |
-
lines: list of ndarrays of size (N, 2, 2).
|
502 |
-
correct_matches: bool array of size (N,) indicating correct matches.
|
503 |
-
lw: line width as float pixels.
|
504 |
-
indices: indices of the images to draw the matches on.
|
505 |
-
"""
|
506 |
-
n_lines = len(lines[0])
|
507 |
-
colors = sns.color_palette("husl", n_colors=n_lines)
|
508 |
-
np.random.shuffle(colors)
|
509 |
-
alphas = np.ones(n_lines)
|
510 |
-
# If correct_matches is not None, display wrong matches with a low alpha
|
511 |
-
if correct_matches is not None:
|
512 |
-
alphas[~np.array(correct_matches)] = 0.2
|
513 |
-
|
514 |
-
fig = plt.gcf()
|
515 |
-
ax = fig.axes
|
516 |
-
assert len(ax) > max(indices)
|
517 |
-
axes = [ax[i] for i in indices]
|
518 |
-
fig.canvas.draw()
|
519 |
-
|
520 |
-
# Plot the lines
|
521 |
-
for a, l in zip(axes, lines):
|
522 |
-
# Transform the points into the figure coordinate system
|
523 |
-
transFigure = fig.transFigure.inverted()
|
524 |
-
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
525 |
-
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
526 |
-
fig.lines += [
|
527 |
-
matplotlib.lines.Line2D(
|
528 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
529 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
530 |
-
zorder=1,
|
531 |
-
transform=fig.transFigure,
|
532 |
-
c=colors[i],
|
533 |
-
alpha=alphas[i],
|
534 |
-
linewidth=lw,
|
535 |
-
)
|
536 |
-
for i in range(n_lines)
|
537 |
-
]
|
538 |
-
|
539 |
-
return fig
|
540 |
-
|
541 |
-
|
542 |
-
def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
|
543 |
-
"""Plot line matches for existing images with multiple colors:
|
544 |
-
green for correct matches, red for wrong ones, and blue for the rest.
|
545 |
-
Args:
|
546 |
-
lines: list of ndarrays of size (N, 2, 2).
|
547 |
-
correct_matches: list of bool arrays of size N with correct matches.
|
548 |
-
wrong_matches: list of bool arrays of size (N,) with correct matches.
|
549 |
-
lw: line width as float pixels.
|
550 |
-
indices: indices of the images to draw the matches on.
|
551 |
-
"""
|
552 |
-
# palette = sns.color_palette()
|
553 |
-
palette = sns.color_palette("hls", 8)
|
554 |
-
blue = palette[5] # palette[0]
|
555 |
-
red = palette[0] # palette[3]
|
556 |
-
green = palette[2] # palette[2]
|
557 |
-
colors = [np.array([blue] * len(l)) for l in lines]
|
558 |
-
for i, c in enumerate(colors):
|
559 |
-
c[np.array(correct_matches[i])] = green
|
560 |
-
c[np.array(wrong_matches[i])] = red
|
561 |
-
|
562 |
-
fig = plt.gcf()
|
563 |
-
ax = fig.axes
|
564 |
-
assert len(ax) > max(indices)
|
565 |
-
axes = [ax[i] for i in indices]
|
566 |
-
fig.canvas.draw()
|
567 |
-
|
568 |
-
# Plot the lines
|
569 |
-
for a, l, c in zip(axes, lines, colors):
|
570 |
-
# Transform the points into the figure coordinate system
|
571 |
-
transFigure = fig.transFigure.inverted()
|
572 |
-
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
573 |
-
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
574 |
-
fig.lines += [
|
575 |
-
matplotlib.lines.Line2D(
|
576 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
577 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
578 |
-
zorder=1,
|
579 |
-
transform=fig.transFigure,
|
580 |
-
c=c[i],
|
581 |
-
linewidth=lw,
|
582 |
-
)
|
583 |
-
for i in range(len(l))
|
584 |
-
]
|
585 |
-
|
586 |
-
|
587 |
-
def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
|
588 |
-
"""Plot line matches for existing images with multiple colors and
|
589 |
-
highlight the actually matched subsegments.
|
590 |
-
Args:
|
591 |
-
lines: list of ndarrays of size (N, 2, 2).
|
592 |
-
subsegments: list of ndarrays of size (N, 2, 2).
|
593 |
-
lw: line width as float pixels.
|
594 |
-
indices: indices of the images to draw the matches on.
|
595 |
-
"""
|
596 |
-
n_lines = len(lines[0])
|
597 |
-
colors = sns.cubehelix_palette(
|
598 |
-
start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
|
599 |
-
)
|
600 |
-
|
601 |
-
fig = plt.gcf()
|
602 |
-
ax = fig.axes
|
603 |
-
assert len(ax) > max(indices)
|
604 |
-
axes = [ax[i] for i in indices]
|
605 |
-
fig.canvas.draw()
|
606 |
-
|
607 |
-
# Plot the lines
|
608 |
-
for a, l, ss in zip(axes, lines, subsegments):
|
609 |
-
# Transform the points into the figure coordinate system
|
610 |
-
transFigure = fig.transFigure.inverted()
|
611 |
-
|
612 |
-
# Draw full line
|
613 |
-
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
614 |
-
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
615 |
-
fig.lines += [
|
616 |
-
matplotlib.lines.Line2D(
|
617 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
618 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
619 |
-
zorder=1,
|
620 |
-
transform=fig.transFigure,
|
621 |
-
c="red",
|
622 |
-
alpha=0.7,
|
623 |
-
linewidth=lw,
|
624 |
-
)
|
625 |
-
for i in range(n_lines)
|
626 |
-
]
|
627 |
-
|
628 |
-
# Draw matched subsegment
|
629 |
-
endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
|
630 |
-
endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
|
631 |
-
fig.lines += [
|
632 |
-
matplotlib.lines.Line2D(
|
633 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
634 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
635 |
-
zorder=1,
|
636 |
-
transform=fig.transFigure,
|
637 |
-
c=colors[i],
|
638 |
-
alpha=1,
|
639 |
-
linewidth=lw,
|
640 |
-
)
|
641 |
-
for i in range(n_lines)
|
642 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/{plotting.py → viz.py}
RENAMED
@@ -6,6 +6,7 @@ import matplotlib.cm as cm
|
|
6 |
from PIL import Image
|
7 |
import torch.nn.functional as F
|
8 |
import torch
|
|
|
9 |
|
10 |
|
11 |
def _compute_conf_thresh(data):
|
@@ -19,7 +20,77 @@ def _compute_conf_thresh(data):
|
|
19 |
return thr
|
20 |
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def make_matching_figure(
|
@@ -57,7 +128,7 @@ def make_matching_figure(
|
|
57 |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
|
58 |
|
59 |
# draw matches
|
60 |
-
if mkpts0.shape[0]
|
61 |
fig.canvas.draw()
|
62 |
transFigure = fig.transFigure.inverted()
|
63 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
@@ -105,8 +176,12 @@ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
|
|
105 |
b_mask = data["m_bids"] == b_id
|
106 |
conf_thr = _compute_conf_thresh(data)
|
107 |
|
108 |
-
img0 = (
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
|
111 |
kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
|
112 |
|
@@ -131,8 +206,10 @@ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
|
|
131 |
|
132 |
text = [
|
133 |
f"#Matches {len(kpts0)}",
|
134 |
-
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):
|
135 |
-
f"
|
|
|
|
|
136 |
]
|
137 |
|
138 |
# make the figure
|
@@ -188,7 +265,9 @@ def error_colormap(err, thr, alpha=1.0):
|
|
188 |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
189 |
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
190 |
return np.clip(
|
191 |
-
np.stack(
|
|
|
|
|
192 |
0,
|
193 |
1,
|
194 |
)
|
@@ -200,9 +279,13 @@ np.random.shuffle(color_map)
|
|
200 |
|
201 |
|
202 |
def draw_topics(
|
203 |
-
data,
|
|
|
|
|
|
|
|
|
|
|
204 |
):
|
205 |
-
|
206 |
topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
|
207 |
hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
|
208 |
hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
|
@@ -237,7 +320,10 @@ def draw_topics(
|
|
237 |
dim=-1, keepdim=True
|
238 |
) # .float() / (n_topics - 1) #* 255 + 1
|
239 |
# topic1[~mask1_nonzero] = -1
|
240 |
-
label_img0, label_img1 =
|
|
|
|
|
|
|
241 |
for i, k in enumerate(top_topics):
|
242 |
label_img0[topic0 == k] = color_map[k]
|
243 |
label_img1[topic1 == k] = color_map[k]
|
@@ -312,24 +398,30 @@ def draw_topicfm_demo(
|
|
312 |
opencv_display=False,
|
313 |
opencv_title="",
|
314 |
):
|
315 |
-
topic_map0, topic_map1 = draw_topics(
|
316 |
-
|
317 |
-
mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(
|
318 |
-
topic_map1 >= 0, axis=-1
|
319 |
)
|
320 |
|
|
|
|
|
|
|
|
|
321 |
topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
|
322 |
-
topic_cm0 = cv2.cvtColor(
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
|
325 |
overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
|
326 |
|
327 |
cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
|
328 |
cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
|
329 |
|
330 |
-
overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (
|
331 |
-
|
332 |
-
)
|
333 |
|
334 |
h0, w0 = img0.shape[:2]
|
335 |
h1, w1 = img1.shape[:2]
|
@@ -338,7 +430,9 @@ def draw_topicfm_demo(
|
|
338 |
out_fig[:h0, :w0] = overlay0
|
339 |
if h0 >= h1:
|
340 |
start = (h0 - h1) // 2
|
341 |
-
out_fig[
|
|
|
|
|
342 |
else:
|
343 |
start = (h1 - h0) // 2
|
344 |
out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
|
@@ -358,7 +452,8 @@ def draw_topicfm_demo(
|
|
358 |
img1[start : start + h0] * 255
|
359 |
).astype(np.uint8)
|
360 |
|
361 |
-
# draw matching lines, this is inspried from
|
|
|
362 |
mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
|
363 |
mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
|
364 |
|
|
|
6 |
from PIL import Image
|
7 |
import torch.nn.functional as F
|
8 |
import torch
|
9 |
+
import seaborn as sns
|
10 |
|
11 |
|
12 |
def _compute_conf_thresh(data):
|
|
|
20 |
return thr
|
21 |
|
22 |
|
23 |
+
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
|
24 |
+
"""Plot a set of images horizontally.
|
25 |
+
Args:
|
26 |
+
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
|
27 |
+
titles: a list of strings, as titles for each image.
|
28 |
+
cmaps: colormaps for monochrome images.
|
29 |
+
"""
|
30 |
+
n = len(imgs)
|
31 |
+
if not isinstance(cmaps, (list, tuple)):
|
32 |
+
cmaps = [cmaps] * n
|
33 |
+
# figsize = (size*n, size*3/4) if size is not None else None
|
34 |
+
figsize = (size * n, size * 6 / 5) if size is not None else None
|
35 |
+
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
|
36 |
+
|
37 |
+
if n == 1:
|
38 |
+
ax = [ax]
|
39 |
+
for i in range(n):
|
40 |
+
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
|
41 |
+
ax[i].get_yaxis().set_ticks([])
|
42 |
+
ax[i].get_xaxis().set_ticks([])
|
43 |
+
ax[i].set_axis_off()
|
44 |
+
for spine in ax[i].spines.values(): # remove frame
|
45 |
+
spine.set_visible(False)
|
46 |
+
if titles:
|
47 |
+
ax[i].set_title(titles[i])
|
48 |
+
fig.tight_layout(pad=pad)
|
49 |
+
return fig
|
50 |
+
|
51 |
+
|
52 |
+
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
|
53 |
+
"""Plot line matches for existing images with multiple colors.
|
54 |
+
Args:
|
55 |
+
lines: list of ndarrays of size (N, 2, 2).
|
56 |
+
correct_matches: bool array of size (N,) indicating correct matches.
|
57 |
+
lw: line width as float pixels.
|
58 |
+
indices: indices of the images to draw the matches on.
|
59 |
+
"""
|
60 |
+
n_lines = len(lines[0])
|
61 |
+
colors = sns.color_palette("husl", n_colors=n_lines)
|
62 |
+
np.random.shuffle(colors)
|
63 |
+
alphas = np.ones(n_lines)
|
64 |
+
# If correct_matches is not None, display wrong matches with a low alpha
|
65 |
+
if correct_matches is not None:
|
66 |
+
alphas[~np.array(correct_matches)] = 0.2
|
67 |
+
|
68 |
+
fig = plt.gcf()
|
69 |
+
ax = fig.axes
|
70 |
+
assert len(ax) > max(indices)
|
71 |
+
axes = [ax[i] for i in indices]
|
72 |
+
fig.canvas.draw()
|
73 |
+
|
74 |
+
# Plot the lines
|
75 |
+
for a, l in zip(axes, lines):
|
76 |
+
# Transform the points into the figure coordinate system
|
77 |
+
transFigure = fig.transFigure.inverted()
|
78 |
+
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
79 |
+
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
80 |
+
fig.lines += [
|
81 |
+
matplotlib.lines.Line2D(
|
82 |
+
(endpoint0[i, 0], endpoint1[i, 0]),
|
83 |
+
(endpoint0[i, 1], endpoint1[i, 1]),
|
84 |
+
zorder=1,
|
85 |
+
transform=fig.transFigure,
|
86 |
+
c=colors[i],
|
87 |
+
alpha=alphas[i],
|
88 |
+
linewidth=lw,
|
89 |
+
)
|
90 |
+
for i in range(n_lines)
|
91 |
+
]
|
92 |
+
|
93 |
+
return fig
|
94 |
|
95 |
|
96 |
def make_matching_figure(
|
|
|
128 |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
|
129 |
|
130 |
# draw matches
|
131 |
+
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
132 |
fig.canvas.draw()
|
133 |
transFigure = fig.transFigure.inverted()
|
134 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
|
|
176 |
b_mask = data["m_bids"] == b_id
|
177 |
conf_thr = _compute_conf_thresh(data)
|
178 |
|
179 |
+
img0 = (
|
180 |
+
(data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
181 |
+
)
|
182 |
+
img1 = (
|
183 |
+
(data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
184 |
+
)
|
185 |
kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
|
186 |
kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
|
187 |
|
|
|
206 |
|
207 |
text = [
|
208 |
f"#Matches {len(kpts0)}",
|
209 |
+
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):"
|
210 |
+
f" {n_correct}/{len(kpts0)}",
|
211 |
+
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%):"
|
212 |
+
f" {n_correct}/{n_gt_matches}",
|
213 |
]
|
214 |
|
215 |
# make the figure
|
|
|
265 |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
266 |
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
267 |
return np.clip(
|
268 |
+
np.stack(
|
269 |
+
[2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1
|
270 |
+
),
|
271 |
0,
|
272 |
1,
|
273 |
)
|
|
|
279 |
|
280 |
|
281 |
def draw_topics(
|
282 |
+
data,
|
283 |
+
img0,
|
284 |
+
img1,
|
285 |
+
saved_folder="viz_topics",
|
286 |
+
show_n_topics=8,
|
287 |
+
saved_name=None,
|
288 |
):
|
|
|
289 |
topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
|
290 |
hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
|
291 |
hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
|
|
|
320 |
dim=-1, keepdim=True
|
321 |
) # .float() / (n_topics - 1) #* 255 + 1
|
322 |
# topic1[~mask1_nonzero] = -1
|
323 |
+
label_img0, label_img1 = (
|
324 |
+
torch.zeros_like(topic0) - 1,
|
325 |
+
torch.zeros_like(topic1) - 1,
|
326 |
+
)
|
327 |
for i, k in enumerate(top_topics):
|
328 |
label_img0[topic0 == k] = color_map[k]
|
329 |
label_img1[topic1 == k] = color_map[k]
|
|
|
398 |
opencv_display=False,
|
399 |
opencv_title="",
|
400 |
):
|
401 |
+
topic_map0, topic_map1 = draw_topics(
|
402 |
+
data, img0, img1, show_n_topics=show_n_topics
|
|
|
|
|
403 |
)
|
404 |
|
405 |
+
mask_tm0, mask_tm1 = np.expand_dims(
|
406 |
+
topic_map0 >= 0, axis=-1
|
407 |
+
), np.expand_dims(topic_map1 >= 0, axis=-1)
|
408 |
+
|
409 |
topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
|
410 |
+
topic_cm0 = cv2.cvtColor(
|
411 |
+
topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
|
412 |
+
)
|
413 |
+
topic_cm1 = cv2.cvtColor(
|
414 |
+
topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
|
415 |
+
)
|
416 |
overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
|
417 |
overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
|
418 |
|
419 |
cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
|
420 |
cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
|
421 |
|
422 |
+
overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (
|
423 |
+
overlay1 * 255
|
424 |
+
).astype(np.uint8)
|
425 |
|
426 |
h0, w0 = img0.shape[:2]
|
427 |
h1, w1 = img1.shape[:2]
|
|
|
430 |
out_fig[:h0, :w0] = overlay0
|
431 |
if h0 >= h1:
|
432 |
start = (h0 - h1) // 2
|
433 |
+
out_fig[
|
434 |
+
start : (start + h1), (w0 + margin) : (w0 + margin + w1)
|
435 |
+
] = overlay1
|
436 |
else:
|
437 |
start = (h1 - h0) // 2
|
438 |
out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
|
|
|
452 |
img1[start : start + h0] * 255
|
453 |
).astype(np.uint8)
|
454 |
|
455 |
+
# draw matching lines, this is inspried from
|
456 |
+
# https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
|
457 |
mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
|
458 |
mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
|
459 |
|
style.css
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
|
5 |
+
#duplicate-button {
|
6 |
+
margin: auto;
|
7 |
+
color: white;
|
8 |
+
background: #1565c0;
|
9 |
+
border-radius: 100vh;
|
10 |
+
}
|
11 |
+
|
12 |
+
#component-0 {
|
13 |
+
/* max-width: 900px; */
|
14 |
+
margin: auto;
|
15 |
+
padding-top: 1.5rem;
|
16 |
+
}
|
17 |
+
|
18 |
+
footer {visibility: hidden}
|