Vincentqyw commited on
Commit
8e76240
1 Parent(s): 6cb641c

update: default params

Browse files
Files changed (4) hide show
  1. app.py +72 -54
  2. common/utils.py +35 -26
  3. common/viz.py +1 -345
  4. style.css +1 -0
app.py CHANGED
@@ -2,10 +2,18 @@ import argparse
2
  import gradio as gr
3
  from common.utils import (
4
  matcher_zoo,
 
5
  change_estimate_geom,
6
  run_matching,
7
- ransac_zoo,
8
  gen_examples,
 
 
 
 
 
 
 
 
9
  )
10
 
11
  DESCRIPTION = """
@@ -21,58 +29,66 @@ This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/ima
21
 
22
 
23
  def ui_change_imagebox(choice):
24
- return {"value": None, "source": choice, "__type__": "update"}
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
- def ui_reset_state(
28
- image0,
29
- image1,
30
- match_threshold,
31
- extract_max_keypoints,
32
- keypoint_threshold,
33
- key,
34
- # enable_ransac=False,
35
- ransac_method="RANSAC",
36
- ransac_reproj_threshold=8,
37
- ransac_confidence=0.999,
38
- ransac_max_iter=10000,
39
- choice_estimate_geom="Homography",
40
- ):
41
- match_threshold = 0.2
42
- extract_max_keypoints = 1000
43
- keypoint_threshold = 0.015
44
- key = list(matcher_zoo.keys())[0]
45
- image0 = None
46
- image1 = None
47
- # enable_ransac = False
48
  return (
49
- image0,
50
- image1,
51
- match_threshold,
52
- extract_max_keypoints,
53
- keypoint_threshold,
54
- key,
55
- ui_change_imagebox("upload"),
56
- ui_change_imagebox("upload"),
57
- "upload",
58
  None, # keypoints
59
  None, # raw matches
60
  None, # ransac matches
61
- {},
62
- {},
63
- None,
64
- {},
65
- # False,
66
- "RANSAC",
67
- 8,
68
- 0.999,
69
- 10000,
70
- "Homography",
71
  )
72
 
73
 
74
  # "footer {visibility: hidden}"
75
  def run(config):
 
 
 
 
 
 
 
 
 
76
  with gr.Blocks(css="style.css") as app:
77
  gr.Markdown(DESCRIPTION)
78
 
@@ -94,21 +110,21 @@ def run(config):
94
  input_image0 = gr.Image(
95
  label="Image 0",
96
  type="numpy",
97
- interactive=True,
98
  image_mode="RGB",
 
 
99
  )
100
  input_image1 = gr.Image(
101
  label="Image 1",
102
  type="numpy",
103
- interactive=True,
104
  image_mode="RGB",
 
 
105
  )
106
 
107
  with gr.Row():
108
  button_reset = gr.Button(value="Reset")
109
- button_run = gr.Button(
110
- value="Run Match", variant="primary"
111
- )
112
 
113
  with gr.Accordion("Advanced Setting", open=False):
114
  with gr.Accordion("Matching Setting", open=True):
@@ -153,7 +169,7 @@ def run(config):
153
  # enable_ransac = gr.Checkbox(label="Enable RANSAC")
154
  ransac_method = gr.Dropdown(
155
  choices=ransac_zoo.keys(),
156
- value="RANSAC",
157
  label="RANSAC Method",
158
  interactive=True,
159
  )
@@ -185,7 +201,7 @@ def run(config):
185
  choice_estimate_geom = gr.Radio(
186
  ["Fundamental", "Homography"],
187
  label="Reconstruct Geometry",
188
- value="Homography",
189
  )
190
 
191
  # with gr.Column():
@@ -197,7 +213,6 @@ def run(config):
197
  match_setting_max_features,
198
  detect_keypoints_threshold,
199
  matcher_list,
200
- # enable_ransac,
201
  ransac_method,
202
  ransac_reproj_threshold,
203
  ransac_confidence,
@@ -243,9 +258,13 @@ def run(config):
243
  output_wrapped = gr.Image(
244
  label="Wrapped Pair", type="numpy"
245
  )
246
- with gr.Accordion("Open for More: Geometry info", open=False):
247
- geometry_result = gr.JSON(label="Reconstructed Geometry")
248
-
 
 
 
 
249
  # callbacks
250
  match_image_src.change(
251
  fn=ui_change_imagebox,
@@ -289,7 +308,6 @@ def run(config):
289
  matcher_info,
290
  output_wrapped,
291
  geometry_result,
292
- # enable_ransac,
293
  ransac_method,
294
  ransac_reproj_threshold,
295
  ransac_confidence,
 
2
  import gradio as gr
3
  from common.utils import (
4
  matcher_zoo,
5
+ ransac_zoo,
6
  change_estimate_geom,
7
  run_matching,
 
8
  gen_examples,
9
+ DEFAULT_RANSAC_METHOD,
10
+ DEFAULT_SETTING_GEOMETRY,
11
+ DEFAULT_RANSAC_REPROJ_THRESHOLD,
12
+ DEFAULT_RANSAC_CONFIDENCE,
13
+ DEFAULT_RANSAC_MAX_ITER,
14
+ DEFAULT_MATCHING_THRESHOLD,
15
+ DEFAULT_SETTING_MAX_FEATURES,
16
+ DEFAULT_DEFAULT_KEYPOINT_THRESHOLD,
17
  )
18
 
19
  DESCRIPTION = """
 
29
 
30
 
31
  def ui_change_imagebox(choice):
32
+ """
33
+ Updates the image box with the given choice.
34
+
35
+ Args:
36
+ choice (list): The list of image sources to be displayed in the image box.
37
+
38
+ Returns:
39
+ dict: A dictionary containing the updated value, sources, and type for the image box.
40
+ """
41
+ return {
42
+ "value": None, # The updated value of the image box
43
+ "sources": choice, # The list of image sources to be displayed
44
+ "__type__": "update", # The type of update for the image box
45
+ }
46
 
47
 
48
+ def ui_reset_state(*args):
49
+ """
50
+ Reset the state of the UI.
51
+
52
+ Returns:
53
+ tuple: A tuple containing the initial values for the UI state.
54
+ """
55
+ key = list(matcher_zoo.keys())[0] # Get the first key from matcher_zoo
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return (
57
+ None, # image0
58
+ None, # image1
59
+ DEFAULT_MATCHING_THRESHOLD, # matching_threshold
60
+ DEFAULT_SETTING_MAX_FEATURES, # max_features
61
+ DEFAULT_DEFAULT_KEYPOINT_THRESHOLD, # keypoint_threshold
62
+ key, # matcher
63
+ ui_change_imagebox("upload"), # input image0
64
+ ui_change_imagebox("upload"), # input image1
65
+ "upload", # match_image_src
66
  None, # keypoints
67
  None, # raw matches
68
  None, # ransac matches
69
+ {}, # matches result info
70
+ {}, # matcher config
71
+ None, # warped image
72
+ {}, # geometry result
73
+ DEFAULT_RANSAC_METHOD, # ransac_method
74
+ DEFAULT_RANSAC_REPROJ_THRESHOLD, # ransac_reproj_threshold
75
+ DEFAULT_RANSAC_CONFIDENCE, # ransac_confidence
76
+ DEFAULT_RANSAC_MAX_ITER, # ransac_max_iter
77
+ DEFAULT_SETTING_GEOMETRY, # geometry
 
78
  )
79
 
80
 
81
  # "footer {visibility: hidden}"
82
  def run(config):
83
+ """
84
+ Runs the application.
85
+
86
+ Args:
87
+ config (dict): A dictionary containing configuration parameters for the application.
88
+
89
+ Returns:
90
+ None
91
+ """
92
  with gr.Blocks(css="style.css") as app:
93
  gr.Markdown(DESCRIPTION)
94
 
 
110
  input_image0 = gr.Image(
111
  label="Image 0",
112
  type="numpy",
 
113
  image_mode="RGB",
114
+ height=300,
115
+ interactive=True,
116
  )
117
  input_image1 = gr.Image(
118
  label="Image 1",
119
  type="numpy",
 
120
  image_mode="RGB",
121
+ height=300,
122
+ interactive=True,
123
  )
124
 
125
  with gr.Row():
126
  button_reset = gr.Button(value="Reset")
127
+ button_run = gr.Button(value="Run Match", variant="primary")
 
 
128
 
129
  with gr.Accordion("Advanced Setting", open=False):
130
  with gr.Accordion("Matching Setting", open=True):
 
169
  # enable_ransac = gr.Checkbox(label="Enable RANSAC")
170
  ransac_method = gr.Dropdown(
171
  choices=ransac_zoo.keys(),
172
+ value=DEFAULT_RANSAC_METHOD,
173
  label="RANSAC Method",
174
  interactive=True,
175
  )
 
201
  choice_estimate_geom = gr.Radio(
202
  ["Fundamental", "Homography"],
203
  label="Reconstruct Geometry",
204
+ value=DEFAULT_SETTING_GEOMETRY,
205
  )
206
 
207
  # with gr.Column():
 
213
  match_setting_max_features,
214
  detect_keypoints_threshold,
215
  matcher_list,
 
216
  ransac_method,
217
  ransac_reproj_threshold,
218
  ransac_confidence,
 
258
  output_wrapped = gr.Image(
259
  label="Wrapped Pair", type="numpy"
260
  )
261
+ with gr.Accordion(
262
+ "Open for More: Geometry info", open=False
263
+ ):
264
+ geometry_result = gr.JSON(
265
+ label="Reconstructed Geometry"
266
+ )
267
+
268
  # callbacks
269
  match_image_src.change(
270
  fn=ui_change_imagebox,
 
308
  matcher_info,
309
  output_wrapped,
310
  geometry_result,
 
311
  ransac_method,
312
  ransac_reproj_threshold,
313
  ransac_confidence,
common/utils.py CHANGED
@@ -13,6 +13,18 @@ from .viz import draw_matches, fig2im, plot_images, plot_color_line_matches
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def get_model(match_conf):
18
  Model = dynamic_load(matchers, match_conf["model"]["name"])
@@ -52,14 +64,13 @@ def gen_examples():
52
  # image pair path
53
  path = "datasets/sacre_coeur/mapping"
54
  pairs = gen_images_pairs(path, len(example_matchers))
55
- match_setting_threshold = 0.1
56
- match_setting_max_features = 2000
57
- detect_keypoints_threshold = 0.01
58
- enable_ransac = True
59
- ransac_method = "RANSAC"
60
- ransac_reproj_threshold = 8
61
- ransac_confidence = 0.999
62
- ransac_max_iter = 10000
63
  input_lists = []
64
  for pair, mt in zip(pairs, example_matchers):
65
  input_lists.append(
@@ -82,10 +93,10 @@ def gen_examples():
82
 
83
  def filter_matches(
84
  pred,
85
- ransac_method="RANSAC",
86
- ransac_reproj_threshold=8,
87
- ransac_confidence=0.999,
88
- ransac_max_iter=10000,
89
  ):
90
  mkpts0 = None
91
  mkpts1 = None
@@ -106,9 +117,9 @@ def filter_matches(
106
  if mkpts0 is None or mkpts0 is None:
107
  return pred
108
  if ransac_method not in ransac_zoo.keys():
109
- ransac_method = "RANSAC"
110
 
111
- if len(mkpts0) < 4:
112
  return pred
113
  H, mask = cv2.findHomography(
114
  mkpts0,
@@ -132,10 +143,10 @@ def filter_matches(
132
 
133
  def compute_geom(
134
  pred,
135
- ransac_method="RANSAC",
136
- ransac_reproj_threshold=8,
137
- ransac_confidence=0.999,
138
- ransac_max_iter=10000,
139
  ) -> dict:
140
  mkpts0 = None
141
  mkpts1 = None
@@ -152,7 +163,7 @@ def compute_geom(
152
  mkpts1 = pred["line_keypoints1_orig"]
153
 
154
  if mkpts0 is not None and mkpts1 is not None:
155
- if len(mkpts0) < 8:
156
  return {}
157
  h1, w1, _ = pred["image0_orig"].shape
158
  geo_info = {}
@@ -309,12 +320,11 @@ def run_matching(
309
  extract_max_keypoints,
310
  keypoint_threshold,
311
  key,
312
- # enable_ransac=False,
313
- ransac_method="RANSAC",
314
- ransac_reproj_threshold=8,
315
- ransac_confidence=0.999,
316
- ransac_max_iter=10000,
317
- choice_estimate_geom="Homography",
318
  ):
319
  # image0 and image1 is RGB mode
320
  if image0 is None or image1 is None:
@@ -420,7 +430,6 @@ def run_matching(
420
  "geom_info": geom_info,
421
  },
422
  output_wrapped,
423
- # geometry_result,
424
  )
425
 
426
 
 
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ DEFAULT_SETTING_THRESHOLD = 0.1
17
+ DEFAULT_SETTING_MAX_FEATURES = 2000
18
+ DEFAULT_DEFAULT_KEYPOINT_THRESHOLD = 0.01
19
+ DEFAULT_ENABLE_RANSAC = True
20
+ DEFAULT_RANSAC_METHOD = "USAC_MAGSAC"
21
+ DEFAULT_RANSAC_REPROJ_THRESHOLD = 8
22
+ DEFAULT_RANSAC_CONFIDENCE = 0.999
23
+ DEFAULT_RANSAC_MAX_ITER = 10000
24
+ DEFAULT_MIN_NUM_MATCHES = 4
25
+ DEFAULT_MATCHING_THRESHOLD = 0.2
26
+ DEFAULT_SETTING_GEOMETRY = "Homography"
27
+
28
 
29
  def get_model(match_conf):
30
  Model = dynamic_load(matchers, match_conf["model"]["name"])
 
64
  # image pair path
65
  path = "datasets/sacre_coeur/mapping"
66
  pairs = gen_images_pairs(path, len(example_matchers))
67
+ match_setting_threshold = DEFAULT_SETTING_THRESHOLD
68
+ match_setting_max_features = DEFAULT_SETTING_MAX_FEATURES
69
+ detect_keypoints_threshold = DEFAULT_DEFAULT_KEYPOINT_THRESHOLD
70
+ ransac_method = DEFAULT_RANSAC_METHOD
71
+ ransac_reproj_threshold = DEFAULT_RANSAC_REPROJ_THRESHOLD
72
+ ransac_confidence = DEFAULT_RANSAC_CONFIDENCE
73
+ ransac_max_iter = DEFAULT_RANSAC_MAX_ITER
 
74
  input_lists = []
75
  for pair, mt in zip(pairs, example_matchers):
76
  input_lists.append(
 
93
 
94
  def filter_matches(
95
  pred,
96
+ ransac_method=DEFAULT_RANSAC_METHOD,
97
+ ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
98
+ ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
99
+ ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
100
  ):
101
  mkpts0 = None
102
  mkpts1 = None
 
117
  if mkpts0 is None or mkpts0 is None:
118
  return pred
119
  if ransac_method not in ransac_zoo.keys():
120
+ ransac_method = DEFAULT_RANSAC_METHOD
121
 
122
+ if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES:
123
  return pred
124
  H, mask = cv2.findHomography(
125
  mkpts0,
 
143
 
144
  def compute_geom(
145
  pred,
146
+ ransac_method=DEFAULT_RANSAC_METHOD,
147
+ ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
148
+ ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
149
+ ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
150
  ) -> dict:
151
  mkpts0 = None
152
  mkpts1 = None
 
163
  mkpts1 = pred["line_keypoints1_orig"]
164
 
165
  if mkpts0 is not None and mkpts1 is not None:
166
+ if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES:
167
  return {}
168
  h1, w1, _ = pred["image0_orig"].shape
169
  geo_info = {}
 
320
  extract_max_keypoints,
321
  keypoint_threshold,
322
  key,
323
+ ransac_method=DEFAULT_RANSAC_METHOD,
324
+ ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD,
325
+ ransac_confidence=DEFAULT_RANSAC_CONFIDENCE,
326
+ ransac_max_iter=DEFAULT_RANSAC_MAX_ITER,
327
+ choice_estimate_geom=DEFAULT_SETTING_GEOMETRY,
 
328
  ):
329
  # image0 and image1 is RGB mode
330
  if image0 is None or image1 is None:
 
430
  "geom_info": geom_info,
431
  },
432
  output_wrapped,
 
433
  )
434
 
435
 
common/viz.py CHANGED
@@ -1,25 +1,9 @@
1
- import bisect
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
- import matplotlib, os, cv2
5
- import matplotlib.cm as cm
6
- from PIL import Image
7
- import torch.nn.functional as F
8
- import torch
9
  import seaborn as sns
10
 
11
 
12
- def _compute_conf_thresh(data):
13
- dataset_name = data["dataset_name"][0].lower()
14
- if dataset_name == "scannet":
15
- thr = 5e-4
16
- elif dataset_name == "megadepth":
17
- thr = 1e-4
18
- else:
19
- raise ValueError(f"Unknown dataset: {dataset_name}")
20
- return thr
21
-
22
-
23
  def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
24
  """Plot a set of images horizontally.
25
  Args:
@@ -172,95 +156,6 @@ def make_matching_figure(
172
  return fig
173
 
174
 
175
- def _make_evaluation_figure(data, b_id, alpha="dynamic"):
176
- b_mask = data["m_bids"] == b_id
177
- conf_thr = _compute_conf_thresh(data)
178
-
179
- img0 = (
180
- (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
181
- )
182
- img1 = (
183
- (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
184
- )
185
- kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
186
- kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
187
-
188
- # for megadepth, we visualize matches on the resized image
189
- if "scale0" in data:
190
- kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
191
- kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
192
-
193
- epi_errs = data["epi_errs"][b_mask].cpu().numpy()
194
- correct_mask = epi_errs < conf_thr
195
- precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
196
- n_correct = np.sum(correct_mask)
197
- n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
198
- recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
199
- # recall might be larger than 1, since the calculation of conf_matrix_gt
200
- # uses groundtruth depths and camera poses, but epipolar distance is used here.
201
-
202
- # matching info
203
- if alpha == "dynamic":
204
- alpha = dynamic_alpha(len(correct_mask))
205
- color = error_colormap(epi_errs, conf_thr, alpha=alpha)
206
-
207
- text = [
208
- f"#Matches {len(kpts0)}",
209
- f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):"
210
- f" {n_correct}/{len(kpts0)}",
211
- f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%):"
212
- f" {n_correct}/{n_gt_matches}",
213
- ]
214
-
215
- # make the figure
216
- figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
217
- return figure
218
-
219
-
220
- def _make_confidence_figure(data, b_id):
221
- # TODO: Implement confidence figure
222
- raise NotImplementedError()
223
-
224
-
225
- def make_matching_figures(data, config, mode="evaluation"):
226
- """Make matching figures for a batch.
227
-
228
- Args:
229
- data (Dict): a batch updated by PL_LoFTR.
230
- config (Dict): matcher config
231
- Returns:
232
- figures (Dict[str, List[plt.figure]]
233
- """
234
- assert mode in ["evaluation", "confidence"] # 'confidence'
235
- figures = {mode: []}
236
- for b_id in range(data["image0"].size(0)):
237
- if mode == "evaluation":
238
- fig = _make_evaluation_figure(
239
- data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
240
- )
241
- elif mode == "confidence":
242
- fig = _make_confidence_figure(data, b_id)
243
- else:
244
- raise ValueError(f"Unknown plot mode: {mode}")
245
- figures[mode].append(fig)
246
- return figures
247
-
248
-
249
- def dynamic_alpha(
250
- n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
251
- ):
252
- if n_matches == 0:
253
- return 1.0
254
- ranges = list(zip(alphas, alphas[1:] + [None]))
255
- loc = bisect.bisect_right(milestones, n_matches) - 1
256
- _range = ranges[loc]
257
- if _range[1] is None:
258
- return _range[0]
259
- return _range[1] + (milestones[loc + 1] - n_matches) / (
260
- milestones[loc + 1] - milestones[loc]
261
- ) * (_range[0] - _range[1])
262
-
263
-
264
  def error_colormap(err, thr, alpha=1.0):
265
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
266
  x = 1 - np.clip(err / (thr * 2), 0, 1)
@@ -278,245 +173,6 @@ color_map = np.arange(100)
278
  np.random.shuffle(color_map)
279
 
280
 
281
- def draw_topics(
282
- data,
283
- img0,
284
- img1,
285
- saved_folder="viz_topics",
286
- show_n_topics=8,
287
- saved_name=None,
288
- ):
289
- topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
290
- hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
291
- hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
292
- # print(hw0_i, hw1_i)
293
- scale0, scale1 = hw0_i[0] // hw0_c[0], hw1_i[0] // hw1_c[0]
294
- if "scale0" in data:
295
- scale0 *= data["scale0"][0]
296
- else:
297
- scale0 = (scale0, scale0)
298
- if "scale1" in data:
299
- scale1 *= data["scale1"][0]
300
- else:
301
- scale1 = (scale1, scale1)
302
-
303
- n_topics = topic0.shape[-1]
304
- # mask0_nonzero = topic0[0].sum(dim=-1, keepdim=True) > 0
305
- # mask1_nonzero = topic1[0].sum(dim=-1, keepdim=True) > 0
306
- theta0 = topic0[0].sum(dim=0)
307
- theta0 /= theta0.sum().float()
308
- theta1 = topic1[0].sum(dim=0)
309
- theta1 /= theta1.sum().float()
310
- # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics]
311
- # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics]
312
- top_topics = torch.argsort(theta0 * theta1, descending=True)[:show_n_topics]
313
- # print(sum_topic0, sum_topic1)
314
-
315
- topic0 = topic0[0].argmax(
316
- dim=-1, keepdim=True
317
- ) # .float() / (n_topics - 1) #* 255 + 1 #
318
- # topic0[~mask0_nonzero] = -1
319
- topic1 = topic1[0].argmax(
320
- dim=-1, keepdim=True
321
- ) # .float() / (n_topics - 1) #* 255 + 1
322
- # topic1[~mask1_nonzero] = -1
323
- label_img0, label_img1 = (
324
- torch.zeros_like(topic0) - 1,
325
- torch.zeros_like(topic1) - 1,
326
- )
327
- for i, k in enumerate(top_topics):
328
- label_img0[topic0 == k] = color_map[k]
329
- label_img1[topic1 == k] = color_map[k]
330
-
331
- # print(hw0_c, scale0)
332
- # print(hw1_c, scale1)
333
- # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0)
334
- map_topic0 = (
335
- label_img0.float().view(hw0_c).cpu().numpy()
336
- ) # map_topic0.squeeze(0).squeeze(0).cpu().numpy()
337
- map_topic0 = cv2.resize(
338
- map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1]))
339
- )
340
- # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1)
341
- map_topic1 = (
342
- label_img1.float().view(hw1_c).cpu().numpy()
343
- ) # map_topic1.squeeze(0).squeeze(0).cpu().numpy()
344
- map_topic1 = cv2.resize(
345
- map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1]))
346
- )
347
-
348
- # show image0
349
- if saved_name is None:
350
- return map_topic0, map_topic1
351
-
352
- if not os.path.exists(saved_folder):
353
- os.makedirs(saved_folder)
354
- path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name))
355
- plt.imshow(img0)
356
- masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0)
357
- plt.imshow(
358
- masked_map_topic0,
359
- cmap=plt.cm.jet,
360
- vmin=0,
361
- vmax=n_topics - 1,
362
- alpha=0.3,
363
- interpolation="bilinear",
364
- )
365
- # plt.show()
366
- plt.axis("off")
367
- plt.savefig(path_saved_img0, bbox_inches="tight", pad_inches=0, dpi=250)
368
- plt.close()
369
-
370
- path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name))
371
- plt.imshow(img1)
372
- masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1)
373
- plt.imshow(
374
- masked_map_topic1,
375
- cmap=plt.cm.jet,
376
- vmin=0,
377
- vmax=n_topics - 1,
378
- alpha=0.3,
379
- interpolation="bilinear",
380
- )
381
- plt.axis("off")
382
- plt.savefig(path_saved_img1, bbox_inches="tight", pad_inches=0, dpi=250)
383
- plt.close()
384
-
385
-
386
- def draw_topicfm_demo(
387
- data,
388
- img0,
389
- img1,
390
- mkpts0,
391
- mkpts1,
392
- mcolor,
393
- text,
394
- show_n_topics=8,
395
- topic_alpha=0.3,
396
- margin=5,
397
- path=None,
398
- opencv_display=False,
399
- opencv_title="",
400
- ):
401
- topic_map0, topic_map1 = draw_topics(
402
- data, img0, img1, show_n_topics=show_n_topics
403
- )
404
-
405
- mask_tm0, mask_tm1 = np.expand_dims(
406
- topic_map0 >= 0, axis=-1
407
- ), np.expand_dims(topic_map1 >= 0, axis=-1)
408
-
409
- topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
410
- topic_cm0 = cv2.cvtColor(
411
- topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
412
- )
413
- topic_cm1 = cv2.cvtColor(
414
- topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
415
- )
416
- overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
417
- overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
418
-
419
- cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
420
- cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
421
-
422
- overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (
423
- overlay1 * 255
424
- ).astype(np.uint8)
425
-
426
- h0, w0 = img0.shape[:2]
427
- h1, w1 = img1.shape[:2]
428
- h, w = h0 * 2 + margin * 2, w0 * 2 + margin
429
- out_fig = 255 * np.ones((h, w, 3), dtype=np.uint8)
430
- out_fig[:h0, :w0] = overlay0
431
- if h0 >= h1:
432
- start = (h0 - h1) // 2
433
- out_fig[
434
- start : (start + h1), (w0 + margin) : (w0 + margin + w1)
435
- ] = overlay1
436
- else:
437
- start = (h1 - h0) // 2
438
- out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
439
- start : (start + h0)
440
- ]
441
-
442
- step_h = h0 + margin * 2
443
- out_fig[step_h : step_h + h0, :w0] = (img0 * 255).astype(np.uint8)
444
- if h0 >= h1:
445
- start = step_h + (h0 - h1) // 2
446
- out_fig[start : start + h1, (w0 + margin) : (w0 + margin + w1)] = (
447
- img1 * 255
448
- ).astype(np.uint8)
449
- else:
450
- start = (h1 - h0) // 2
451
- out_fig[step_h : step_h + h0, (w0 + margin) : (w0 + margin + w1)] = (
452
- img1[start : start + h0] * 255
453
- ).astype(np.uint8)
454
-
455
- # draw matching lines, this is inspried from
456
- # https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
457
- mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
458
- mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
459
-
460
- for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor):
461
- c = c.tolist()
462
- cv2.line(
463
- out_fig,
464
- (x0, y0 + step_h),
465
- (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
466
- color=c,
467
- thickness=1,
468
- lineType=cv2.LINE_AA,
469
- )
470
- # display line end-points as circles
471
- cv2.circle(out_fig, (x0, y0 + step_h), 2, c, -1, lineType=cv2.LINE_AA)
472
- cv2.circle(
473
- out_fig,
474
- (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
475
- 2,
476
- c,
477
- -1,
478
- lineType=cv2.LINE_AA,
479
- )
480
-
481
- # Scale factor for consistent visualization across scales.
482
- sc = min(h / 960.0, 2.0)
483
-
484
- # Big text.
485
- Ht = int(30 * sc) # text height
486
- txt_color_fg = (255, 255, 255)
487
- txt_color_bg = (0, 0, 0)
488
- for i, t in enumerate(text):
489
- cv2.putText(
490
- out_fig,
491
- t,
492
- (int(8 * sc), Ht + step_h * i),
493
- cv2.FONT_HERSHEY_DUPLEX,
494
- 1.0 * sc,
495
- txt_color_bg,
496
- 2,
497
- cv2.LINE_AA,
498
- )
499
- cv2.putText(
500
- out_fig,
501
- t,
502
- (int(8 * sc), Ht + step_h * i),
503
- cv2.FONT_HERSHEY_DUPLEX,
504
- 1.0 * sc,
505
- txt_color_fg,
506
- 1,
507
- cv2.LINE_AA,
508
- )
509
-
510
- if path is not None:
511
- cv2.imwrite(str(path), out_fig)
512
-
513
- if opencv_display:
514
- cv2.imshow(opencv_title, out_fig)
515
- cv2.waitKey(1)
516
-
517
- return out_fig
518
-
519
-
520
  def fig2im(fig):
521
  fig.canvas.draw()
522
  w, h = fig.canvas.get_width_height()
 
 
1
  import numpy as np
2
  import matplotlib.pyplot as plt
3
+ import matplotlib
 
 
 
 
4
  import seaborn as sns
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
8
  """Plot a set of images horizontally.
9
  Args:
 
156
  return fig
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def error_colormap(err, thr, alpha=1.0):
160
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
161
  x = 1 - np.clip(err / (thr * 2), 0, 1)
 
173
  np.random.shuffle(color_map)
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def fig2im(fig):
177
  fig.canvas.draw()
178
  w, h = fig.canvas.get_width_height()
style.css CHANGED
@@ -1,5 +1,6 @@
1
  h1 {
2
  text-align: center;
 
3
  }
4
 
5
  #duplicate-button {
 
1
  h1 {
2
  text-align: center;
3
+ display:block;
4
  }
5
 
6
  #duplicate-button {