Realcat commited on
Commit
0bc7901
1 Parent(s): 82ee2a0

add: ModelCache

Browse files
Files changed (3) hide show
  1. common/utils.py +73 -11
  2. hloc/matchers/omniglue.py +1 -0
  3. test_app_cli.py +36 -5
common/utils.py CHANGED
@@ -1,7 +1,10 @@
1
  import os
2
  import cv2
 
3
  import torch
4
  import random
 
 
5
  import numpy as np
6
  import gradio as gr
7
  from pathlib import Path
@@ -42,6 +45,66 @@ MATCHER_ZOO = None
42
  models_already_loaded = {}
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def load_config(config_name: str) -> Dict[str, Any]:
46
  """
47
  Load a YAML configuration file.
@@ -579,6 +642,7 @@ def run_matching(
579
  ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
580
  choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
581
  matcher_zoo: Dict[str, Any] = None,
 
582
  ) -> Tuple[
583
  np.ndarray,
584
  np.ndarray,
@@ -639,15 +703,12 @@ def run_matching(
639
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
640
  t0 = time.time()
641
  cache_key = "{}_{}".format(key, match_conf["model"]["name"])
642
- if cache_key in models_already_loaded:
643
- matcher = models_already_loaded[cache_key]
644
  matcher.conf["max_keypoints"] = extract_max_keypoints
645
  matcher.conf["match_threshold"] = match_threshold
646
  logger.info(f"Loaded cached model {cache_key}")
647
- else:
648
- matcher = get_model(match_conf)
649
- models_already_loaded[cache_key] = matcher
650
- # gr.Info(f"Loading model using: {time.time()-t0:.3f}s")
651
  logger.info(f"Loading model using: {time.time()-t0:.3f}s")
652
  t1 = time.time()
653
 
@@ -663,14 +724,15 @@ def run_matching(
663
  extract_conf["model"]["max_keypoints"] = extract_max_keypoints
664
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
665
  cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
666
- if cache_key in models_already_loaded:
667
- extractor = models_already_loaded[cache_key]
 
 
 
668
  extractor.conf["max_keypoints"] = extract_max_keypoints
669
  extractor.conf["keypoint_threshold"] = keypoint_threshold
670
  logger.info(f"Loaded cached model {cache_key}")
671
- else:
672
- extractor = get_feature_model(extract_conf)
673
- models_already_loaded[cache_key] = extractor
674
  pred0 = extract_features.extract(
675
  extractor, image0, extract_conf["preprocessing"]
676
  )
 
1
  import os
2
  import cv2
3
+ import sys
4
  import torch
5
  import random
6
+ import psutil
7
+ import shutil
8
  import numpy as np
9
  import gradio as gr
10
  from pathlib import Path
 
45
  models_already_loaded = {}
46
 
47
 
48
+ class ModelCache:
49
+ def __init__(self, max_memory_size: int = 8):
50
+ self.max_memory_size = max_memory_size
51
+ self.current_memory_size = 0
52
+ self.model_dict = {}
53
+ self.model_timestamps = []
54
+
55
+ def cache_model(self, model_key, model_loader_func, model_conf):
56
+ if model_key in self.model_dict:
57
+ self.model_timestamps.remove(model_key)
58
+ self.model_timestamps.append(model_key)
59
+ logger.info(f"Load cached {model_key}")
60
+ return self.model_dict[model_key]
61
+
62
+ model = self._load_model_from_disk(model_loader_func, model_conf)
63
+ while self._calculate_model_memory() > self.max_memory_size:
64
+ if len(self.model_timestamps) == 0:
65
+ logger.warn(
66
+ "RAM: {}GB, MAX RAM: {}GB".format(
67
+ self._calculate_model_memory(), self.max_memory_size
68
+ )
69
+ )
70
+ break
71
+ oldest_model_key = self.model_timestamps.pop(0)
72
+ self.current_memory_size = self._calculate_model_memory()
73
+ logger.info(f"Del cached {oldest_model_key}")
74
+ del self.model_dict[oldest_model_key]
75
+
76
+ self.model_dict[model_key] = model
77
+ self.model_timestamps.append(model_key)
78
+
79
+ self.print_memory_usage()
80
+ logger.info(f"Total cached {list(self.model_dict.keys())}")
81
+
82
+ return model
83
+
84
+ def _load_model_from_disk(self, model_loader_func, model_conf):
85
+ return model_loader_func(model_conf)
86
+
87
+ def _calculate_model_memory(self, verbose=False):
88
+ host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
89
+ vm = psutil.virtual_memory()
90
+ du = shutil.disk_usage(".")
91
+ vm_ratio = host_colocation * vm.used / vm.total
92
+ if verbose:
93
+ logger.info(
94
+ f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
95
+ )
96
+ # logger.info(
97
+ # f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
98
+ # )
99
+ return vm.used / 1e9
100
+
101
+ def print_memory_usage(self):
102
+ self._calculate_model_memory(verbose=True)
103
+
104
+
105
+ model_cache = ModelCache()
106
+
107
+
108
  def load_config(config_name: str) -> Dict[str, Any]:
109
  """
110
  Load a YAML configuration file.
 
642
  ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
643
  choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
644
  matcher_zoo: Dict[str, Any] = None,
645
+ use_cached_model: bool = True,
646
  ) -> Tuple[
647
  np.ndarray,
648
  np.ndarray,
 
703
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
704
  t0 = time.time()
705
  cache_key = "{}_{}".format(key, match_conf["model"]["name"])
706
+ matcher = model_cache.cache_model(cache_key, get_model, match_conf)
707
+ if use_cached_model:
708
  matcher.conf["max_keypoints"] = extract_max_keypoints
709
  matcher.conf["match_threshold"] = match_threshold
710
  logger.info(f"Loaded cached model {cache_key}")
711
+
 
 
 
712
  logger.info(f"Loading model using: {time.time()-t0:.3f}s")
713
  t1 = time.time()
714
 
 
724
  extract_conf["model"]["max_keypoints"] = extract_max_keypoints
725
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
726
  cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
727
+
728
+ extractor = model_cache.cache_model(
729
+ cache_key, get_feature_model, extract_conf
730
+ )
731
+ if use_cached_model:
732
  extractor.conf["max_keypoints"] = extract_max_keypoints
733
  extractor.conf["keypoint_threshold"] = keypoint_threshold
734
  logger.info(f"Loaded cached model {cache_key}")
735
+
 
 
736
  pred0 = extract_features.extract(
737
  extractor, image0, extract_conf["preprocessing"]
738
  )
hloc/matchers/omniglue.py CHANGED
@@ -10,6 +10,7 @@ from ..utils.base_model import BaseModel
10
  thirdparty_path = Path(__file__).parent / "../../third_party"
11
  sys.path.append(str(thirdparty_path))
12
  from omniglue.src import omniglue
 
13
  omniglue_path = thirdparty_path / "omniglue"
14
 
15
 
 
10
  thirdparty_path = Path(__file__).parent / "../../third_party"
11
  sys.path.append(str(thirdparty_path))
12
  from omniglue.src import omniglue
13
+
14
  omniglue_path = thirdparty_path / "omniglue"
15
 
16
 
test_app_cli.py CHANGED
@@ -12,11 +12,11 @@ from common.utils import (
12
  from common.api import ImageMatchingAPI
13
 
14
 
15
- def test_api(config: dict = None):
16
  img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
17
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
18
- image0 = cv2.imread(str(img_path1))[:, :, ::-1]
19
- image1 = cv2.imread(str(img_path2))[:, :, ::-1]
20
 
21
  matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
22
  for k, v in matcher_zoo_restored.items():
@@ -27,15 +27,46 @@ def test_api(config: dict = None):
27
  logger.info(f"Testing {k} ...")
28
  api = ImageMatchingAPI(conf=v, device=device)
29
  api(image0, image1)
30
- log_path = ROOT / "experiments1"
31
  log_path.mkdir(exist_ok=True, parents=True)
32
  api.visualize(log_path=log_path)
33
  else:
34
  logger.info(f"Skipping {k} ...")
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if __name__ == "__main__":
38
  import argparse
39
 
40
  config = load_config(ROOT / "common/config.yaml")
41
- test_api(config)
 
 
12
  from common.api import ImageMatchingAPI
13
 
14
 
15
+ def test_all(config: dict = None):
16
  img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
17
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
18
+ image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
19
+ image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
20
 
21
  matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
22
  for k, v in matcher_zoo_restored.items():
 
27
  logger.info(f"Testing {k} ...")
28
  api = ImageMatchingAPI(conf=v, device=device)
29
  api(image0, image1)
30
+ log_path = ROOT / "experiments" / "all"
31
  log_path.mkdir(exist_ok=True, parents=True)
32
  api.visualize(log_path=log_path)
33
  else:
34
  logger.info(f"Skipping {k} ...")
35
 
36
 
37
+ def test_one():
38
+ img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
39
+ img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
40
+ image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
41
+ image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
42
+
43
+ conf = {
44
+ "matcher": {
45
+ "output": "matches-omniglue",
46
+ "model": {
47
+ "name": "omniglue",
48
+ "match_threshold": 0.2,
49
+ "features": "null",
50
+ },
51
+ "preprocessing": {
52
+ "grayscale": False,
53
+ "resize_max": 1024,
54
+ "dfactor": 8,
55
+ "force_resize": False,
56
+ },
57
+ },
58
+ "dense": True,
59
+ }
60
+ api = ImageMatchingAPI(conf=conf, device=device)
61
+ api(image0, image1)
62
+ log_path = ROOT / "experiments" / "one"
63
+ log_path.mkdir(exist_ok=True, parents=True)
64
+ api.visualize(log_path=log_path)
65
+
66
+
67
  if __name__ == "__main__":
68
  import argparse
69
 
70
  config = load_config(ROOT / "common/config.yaml")
71
+ test_one()
72
+ test_all(config)