Realcat commited on
Commit
e400e91
1 Parent(s): 0bc7901

update: omniglue

Browse files
common/utils.py CHANGED
@@ -642,7 +642,7 @@ def run_matching(
642
  ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
643
  choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
644
  matcher_zoo: Dict[str, Any] = None,
645
- use_cached_model: bool = True,
646
  ) -> Tuple[
647
  np.ndarray,
648
  np.ndarray,
@@ -696,19 +696,21 @@ def run_matching(
696
  f"Success! Please be patient and allow for about 2-3 minutes."
697
  f" Due to CPU inference, {key} is quiet slow."
698
  )
 
699
  model = matcher_zoo[key]
700
  match_conf = model["matcher"]
701
  # update match config
702
  match_conf["model"]["match_threshold"] = match_threshold
703
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
704
- t0 = time.time()
705
  cache_key = "{}_{}".format(key, match_conf["model"]["name"])
706
- matcher = model_cache.cache_model(cache_key, get_model, match_conf)
707
  if use_cached_model:
 
 
708
  matcher.conf["max_keypoints"] = extract_max_keypoints
709
  matcher.conf["match_threshold"] = match_threshold
710
  logger.info(f"Loaded cached model {cache_key}")
711
-
 
712
  logger.info(f"Loading model using: {time.time()-t0:.3f}s")
713
  t1 = time.time()
714
 
@@ -725,13 +727,16 @@ def run_matching(
725
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
726
  cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
727
 
728
- extractor = model_cache.cache_model(
729
- cache_key, get_feature_model, extract_conf
730
- )
731
  if use_cached_model:
 
 
 
 
732
  extractor.conf["max_keypoints"] = extract_max_keypoints
733
  extractor.conf["keypoint_threshold"] = keypoint_threshold
734
  logger.info(f"Loaded cached model {cache_key}")
 
 
735
 
736
  pred0 = extract_features.extract(
737
  extractor, image0, extract_conf["preprocessing"]
 
642
  ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
643
  choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
644
  matcher_zoo: Dict[str, Any] = None,
645
+ use_cached_model: bool = False,
646
  ) -> Tuple[
647
  np.ndarray,
648
  np.ndarray,
 
696
  f"Success! Please be patient and allow for about 2-3 minutes."
697
  f" Due to CPU inference, {key} is quiet slow."
698
  )
699
+ t0 = time.time()
700
  model = matcher_zoo[key]
701
  match_conf = model["matcher"]
702
  # update match config
703
  match_conf["model"]["match_threshold"] = match_threshold
704
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
 
705
  cache_key = "{}_{}".format(key, match_conf["model"]["name"])
 
706
  if use_cached_model:
707
+ # because of the model cache, we need to update the config
708
+ matcher = model_cache.cache_model(cache_key, get_model, match_conf)
709
  matcher.conf["max_keypoints"] = extract_max_keypoints
710
  matcher.conf["match_threshold"] = match_threshold
711
  logger.info(f"Loaded cached model {cache_key}")
712
+ else:
713
+ matcher = get_model(match_conf)
714
  logger.info(f"Loading model using: {time.time()-t0:.3f}s")
715
  t1 = time.time()
716
 
 
727
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
728
  cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
729
 
 
 
 
730
  if use_cached_model:
731
+ extractor = model_cache.cache_model(
732
+ cache_key, get_feature_model, extract_conf
733
+ )
734
+ # because of the model cache, we need to update the config
735
  extractor.conf["max_keypoints"] = extract_max_keypoints
736
  extractor.conf["keypoint_threshold"] = keypoint_threshold
737
  logger.info(f"Loaded cached model {cache_key}")
738
+ else:
739
+ extractor = get_feature_model(extract_conf)
740
 
741
  pred0 = extract_features.extract(
742
  extractor, image0, extract_conf["preprocessing"]
hloc/match_dense.py CHANGED
@@ -216,6 +216,7 @@ confs = {
216
  "model": {
217
  "name": "omniglue",
218
  "match_threshold": 0.2,
 
219
  "features": "null",
220
  },
221
  "preprocessing": {
 
216
  "model": {
217
  "name": "omniglue",
218
  "match_threshold": 0.2,
219
+ "max_keypoints": 2000,
220
  "features": "null",
221
  },
222
  "preprocessing": {
hloc/matchers/duster.py CHANGED
@@ -105,7 +105,7 @@ class Duster(BaseModel):
105
  reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
106
  *pts3d_list
107
  )
108
- print(f"found {num_matches} matches")
109
  mkpts1 = pts2d_list[1][reciprocal_in_P2]
110
  mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
111
 
@@ -114,7 +114,6 @@ class Duster(BaseModel):
114
  keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
115
  mkpts0 = mkpts0[keep]
116
  mkpts1 = mkpts1[keep]
117
- breakpoint()
118
  pred = {
119
  "keypoints0": torch.from_numpy(mkpts0),
120
  "keypoints1": torch.from_numpy(mkpts1),
 
105
  reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
106
  *pts3d_list
107
  )
108
+ logger.info(f"Found {num_matches} matches")
109
  mkpts1 = pts2d_list[1][reciprocal_in_P2]
110
  mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
111
 
 
114
  keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
115
  mkpts0 = mkpts0[keep]
116
  mkpts1 = mkpts1[keep]
 
117
  pred = {
118
  "keypoints0": torch.from_numpy(mkpts0),
119
  "keypoints1": torch.from_numpy(mkpts1),
hloc/matchers/omniglue.py CHANGED
@@ -39,7 +39,6 @@ class OmniGlue(BaseModel):
39
  subprocess.run(cmd, check=True)
40
  else:
41
  logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
42
-
43
  self.net = omniglue.OmniGlue(
44
  og_export=str(og_model_path),
45
  sp_export=str(sp_model_path),
@@ -54,9 +53,8 @@ class OmniGlue(BaseModel):
54
  image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
55
  image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
56
  match_kp0, match_kp1, match_confidences = self.net.FindMatches(
57
- image0_rgb_np, image1_rgb_np
58
  )
59
-
60
  # filter matches
61
  match_threshold = self.conf["match_threshold"]
62
  keep_idx = []
 
39
  subprocess.run(cmd, check=True)
40
  else:
41
  logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
 
42
  self.net = omniglue.OmniGlue(
43
  og_export=str(og_model_path),
44
  sp_export=str(sp_model_path),
 
53
  image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
54
  image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
55
  match_kp0, match_kp1, match_confidences = self.net.FindMatches(
56
+ image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"]
57
  )
 
58
  # filter matches
59
  match_threshold = self.conf["match_threshold"]
60
  keep_idx = []
hloc/utils/viz.py CHANGED
@@ -65,9 +65,11 @@ def plot_keypoints(kpts, colors="lime", ps=4):
65
  if not isinstance(colors, list):
66
  colors = [colors] * len(kpts)
67
  axes = plt.gcf().axes
68
- for a, k, c in zip(axes, kpts, colors):
69
- a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
70
-
 
 
71
 
72
  def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
73
  """Plot matches for a pair of existing images.
 
65
  if not isinstance(colors, list):
66
  colors = [colors] * len(kpts)
67
  axes = plt.gcf().axes
68
+ try:
69
+ for a, k, c in zip(axes, kpts, colors):
70
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
71
+ except IndexError as e:
72
+ pass
73
 
74
  def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
75
  """Plot matches for a pair of existing images.
third_party/omniglue/src/omniglue/omniglue_extract.py CHANGED
@@ -46,13 +46,18 @@ class OmniGlue:
46
  dino_export, feature_layer=1
47
  )
48
 
49
- def FindMatches(self, image0: np.ndarray, image1: np.ndarray):
 
 
 
 
 
50
  """TODO(omniglue): docstring."""
51
  height0, width0 = image0.shape[:2]
52
  height1, width1 = image1.shape[:2]
53
  # TODO: numpy to torch inputs
54
- sp_features0 = self.sp_extract(image0, num_features=self.max_keypoints)
55
- sp_features1 = self.sp_extract(image1, num_features=self.max_keypoints)
56
  dino_features0 = self.dino_extract(image0)
57
  dino_features1 = self.dino_extract(image1)
58
  dino_descriptors0 = dino_extract.get_dino_descriptors(
 
46
  dino_export, feature_layer=1
47
  )
48
 
49
+ def FindMatches(
50
+ self,
51
+ image0: np.ndarray,
52
+ image1: np.ndarray,
53
+ max_keypoints: int = 2048,
54
+ ):
55
  """TODO(omniglue): docstring."""
56
  height0, width0 = image0.shape[:2]
57
  height1, width1 = image1.shape[:2]
58
  # TODO: numpy to torch inputs
59
+ sp_features0 = self.sp_extract(image0, num_features=max_keypoints)
60
+ sp_features1 = self.sp_extract(image1, num_features=max_keypoints)
61
  dino_features0 = self.dino_extract(image0)
62
  dino_features1 = self.dino_extract(image1)
63
  dino_descriptors0 = dino_extract.get_dino_descriptors(