Realcat commited on
Commit
8d7004c
·
1 Parent(s): 1dbd087

fix: crash when setting too less max features

Browse files
Files changed (3) hide show
  1. common/utils.py +18 -4
  2. common/viz.py +0 -1
  3. hloc/matchers/cotr.py +2 -2
common/utils.py CHANGED
@@ -209,6 +209,18 @@ def gen_examples():
209
  return input_lists
210
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  def filter_matches(
213
  pred: Dict[str, Any],
214
  ransac_method: str = DEFAULT_RANSAC_METHOD,
@@ -246,14 +258,14 @@ def filter_matches(
246
  mkpts1 = pred["line_keypoints1_orig"]
247
  feature_type = "LINE"
248
  else:
249
- return pred
250
  if mkpts0 is None or mkpts0 is None:
251
- return pred
252
  if ransac_method not in ransac_zoo.keys():
253
  ransac_method = DEFAULT_RANSAC_METHOD
254
 
255
  if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES:
256
- return pred
257
  H, mask = cv2.findHomography(
258
  mkpts0,
259
  mkpts1,
@@ -272,6 +284,8 @@ def filter_matches(
272
  pred["mline_keypoints0_orig"] = mkpts0[mask]
273
  pred["mline_keypoints1_orig"] = mkpts1[mask]
274
  pred["H"] = H
 
 
275
  return pred
276
 
277
 
@@ -344,7 +358,7 @@ def compute_geometry(
344
  geo_info["H1"] = H1.tolist()
345
  geo_info["H2"] = H2.tolist()
346
  except cv2.error as e:
347
- logger.error(f"{e}, skip")
348
  return geo_info
349
  else:
350
  return {}
 
209
  return input_lists
210
 
211
 
212
+ def set_null_pred(feature_type: str, pred: dict):
213
+ if feature_type == "KEYPOINT":
214
+ pred["mkeypoints0_orig"] = np.array([])
215
+ pred["mkeypoints1_orig"] = np.array([])
216
+ pred["mmconf"] = np.array([])
217
+ elif feature_type == "LINE":
218
+ pred["mline_keypoints0_orig"] = np.array([])
219
+ pred["mline_keypoints1_orig"] = np.array([])
220
+ pred["H"] = None
221
+ return pred
222
+
223
+
224
  def filter_matches(
225
  pred: Dict[str, Any],
226
  ransac_method: str = DEFAULT_RANSAC_METHOD,
 
258
  mkpts1 = pred["line_keypoints1_orig"]
259
  feature_type = "LINE"
260
  else:
261
+ return set_null_pred(feature_type, pred)
262
  if mkpts0 is None or mkpts0 is None:
263
+ return set_null_pred(feature_type, pred)
264
  if ransac_method not in ransac_zoo.keys():
265
  ransac_method = DEFAULT_RANSAC_METHOD
266
 
267
  if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES:
268
+ return set_null_pred(feature_type, pred)
269
  H, mask = cv2.findHomography(
270
  mkpts0,
271
  mkpts1,
 
284
  pred["mline_keypoints0_orig"] = mkpts0[mask]
285
  pred["mline_keypoints1_orig"] = mkpts1[mask]
286
  pred["H"] = H
287
+ else:
288
+ set_null_pred(feature_type, pred)
289
  return pred
290
 
291
 
 
358
  geo_info["H1"] = H1.tolist()
359
  geo_info["H2"] = H2.tolist()
360
  except cv2.error as e:
361
+ logger.error(f"StereoRectifyUncalibrated failed, skip!")
362
  return geo_info
363
  else:
364
  return {}
common/viz.py CHANGED
@@ -396,7 +396,6 @@ def display_matches(
396
  """
397
  img0 = pred["image0_orig"]
398
  img1 = pred["image1_orig"]
399
-
400
  num_inliers = 0
401
  # draw raw matches
402
  if (
 
396
  """
397
  img0 = pred["image0_orig"]
398
  img1 = pred["image1_orig"]
 
399
  num_inliers = 0
400
  # draw raw matches
401
  if (
hloc/matchers/cotr.py CHANGED
@@ -71,7 +71,7 @@ class COTR(BaseModel):
71
  queries_a=None,
72
  )
73
  pred = {
74
- "keypoints0": torch.from_numpy(corrs[:,:2]),
75
- "keypoints1": torch.from_numpy(corrs[:,2:]),
76
  }
77
  return pred
 
71
  queries_a=None,
72
  )
73
  pred = {
74
+ "keypoints0": torch.from_numpy(corrs[:, :2]),
75
+ "keypoints1": torch.from_numpy(corrs[:, 2:]),
76
  }
77
  return pred