Vincentqyw commited on
Commit
9223079
1 Parent(s): 71bbcb3

add: files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +19 -0
  3. .gitmodules +45 -0
  4. README.md +107 -12
  5. app.py +291 -0
  6. assets/demo.gif +3 -0
  7. assets/gui.jpg +0 -0
  8. datasets/.gitignore +0 -0
  9. datasets/lines/terrace0.JPG +0 -0
  10. datasets/lines/terrace1.JPG +0 -0
  11. datasets/sacre_coeur/README.md +3 -0
  12. datasets/sacre_coeur/mapping/02928139_3448003521.jpg +0 -0
  13. datasets/sacre_coeur/mapping/03903474_1471484089.jpg +0 -0
  14. datasets/sacre_coeur/mapping/10265353_3838484249.jpg +0 -0
  15. datasets/sacre_coeur/mapping/17295357_9106075285.jpg +0 -0
  16. datasets/sacre_coeur/mapping/32809961_8274055477.jpg +0 -0
  17. datasets/sacre_coeur/mapping/44120379_8371960244.jpg +0 -0
  18. datasets/sacre_coeur/mapping/51091044_3486849416.jpg +0 -0
  19. datasets/sacre_coeur/mapping/60584745_2207571072.jpg +0 -0
  20. datasets/sacre_coeur/mapping/71295362_4051449754.jpg +0 -0
  21. datasets/sacre_coeur/mapping/93341989_396310999.jpg +0 -0
  22. extra_utils/__init__.py +0 -0
  23. extra_utils/plotting.py +504 -0
  24. extra_utils/utils.py +182 -0
  25. extra_utils/visualize_util.py +642 -0
  26. hloc/__init__.py +31 -0
  27. hloc/extract_features.py +516 -0
  28. hloc/extractors/__init__.py +0 -0
  29. hloc/extractors/alike.py +52 -0
  30. hloc/extractors/cosplace.py +44 -0
  31. hloc/extractors/d2net.py +57 -0
  32. hloc/extractors/darkfeat.py +57 -0
  33. hloc/extractors/dedode.py +102 -0
  34. hloc/extractors/dir.py +76 -0
  35. hloc/extractors/disk.py +32 -0
  36. hloc/extractors/dog.py +131 -0
  37. hloc/extractors/example.py +58 -0
  38. hloc/extractors/fire.py +73 -0
  39. hloc/extractors/fire_local.py +90 -0
  40. hloc/extractors/lanet.py +53 -0
  41. hloc/extractors/netvlad.py +147 -0
  42. hloc/extractors/openibl.py +26 -0
  43. hloc/extractors/r2d2.py +61 -0
  44. hloc/extractors/rekd.py +53 -0
  45. hloc/extractors/superpoint.py +44 -0
  46. hloc/match_dense.py +384 -0
  47. hloc/match_features.py +389 -0
  48. hloc/matchers/__init__.py +3 -0
  49. hloc/matchers/adalam.py +69 -0
  50. hloc/matchers/aspanformer.py +76 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build/
2
+
3
+ lib/
4
+ bin/
5
+
6
+ cmake_modules/
7
+ cmake-build-debug/
8
+ .idea/
9
+ .vscode/
10
+ *.pyc
11
+ flagged
12
+ .ipynb_checkpoints
13
+ __pycache__
14
+ Untitled*
15
+ experiments
16
+ third_party/REKD
17
+ Dockerfile
18
+ hloc/matchers/dedode.py
19
+ gradio_cached_examples
.gitmodules ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "third_party/Roma"]
2
+ path = third_party/Roma
3
+ url = https://github.com/Vincentqyw/RoMa.git
4
+ [submodule "third_party/SuperGluePretrainedNetwork"]
5
+ path = third_party/SuperGluePretrainedNetwork
6
+ url = https://github.com/magicleap/SuperGluePretrainedNetwork.git
7
+ [submodule "third_party/SOLD2"]
8
+ path = third_party/SOLD2
9
+ url = https://github.com/cvg/SOLD2.git
10
+ [submodule "third_party/GlueStick"]
11
+ path = third_party/GlueStick
12
+ url = https://github.com/cvg/GlueStick.git
13
+ [submodule "third_party/ASpanFormer"]
14
+ path = third_party/ASpanFormer
15
+ url = https://github.com/Vincentqyw/ml-aspanformer.git
16
+ [submodule "third_party/TopicFM"]
17
+ path = third_party/TopicFM
18
+ url = https://github.com/Vincentqyw/TopicFM.git
19
+ [submodule "third_party/d2net"]
20
+ path = third_party/d2net
21
+ url = https://github.com/Vincentqyw/d2-net.git
22
+ [submodule "third_party/r2d2"]
23
+ path = third_party/r2d2
24
+ url = https://github.com/naver/r2d2.git
25
+ [submodule "third_party/DKM"]
26
+ path = third_party/DKM
27
+ url = https://github.com/Vincentqyw/DKM.git
28
+ [submodule "third_party/ALIKE"]
29
+ path = third_party/ALIKE
30
+ url = https://github.com/Shiaoming/ALIKE.git
31
+ [submodule "third_party/lanet"]
32
+ path = third_party/lanet
33
+ url = https://github.com/wangch-g/lanet.git
34
+ [submodule "third_party/LightGlue"]
35
+ path = third_party/LightGlue
36
+ url = https://github.com/cvg/LightGlue.git
37
+ [submodule "third_party/SGMNet"]
38
+ path = third_party/SGMNet
39
+ url = https://github.com/vdvchen/SGMNet.git
40
+ [submodule "third_party/DarkFeat"]
41
+ path = third_party/DarkFeat
42
+ url = https://github.com/THU-LYJ-Lab/DarkFeat.git
43
+ [submodule "third_party/DeDoDe"]
44
+ path = third_party/DeDoDe
45
+ url = https://github.com/Parskatt/DeDoDe.git
README.md CHANGED
@@ -1,12 +1,107 @@
1
- ---
2
- title: Image Matching Webui
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 3.35.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![Contributors][contributors-shield]][contributors-url]
2
+ [![Forks][forks-shield]][forks-url]
3
+ [![Stargazers][stars-shield]][stars-url]
4
+ [![Issues][issues-shield]][issues-url]
5
+
6
+ <p align="center">
7
+ <h1 align="center"><br><ins>Image Matching WebUI</ins><br>find matches between 2 images</h1>
8
+ </p>
9
+
10
+ ## Description
11
+
12
+ This simple tool efficiently matches image pairs using multiple famous image matching algorithms. The tool features a Graphical User Interface (GUI) designed using [gradio](https://gradio.app/). You can effortlessly select two images and a matching algorithm and obtain a precise matching result.
13
+ **Note**: the images source can be either local images or webcam images.
14
+
15
+ Here is a demo of the tool:
16
+
17
+ ![demo](assets/demo.gif)
18
+
19
+ The tool currently supports various popular image matching algorithms, namely:
20
+ - [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
21
+ - [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), TBD
22
+ - [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
23
+ - [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
24
+ - [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
25
+ - [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
26
+ - [x] [GlueStick](https://github.com/cvg/GlueStick), ArXiv 2023
27
+ - [ ] [ConvMatch](https://github.com/SuhZhang/ConvMatch), AAAI 2023
28
+ - [x] [SOLD2](https://github.com/cvg/SOLD2), CVPR 2021
29
+ - [ ] [LineTR](https://github.com/yosungho/LineTR), RA-L 2021
30
+ - [x] [DKM](https://github.com/Parskatt/DKM), CVPR 2023
31
+ - [x] [RoMa](https://github.com/Vincentqyw/RoMa), Arxiv 2023
32
+ - [ ] [NCMNet](https://github.com/xinliu29/NCMNet), CVPR 2023
33
+ - [x] [TopicFM](https://github.com/Vincentqyw/TopicFM), AAAI 2023
34
+ - [x] [AspanFormer](https://github.com/Vincentqyw/ml-aspanformer), ECCV 2022
35
+ - [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
36
+ - [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
37
+ - [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
38
+ - [x] [ALIKE](https://github.com/Shiaoming/ALIKE), ArXiv 2022
39
+ - [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
40
+ - [x] [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), CVPRW 2018
41
+ - [x] [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), CVPR 2020
42
+ - [x] [D2Net](https://github.com/Vincentqyw/d2-net), CVPR 2019
43
+ - [x] [R2D2](https://github.com/naver/r2d2), NeurIPS 2019
44
+ - [x] [DISK](https://github.com/cvlab-epfl/disk), NeurIPS 2020
45
+ - [ ] [Key.Net](https://github.com/axelBarroso/Key.Net), ICCV 2019
46
+ - [ ] [OANet](https://github.com/zjhthu/OANet), ICCV 2019
47
+ - [ ] [SOSNet](https://github.com/scape-research/SOSNet), CVPR 2019
48
+ - [x] [SIFT](https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html), IJCV 2004
49
+
50
+ ## How to use
51
+
52
+ ### requirements
53
+ ``` bash
54
+ git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git
55
+ cd image-matching-webui
56
+ conda env create -f environment.yaml
57
+ conda activate imw
58
+ ```
59
+
60
+ ### run demo
61
+ ``` bash
62
+ python3 ./app.py
63
+ ```
64
+ then open http://localhost:7860 in your browser.
65
+
66
+ ![](assets/gui.jpg)
67
+
68
+ ### Add your own feature / matcher
69
+
70
+ I provide an example to add local feature in [hloc/extractors/example.py](hloc/extractors/example.py). Then add feature settings in `confs` in file [hloc/extract_features.py](hloc/extract_features.py). Last step is adding some settings to `model_zoo` in file [extra_utils/utils.py](extra_utils/utils.py).
71
+
72
+ ## Contributions welcome!
73
+
74
+ External contributions are very much welcome. Please follow the [PEP8 style guidelines](https://www.python.org/dev/peps/pep-0008/) using a linter like flake8 (reformat using command `python -m black .`). This is a non-exhaustive list of features that might be valuable additions:
75
+
76
+ - [x] add webcam support
77
+ - [x] add [line feature matching](https://github.com/Vincentqyw/LineSegmentsDetection) algorithms
78
+ - [x] example to add a new feature extractor / matcher
79
+ - [ ] ransac to filter outliers
80
+ - [ ] support export matches to colmap ([#issue 6](https://github.com/Vincentqyw/image-matching-webui/issues/6))
81
+ - [ ] add config file to set default parameters
82
+ - [ ] dynamically load models and reduce GPU overload
83
+
84
+ Adding local features / matchers as submodules is very easy. For example, to add the [GlueStick](https://github.com/cvg/GlueStick):
85
+
86
+ ``` bash
87
+ git submodule add https://github.com/cvg/GlueStick.git third_party/GlueStick
88
+ ```
89
+
90
+ If remote submodule repositories are updated, don't forget to pull submodules with `git submodule update --remote`, if you only want to update one submodule, use `git submodule update --remote third_party/GlueStick`.
91
+
92
+ ## Resources
93
+ - [Image Matching: Local Features & Beyond](https://image-matching-workshop.github.io)
94
+ - [Long-term Visual Localization](https://www.visuallocalization.net)
95
+
96
+ ## Acknowledgement
97
+
98
+ This code is built based on [Hierarchical-Localization](https://github.com/cvg/Hierarchical-Localization). We express our gratitude to the authors for their valuable source code.
99
+
100
+ [contributors-shield]: https://img.shields.io/github/contributors/Vincentqyw/image-matching-webui.svg?style=for-the-badge
101
+ [contributors-url]: https://github.com/Vincentqyw/image-matching-webui/graphs/contributors
102
+ [forks-shield]: https://img.shields.io/github/forks/Vincentqyw/image-matching-webui.svg?style=for-the-badge
103
+ [forks-url]: https://github.com/Vincentqyw/image-matching-webui/network/members
104
+ [stars-shield]: https://img.shields.io/github/stars/Vincentqyw/image-matching-webui.svg?style=for-the-badge
105
+ [stars-url]: https://github.com/Vincentqyw/image-matching-webui/stargazers
106
+ [issues-shield]: https://img.shields.io/github/issues/Vincentqyw/image-matching-webui.svg?style=for-the-badge
107
+ [issues-url]: https://github.com/Vincentqyw/image-matching-webui/issues
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+
4
+ from hloc import extract_features
5
+ from extra_utils.utils import (
6
+ matcher_zoo,
7
+ device,
8
+ match_dense,
9
+ match_features,
10
+ get_model,
11
+ get_feature_model,
12
+ display_matches
13
+ )
14
+
15
+ def run_matching(
16
+ match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
17
+ ):
18
+ # image0 and image1 is RGB mode
19
+ if image0 is None or image1 is None:
20
+ raise gr.Error("Error: No images found! Please upload two images.")
21
+
22
+ model = matcher_zoo[key]
23
+ match_conf = model["config"]
24
+ # update match config
25
+ match_conf["model"]["match_threshold"] = match_threshold
26
+ match_conf["model"]["max_keypoints"] = extract_max_keypoints
27
+
28
+ matcher = get_model(match_conf)
29
+ if model["dense"]:
30
+ pred = match_dense.match_images(
31
+ matcher, image0, image1, match_conf["preprocessing"], device=device
32
+ )
33
+ del matcher
34
+ extract_conf = None
35
+ else:
36
+ extract_conf = model["config_feature"]
37
+ # update extract config
38
+ extract_conf["model"]["max_keypoints"] = extract_max_keypoints
39
+ extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
40
+ extractor = get_feature_model(extract_conf)
41
+ pred0 = extract_features.extract(
42
+ extractor, image0, extract_conf["preprocessing"]
43
+ )
44
+ pred1 = extract_features.extract(
45
+ extractor, image1, extract_conf["preprocessing"]
46
+ )
47
+ pred = match_features.match_images(matcher, pred0, pred1)
48
+ del extractor
49
+ fig, num_inliers = display_matches(pred)
50
+ del pred
51
+ return (
52
+ fig,
53
+ {"matches number": num_inliers},
54
+ {"match_conf": match_conf, "extractor_conf": extract_conf},
55
+ )
56
+
57
+
58
+ def ui_change_imagebox(choice):
59
+ return {"value": None, "source": choice, "__type__": "update"}
60
+
61
+
62
+ def ui_reset_state(
63
+ match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
64
+ ):
65
+ match_threshold = 0.2
66
+ extract_max_keypoints = 1000
67
+ keypoint_threshold = 0.015
68
+ key = list(matcher_zoo.keys())[0]
69
+ image0 = None
70
+ image1 = None
71
+ return (
72
+ match_threshold,
73
+ extract_max_keypoints,
74
+ keypoint_threshold,
75
+ key,
76
+ image0,
77
+ image1,
78
+ {"value": None, "source": "upload", "__type__": "update"},
79
+ {"value": None, "source": "upload", "__type__": "update"},
80
+ "upload",
81
+ None,
82
+ {},
83
+ {},
84
+ )
85
+
86
+
87
+ def run(config):
88
+ with gr.Blocks(
89
+ theme=gr.themes.Monochrome(), css="footer {visibility: hidden}"
90
+ ) as app:
91
+ gr.Markdown(
92
+ """
93
+ <p align="center">
94
+ <h1 align="center">Image Matching WebUI</h1>
95
+ </p>
96
+ """
97
+ )
98
+
99
+ with gr.Row(equal_height=False):
100
+ with gr.Column():
101
+ with gr.Row():
102
+ matcher_list = gr.Dropdown(
103
+ choices=list(matcher_zoo.keys()),
104
+ value="disk+lightglue",
105
+ label="Matching Model",
106
+ interactive=True,
107
+ )
108
+ match_image_src = gr.Radio(
109
+ ["upload", "webcam", "canvas"],
110
+ label="Image Source",
111
+ value="upload",
112
+ )
113
+
114
+ with gr.Row():
115
+ match_setting_threshold = gr.Slider(
116
+ minimum=0.0,
117
+ maximum=1,
118
+ step=0.001,
119
+ label="Match threshold",
120
+ value=0.1,
121
+ )
122
+ match_setting_max_features = gr.Slider(
123
+ minimum=10,
124
+ maximum=10000,
125
+ step=10,
126
+ label="Max number of features",
127
+ value=1000,
128
+ )
129
+ # TODO: add line settings
130
+ with gr.Row():
131
+ detect_keypoints_threshold = gr.Slider(
132
+ minimum=0,
133
+ maximum=1,
134
+ step=0.001,
135
+ label="Keypoint threshold",
136
+ value=0.015,
137
+ )
138
+ detect_line_threshold = gr.Slider(
139
+ minimum=0.1,
140
+ maximum=1,
141
+ step=0.01,
142
+ label="Line threshold",
143
+ value=0.2,
144
+ )
145
+ # matcher_lists = gr.Radio(
146
+ # ["NN-mutual", "Dual-Softmax"],
147
+ # label="Matcher mode",
148
+ # value="NN-mutual",
149
+ # )
150
+ with gr.Row():
151
+ input_image0 = gr.Image(
152
+ label="Image 0",
153
+ type="numpy",
154
+ interactive=True,
155
+ image_mode="RGB",
156
+ )
157
+ input_image1 = gr.Image(
158
+ label="Image 1",
159
+ type="numpy",
160
+ interactive=True,
161
+ image_mode="RGB",
162
+ )
163
+
164
+ with gr.Row():
165
+ button_reset = gr.Button(label="Reset", value="Reset")
166
+ button_run = gr.Button(
167
+ label="Run Match", value="Run Match", variant="primary"
168
+ )
169
+
170
+ with gr.Accordion("Open for More!", open=False):
171
+ gr.Markdown(
172
+ f"""
173
+ <h3>Supported Algorithms</h3>
174
+ {", ".join(matcher_zoo.keys())}
175
+ """
176
+ )
177
+
178
+ # collect inputs
179
+ inputs = [
180
+ match_setting_threshold,
181
+ match_setting_max_features,
182
+ detect_keypoints_threshold,
183
+ matcher_list,
184
+ input_image0,
185
+ input_image1,
186
+ ]
187
+
188
+ # Add some examples
189
+ with gr.Row():
190
+ examples = [
191
+ [
192
+ 0.1,
193
+ 2000,
194
+ 0.015,
195
+ "disk+lightglue",
196
+ "datasets/sacre_coeur/mapping/71295362_4051449754.jpg",
197
+ "datasets/sacre_coeur/mapping/93341989_396310999.jpg",
198
+ ],
199
+ [
200
+ 0.1,
201
+ 2000,
202
+ 0.015,
203
+ "loftr",
204
+ "datasets/sacre_coeur/mapping/03903474_1471484089.jpg",
205
+ "datasets/sacre_coeur/mapping/02928139_3448003521.jpg",
206
+ ],
207
+ [
208
+ 0.1,
209
+ 2000,
210
+ 0.015,
211
+ "disk",
212
+ "datasets/sacre_coeur/mapping/10265353_3838484249.jpg",
213
+ "datasets/sacre_coeur/mapping/51091044_3486849416.jpg",
214
+ ],
215
+ [
216
+ 0.1,
217
+ 2000,
218
+ 0.015,
219
+ "topicfm",
220
+ "datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
221
+ "datasets/sacre_coeur/mapping/93341989_396310999.jpg",
222
+ ],
223
+ [
224
+ 0.1,
225
+ 2000,
226
+ 0.015,
227
+ "superpoint+superglue",
228
+ "datasets/sacre_coeur/mapping/17295357_9106075285.jpg",
229
+ "datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
230
+ ],
231
+ ]
232
+ # Example inputs
233
+ gr.Examples(
234
+ examples=examples,
235
+ inputs=inputs,
236
+ outputs=[],
237
+ fn=run_matching,
238
+ cache_examples=False,
239
+ label="Examples (click one of the images below to Run Match)",
240
+ )
241
+
242
+ with gr.Column():
243
+ output_mkpts = gr.Image(label="Keypoints Matching", type="numpy")
244
+ matches_result_info = gr.JSON(label="Matches Statistics")
245
+ matcher_info = gr.JSON(label="Match info")
246
+
247
+ # callbacks
248
+ match_image_src.change(
249
+ fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image0
250
+ )
251
+ match_image_src.change(
252
+ fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image1
253
+ )
254
+
255
+ # collect outputs
256
+ outputs = [
257
+ output_mkpts,
258
+ matches_result_info,
259
+ matcher_info,
260
+ ]
261
+ # button callbacks
262
+ button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
263
+
264
+ # Reset images
265
+ reset_outputs = [
266
+ match_setting_threshold,
267
+ match_setting_max_features,
268
+ detect_keypoints_threshold,
269
+ matcher_list,
270
+ input_image0,
271
+ input_image1,
272
+ input_image0,
273
+ input_image1,
274
+ match_image_src,
275
+ output_mkpts,
276
+ matches_result_info,
277
+ matcher_info,
278
+ ]
279
+ button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
280
+
281
+ app.launch(share=True)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ parser = argparse.ArgumentParser()
286
+ parser.add_argument(
287
+ "--config_path", type=str, default="config.yaml", help="configuration file path"
288
+ )
289
+ args = parser.parse_args()
290
+ config = None
291
+ run(config)
assets/demo.gif ADDED

Git LFS Details

  • SHA256: 3f163c0e2699181897c81c68e01c60fa4289e886a2a40932d53dd529262d3735
  • Pointer size: 132 Bytes
  • Size of remote file: 8.91 MB
assets/gui.jpg ADDED
datasets/.gitignore ADDED
File without changes
datasets/lines/terrace0.JPG ADDED
datasets/lines/terrace1.JPG ADDED
datasets/sacre_coeur/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Sacre Coeur demo
2
+
3
+ We provide here a subset of images depicting the Sacre Coeur. These images were obtained from the [Image Matching Challenge 2021](https://www.cs.ubc.ca/research/image-matching-challenge/2021/data/) and were originally collected by the [Yahoo Flickr Creative Commons 100M (YFCC) dataset](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).
datasets/sacre_coeur/mapping/02928139_3448003521.jpg ADDED
datasets/sacre_coeur/mapping/03903474_1471484089.jpg ADDED
datasets/sacre_coeur/mapping/10265353_3838484249.jpg ADDED
datasets/sacre_coeur/mapping/17295357_9106075285.jpg ADDED
datasets/sacre_coeur/mapping/32809961_8274055477.jpg ADDED
datasets/sacre_coeur/mapping/44120379_8371960244.jpg ADDED
datasets/sacre_coeur/mapping/51091044_3486849416.jpg ADDED
datasets/sacre_coeur/mapping/60584745_2207571072.jpg ADDED
datasets/sacre_coeur/mapping/71295362_4051449754.jpg ADDED
datasets/sacre_coeur/mapping/93341989_396310999.jpg ADDED
extra_utils/__init__.py ADDED
File without changes
extra_utils/plotting.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib, os, cv2
5
+ import matplotlib.cm as cm
6
+ from PIL import Image
7
+ import torch.nn.functional as F
8
+ import torch
9
+
10
+
11
+ def _compute_conf_thresh(data):
12
+ dataset_name = data["dataset_name"][0].lower()
13
+ if dataset_name == "scannet":
14
+ thr = 5e-4
15
+ elif dataset_name == "megadepth":
16
+ thr = 1e-4
17
+ else:
18
+ raise ValueError(f"Unknown dataset: {dataset_name}")
19
+ return thr
20
+
21
+
22
+ # --- VISUALIZATION --- #
23
+
24
+
25
+ def make_matching_figure(
26
+ img0,
27
+ img1,
28
+ mkpts0,
29
+ mkpts1,
30
+ color,
31
+ titles=None,
32
+ kpts0=None,
33
+ kpts1=None,
34
+ text=[],
35
+ dpi=75,
36
+ path=None,
37
+ pad=0,
38
+ ):
39
+ # draw image pair
40
+ # assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
41
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
42
+ axes[0].imshow(img0) # , cmap='gray')
43
+ axes[1].imshow(img1) # , cmap='gray')
44
+ for i in range(2): # clear all frames
45
+ axes[i].get_yaxis().set_ticks([])
46
+ axes[i].get_xaxis().set_ticks([])
47
+ for spine in axes[i].spines.values():
48
+ spine.set_visible(False)
49
+ if titles is not None:
50
+ axes[i].set_title(titles[i])
51
+
52
+ plt.tight_layout(pad=pad)
53
+
54
+ if kpts0 is not None:
55
+ assert kpts1 is not None
56
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5)
57
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
58
+
59
+ # draw matches
60
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
61
+ fig.canvas.draw()
62
+ transFigure = fig.transFigure.inverted()
63
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
64
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
65
+ fig.lines = [
66
+ matplotlib.lines.Line2D(
67
+ (fkpts0[i, 0], fkpts1[i, 0]),
68
+ (fkpts0[i, 1], fkpts1[i, 1]),
69
+ transform=fig.transFigure,
70
+ c=color[i],
71
+ linewidth=2,
72
+ )
73
+ for i in range(len(mkpts0))
74
+ ]
75
+
76
+ # freeze the axes to prevent the transform to change
77
+ axes[0].autoscale(enable=False)
78
+ axes[1].autoscale(enable=False)
79
+
80
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4)
81
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4)
82
+
83
+ # put txts
84
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
85
+ fig.text(
86
+ 0.01,
87
+ 0.99,
88
+ "\n".join(text),
89
+ transform=fig.axes[0].transAxes,
90
+ fontsize=15,
91
+ va="top",
92
+ ha="left",
93
+ color=txt_color,
94
+ )
95
+
96
+ # save or return figure
97
+ if path:
98
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
99
+ plt.close()
100
+ else:
101
+ return fig
102
+
103
+
104
+ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
105
+ b_mask = data["m_bids"] == b_id
106
+ conf_thr = _compute_conf_thresh(data)
107
+
108
+ img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
109
+ img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
110
+ kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
111
+ kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
112
+
113
+ # for megadepth, we visualize matches on the resized image
114
+ if "scale0" in data:
115
+ kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
116
+ kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
117
+
118
+ epi_errs = data["epi_errs"][b_mask].cpu().numpy()
119
+ correct_mask = epi_errs < conf_thr
120
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
121
+ n_correct = np.sum(correct_mask)
122
+ n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
123
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
124
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
125
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
126
+
127
+ # matching info
128
+ if alpha == "dynamic":
129
+ alpha = dynamic_alpha(len(correct_mask))
130
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
131
+
132
+ text = [
133
+ f"#Matches {len(kpts0)}",
134
+ f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
135
+ f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
136
+ ]
137
+
138
+ # make the figure
139
+ figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
140
+ return figure
141
+
142
+
143
+ def _make_confidence_figure(data, b_id):
144
+ # TODO: Implement confidence figure
145
+ raise NotImplementedError()
146
+
147
+
148
+ def make_matching_figures(data, config, mode="evaluation"):
149
+ """Make matching figures for a batch.
150
+
151
+ Args:
152
+ data (Dict): a batch updated by PL_LoFTR.
153
+ config (Dict): matcher config
154
+ Returns:
155
+ figures (Dict[str, List[plt.figure]]
156
+ """
157
+ assert mode in ["evaluation", "confidence"] # 'confidence'
158
+ figures = {mode: []}
159
+ for b_id in range(data["image0"].size(0)):
160
+ if mode == "evaluation":
161
+ fig = _make_evaluation_figure(
162
+ data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
163
+ )
164
+ elif mode == "confidence":
165
+ fig = _make_confidence_figure(data, b_id)
166
+ else:
167
+ raise ValueError(f"Unknown plot mode: {mode}")
168
+ figures[mode].append(fig)
169
+ return figures
170
+
171
+
172
+ def dynamic_alpha(
173
+ n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
174
+ ):
175
+ if n_matches == 0:
176
+ return 1.0
177
+ ranges = list(zip(alphas, alphas[1:] + [None]))
178
+ loc = bisect.bisect_right(milestones, n_matches) - 1
179
+ _range = ranges[loc]
180
+ if _range[1] is None:
181
+ return _range[0]
182
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
183
+ milestones[loc + 1] - milestones[loc]
184
+ ) * (_range[0] - _range[1])
185
+
186
+
187
+ def error_colormap(err, thr, alpha=1.0):
188
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
189
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
190
+ return np.clip(
191
+ np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
192
+ 0,
193
+ 1,
194
+ )
195
+
196
+
197
+ np.random.seed(1995)
198
+ color_map = np.arange(100)
199
+ np.random.shuffle(color_map)
200
+
201
+
202
+ def draw_topics(
203
+ data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None
204
+ ):
205
+
206
+ topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
207
+ hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
208
+ hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
209
+ # print(hw0_i, hw1_i)
210
+ scale0, scale1 = hw0_i[0] // hw0_c[0], hw1_i[0] // hw1_c[0]
211
+ if "scale0" in data:
212
+ scale0 *= data["scale0"][0]
213
+ else:
214
+ scale0 = (scale0, scale0)
215
+ if "scale1" in data:
216
+ scale1 *= data["scale1"][0]
217
+ else:
218
+ scale1 = (scale1, scale1)
219
+
220
+ n_topics = topic0.shape[-1]
221
+ # mask0_nonzero = topic0[0].sum(dim=-1, keepdim=True) > 0
222
+ # mask1_nonzero = topic1[0].sum(dim=-1, keepdim=True) > 0
223
+ theta0 = topic0[0].sum(dim=0)
224
+ theta0 /= theta0.sum().float()
225
+ theta1 = topic1[0].sum(dim=0)
226
+ theta1 /= theta1.sum().float()
227
+ # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics]
228
+ # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics]
229
+ top_topics = torch.argsort(theta0 * theta1, descending=True)[:show_n_topics]
230
+ # print(sum_topic0, sum_topic1)
231
+
232
+ topic0 = topic0[0].argmax(
233
+ dim=-1, keepdim=True
234
+ ) # .float() / (n_topics - 1) #* 255 + 1 #
235
+ # topic0[~mask0_nonzero] = -1
236
+ topic1 = topic1[0].argmax(
237
+ dim=-1, keepdim=True
238
+ ) # .float() / (n_topics - 1) #* 255 + 1
239
+ # topic1[~mask1_nonzero] = -1
240
+ label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1
241
+ for i, k in enumerate(top_topics):
242
+ label_img0[topic0 == k] = color_map[k]
243
+ label_img1[topic1 == k] = color_map[k]
244
+
245
+ # print(hw0_c, scale0)
246
+ # print(hw1_c, scale1)
247
+ # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0)
248
+ map_topic0 = (
249
+ label_img0.float().view(hw0_c).cpu().numpy()
250
+ ) # map_topic0.squeeze(0).squeeze(0).cpu().numpy()
251
+ map_topic0 = cv2.resize(
252
+ map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1]))
253
+ )
254
+ # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1)
255
+ map_topic1 = (
256
+ label_img1.float().view(hw1_c).cpu().numpy()
257
+ ) # map_topic1.squeeze(0).squeeze(0).cpu().numpy()
258
+ map_topic1 = cv2.resize(
259
+ map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1]))
260
+ )
261
+
262
+ # show image0
263
+ if saved_name is None:
264
+ return map_topic0, map_topic1
265
+
266
+ if not os.path.exists(saved_folder):
267
+ os.makedirs(saved_folder)
268
+ path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name))
269
+ plt.imshow(img0)
270
+ masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0)
271
+ plt.imshow(
272
+ masked_map_topic0,
273
+ cmap=plt.cm.jet,
274
+ vmin=0,
275
+ vmax=n_topics - 1,
276
+ alpha=0.3,
277
+ interpolation="bilinear",
278
+ )
279
+ # plt.show()
280
+ plt.axis("off")
281
+ plt.savefig(path_saved_img0, bbox_inches="tight", pad_inches=0, dpi=250)
282
+ plt.close()
283
+
284
+ path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name))
285
+ plt.imshow(img1)
286
+ masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1)
287
+ plt.imshow(
288
+ masked_map_topic1,
289
+ cmap=plt.cm.jet,
290
+ vmin=0,
291
+ vmax=n_topics - 1,
292
+ alpha=0.3,
293
+ interpolation="bilinear",
294
+ )
295
+ plt.axis("off")
296
+ plt.savefig(path_saved_img1, bbox_inches="tight", pad_inches=0, dpi=250)
297
+ plt.close()
298
+
299
+
300
+ def draw_topicfm_demo(
301
+ data,
302
+ img0,
303
+ img1,
304
+ mkpts0,
305
+ mkpts1,
306
+ mcolor,
307
+ text,
308
+ show_n_topics=8,
309
+ topic_alpha=0.3,
310
+ margin=5,
311
+ path=None,
312
+ opencv_display=False,
313
+ opencv_title="",
314
+ ):
315
+ topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics)
316
+
317
+ mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(
318
+ topic_map1 >= 0, axis=-1
319
+ )
320
+
321
+ topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
322
+ topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
323
+ topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
324
+ overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
325
+ overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
326
+
327
+ cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
328
+ cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
329
+
330
+ overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(
331
+ np.uint8
332
+ )
333
+
334
+ h0, w0 = img0.shape[:2]
335
+ h1, w1 = img1.shape[:2]
336
+ h, w = h0 * 2 + margin * 2, w0 * 2 + margin
337
+ out_fig = 255 * np.ones((h, w, 3), dtype=np.uint8)
338
+ out_fig[:h0, :w0] = overlay0
339
+ if h0 >= h1:
340
+ start = (h0 - h1) // 2
341
+ out_fig[start : (start + h1), (w0 + margin) : (w0 + margin + w1)] = overlay1
342
+ else:
343
+ start = (h1 - h0) // 2
344
+ out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
345
+ start : (start + h0)
346
+ ]
347
+
348
+ step_h = h0 + margin * 2
349
+ out_fig[step_h : step_h + h0, :w0] = (img0 * 255).astype(np.uint8)
350
+ if h0 >= h1:
351
+ start = step_h + (h0 - h1) // 2
352
+ out_fig[start : start + h1, (w0 + margin) : (w0 + margin + w1)] = (
353
+ img1 * 255
354
+ ).astype(np.uint8)
355
+ else:
356
+ start = (h1 - h0) // 2
357
+ out_fig[step_h : step_h + h0, (w0 + margin) : (w0 + margin + w1)] = (
358
+ img1[start : start + h0] * 255
359
+ ).astype(np.uint8)
360
+
361
+ # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
362
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
363
+ mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
364
+
365
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor):
366
+ c = c.tolist()
367
+ cv2.line(
368
+ out_fig,
369
+ (x0, y0 + step_h),
370
+ (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
371
+ color=c,
372
+ thickness=1,
373
+ lineType=cv2.LINE_AA,
374
+ )
375
+ # display line end-points as circles
376
+ cv2.circle(out_fig, (x0, y0 + step_h), 2, c, -1, lineType=cv2.LINE_AA)
377
+ cv2.circle(
378
+ out_fig,
379
+ (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
380
+ 2,
381
+ c,
382
+ -1,
383
+ lineType=cv2.LINE_AA,
384
+ )
385
+
386
+ # Scale factor for consistent visualization across scales.
387
+ sc = min(h / 960.0, 2.0)
388
+
389
+ # Big text.
390
+ Ht = int(30 * sc) # text height
391
+ txt_color_fg = (255, 255, 255)
392
+ txt_color_bg = (0, 0, 0)
393
+ for i, t in enumerate(text):
394
+ cv2.putText(
395
+ out_fig,
396
+ t,
397
+ (int(8 * sc), Ht + step_h * i),
398
+ cv2.FONT_HERSHEY_DUPLEX,
399
+ 1.0 * sc,
400
+ txt_color_bg,
401
+ 2,
402
+ cv2.LINE_AA,
403
+ )
404
+ cv2.putText(
405
+ out_fig,
406
+ t,
407
+ (int(8 * sc), Ht + step_h * i),
408
+ cv2.FONT_HERSHEY_DUPLEX,
409
+ 1.0 * sc,
410
+ txt_color_fg,
411
+ 1,
412
+ cv2.LINE_AA,
413
+ )
414
+
415
+ if path is not None:
416
+ cv2.imwrite(str(path), out_fig)
417
+
418
+ if opencv_display:
419
+ cv2.imshow(opencv_title, out_fig)
420
+ cv2.waitKey(1)
421
+
422
+ return out_fig
423
+
424
+
425
+ def fig2im(fig):
426
+ fig.canvas.draw()
427
+ w, h = fig.canvas.get_width_height()
428
+ buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
429
+ im = buf_ndarray.reshape(h, w, 3)
430
+ return im
431
+
432
+
433
+ def draw_matches(
434
+ mkpts0, mkpts1, img0, img1, conf, titles=None, dpi=150, path=None, pad=0.5
435
+ ):
436
+ thr = 5e-4
437
+ thr = 0.5
438
+ color = error_colormap(conf, thr, alpha=0.1)
439
+ text = [
440
+ f"image name",
441
+ f"#Matches: {len(mkpts0)}",
442
+ ]
443
+ if path:
444
+ fig2im(
445
+ make_matching_figure(
446
+ img0,
447
+ img1,
448
+ mkpts0,
449
+ mkpts1,
450
+ color,
451
+ titles=titles,
452
+ text=text,
453
+ path=path,
454
+ dpi=dpi,
455
+ pad=pad,
456
+ )
457
+ )
458
+ else:
459
+ return fig2im(
460
+ make_matching_figure(
461
+ img0,
462
+ img1,
463
+ mkpts0,
464
+ mkpts1,
465
+ color,
466
+ titles=titles,
467
+ text=text,
468
+ pad=pad,
469
+ dpi=dpi,
470
+ )
471
+ )
472
+
473
+
474
+ def draw_image_pairs(img0, img1, text=[], dpi=75, path=None, pad=0.5):
475
+ # draw image pair
476
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
477
+ axes[0].imshow(img0) # , cmap='gray')
478
+ axes[1].imshow(img1) # , cmap='gray')
479
+ for i in range(2): # clear all frames
480
+ axes[i].get_yaxis().set_ticks([])
481
+ axes[i].get_xaxis().set_ticks([])
482
+ for spine in axes[i].spines.values():
483
+ spine.set_visible(False)
484
+ plt.tight_layout(pad=pad)
485
+
486
+ # put txts
487
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
488
+ fig.text(
489
+ 0.01,
490
+ 0.99,
491
+ "\n".join(text),
492
+ transform=fig.axes[0].transAxes,
493
+ fontsize=15,
494
+ va="top",
495
+ ha="left",
496
+ color=txt_color,
497
+ )
498
+
499
+ # save or return figure
500
+ if path:
501
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
502
+ plt.close()
503
+ else:
504
+ return fig2im(fig)
extra_utils/utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from hloc import matchers, extractors
5
+ from hloc.utils.base_model import dynamic_load
6
+ from hloc import match_dense, match_features, extract_features
7
+ from .plotting import draw_matches, fig2im
8
+ from .visualize_util import plot_images, plot_color_line_matches
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+
13
+ def get_model(match_conf):
14
+ Model = dynamic_load(matchers, match_conf["model"]["name"])
15
+ model = Model(match_conf["model"]).eval().to(device)
16
+ return model
17
+
18
+
19
+ def get_feature_model(conf):
20
+ Model = dynamic_load(extractors, conf["model"]["name"])
21
+ model = Model(conf["model"]).eval().to(device)
22
+ return model
23
+
24
+
25
+ def display_matches(pred: dict):
26
+ img0 = pred["image0_orig"]
27
+ img1 = pred["image1_orig"]
28
+
29
+ num_inliers = 0
30
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
31
+ mkpts0 = pred["keypoints0_orig"]
32
+ mkpts1 = pred["keypoints1_orig"]
33
+ num_inliers = len(mkpts0)
34
+ if "mconf" in pred.keys():
35
+ mconf = pred["mconf"]
36
+ else:
37
+ mconf = np.ones(len(mkpts0))
38
+ fig_mkpts = draw_matches(
39
+ mkpts0,
40
+ mkpts1,
41
+ img0,
42
+ img1,
43
+ mconf,
44
+ dpi=300,
45
+ titles=["Image 0 - matched keypoints", "Image 1 - matched keypoints"],
46
+ )
47
+ fig = fig_mkpts
48
+ if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
49
+ # lines
50
+ mtlines0 = pred["line0_orig"]
51
+ mtlines1 = pred["line1_orig"]
52
+ num_inliers = len(mtlines0)
53
+ fig_lines = plot_images(
54
+ [img0.squeeze(), img1.squeeze()],
55
+ ["Image 0 - matched lines", "Image 1 - matched lines"],
56
+ dpi=300,
57
+ )
58
+ fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2)
59
+ fig_lines = fig2im(fig_lines)
60
+
61
+ # keypoints
62
+ mkpts0 = pred["line_keypoints0_orig"]
63
+ mkpts1 = pred["line_keypoints1_orig"]
64
+
65
+ if mkpts0 is not None and mkpts1 is not None:
66
+ num_inliers = len(mkpts0)
67
+ if "mconf" in pred.keys():
68
+ mconf = pred["mconf"]
69
+ else:
70
+ mconf = np.ones(len(mkpts0))
71
+ fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
72
+ fig_lines = cv2.resize(fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]))
73
+ fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
74
+ else:
75
+ fig = fig_lines
76
+ return fig, num_inliers
77
+
78
+
79
+ # Matchers collections
80
+ matcher_zoo = {
81
+ "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
82
+ "sold2": {"config": match_dense.confs["sold2"], "dense": True},
83
+ # 'dedode-sparse': {
84
+ # 'config': match_dense.confs['dedode_sparse'],
85
+ # 'dense': True # dense mode, we need 2 images
86
+ # },
87
+ "loftr": {"config": match_dense.confs["loftr"], "dense": True},
88
+ "topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
89
+ "aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
90
+ "dedode": {
91
+ "config": match_features.confs["Dual-Softmax"],
92
+ "config_feature": extract_features.confs["dedode"],
93
+ "dense": False,
94
+ },
95
+ "superpoint+superglue": {
96
+ "config": match_features.confs["superglue"],
97
+ "config_feature": extract_features.confs["superpoint_max"],
98
+ "dense": False,
99
+ },
100
+ "superpoint+lightglue": {
101
+ "config": match_features.confs["superpoint-lightglue"],
102
+ "config_feature": extract_features.confs["superpoint_max"],
103
+ "dense": False,
104
+ },
105
+ "disk": {
106
+ "config": match_features.confs["NN-mutual"],
107
+ "config_feature": extract_features.confs["disk"],
108
+ "dense": False,
109
+ },
110
+ "disk+dualsoftmax": {
111
+ "config": match_features.confs["Dual-Softmax"],
112
+ "config_feature": extract_features.confs["disk"],
113
+ "dense": False,
114
+ },
115
+ "superpoint+dualsoftmax": {
116
+ "config": match_features.confs["Dual-Softmax"],
117
+ "config_feature": extract_features.confs["superpoint_max"],
118
+ "dense": False,
119
+ },
120
+ "disk+lightglue": {
121
+ "config": match_features.confs["disk-lightglue"],
122
+ "config_feature": extract_features.confs["disk"],
123
+ "dense": False,
124
+ },
125
+ "superpoint+mnn": {
126
+ "config": match_features.confs["NN-mutual"],
127
+ "config_feature": extract_features.confs["superpoint_max"],
128
+ "dense": False,
129
+ },
130
+ "sift+sgmnet": {
131
+ "config": match_features.confs["sgmnet"],
132
+ "config_feature": extract_features.confs["sift"],
133
+ "dense": False,
134
+ },
135
+ "sosnet": {
136
+ "config": match_features.confs["NN-mutual"],
137
+ "config_feature": extract_features.confs["sosnet"],
138
+ "dense": False,
139
+ },
140
+ "hardnet": {
141
+ "config": match_features.confs["NN-mutual"],
142
+ "config_feature": extract_features.confs["hardnet"],
143
+ "dense": False,
144
+ },
145
+ "d2net": {
146
+ "config": match_features.confs["NN-mutual"],
147
+ "config_feature": extract_features.confs["d2net-ss"],
148
+ "dense": False,
149
+ },
150
+ "d2net-ms": {
151
+ "config": match_features.confs["NN-mutual"],
152
+ "config_feature": extract_features.confs["d2net-ms"],
153
+ "dense": False,
154
+ },
155
+ "alike": {
156
+ "config": match_features.confs["NN-mutual"],
157
+ "config_feature": extract_features.confs["alike"],
158
+ "dense": False,
159
+ },
160
+ "lanet": {
161
+ "config": match_features.confs["NN-mutual"],
162
+ "config_feature": extract_features.confs["lanet"],
163
+ "dense": False,
164
+ },
165
+ "r2d2": {
166
+ "config": match_features.confs["NN-mutual"],
167
+ "config_feature": extract_features.confs["r2d2"],
168
+ "dense": False,
169
+ },
170
+ "darkfeat": {
171
+ "config": match_features.confs["NN-mutual"],
172
+ "config_feature": extract_features.confs["darkfeat"],
173
+ "dense": False,
174
+ },
175
+ "sift": {
176
+ "config": match_features.confs["NN-mutual"],
177
+ "config_feature": extract_features.confs["sift"],
178
+ "dense": False,
179
+ },
180
+ "roma": {"config": match_dense.confs["roma"], "dense": True},
181
+ "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
182
+ }
extra_utils/visualize_util.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Organize some frequently used visualization functions. """
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import copy
7
+ import seaborn as sns
8
+
9
+
10
+ # Plot junctions onto the image (return a separate copy)
11
+ def plot_junctions(input_image, junctions, junc_size=3, color=None):
12
+ """
13
+ input_image: can be 0~1 float or 0~255 uint8.
14
+ junctions: Nx2 or 2xN np array.
15
+ junc_size: the size of the plotted circles.
16
+ """
17
+ # Create image copy
18
+ image = copy.copy(input_image)
19
+ # Make sure the image is converted to 255 uint8
20
+ if image.dtype == np.uint8:
21
+ pass
22
+ # A float type image ranging from 0~1
23
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
24
+ image = (image * 255.0).astype(np.uint8)
25
+ # A float type image ranging from 0.~255.
26
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
27
+ image = image.astype(np.uint8)
28
+ else:
29
+ raise ValueError(
30
+ "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
31
+ )
32
+
33
+ # Check whether the image is single channel
34
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
35
+ # Squeeze to H*W first
36
+ image = image.squeeze()
37
+
38
+ # Stack to channle 3
39
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
40
+
41
+ # Junction dimensions should be N*2
42
+ if not len(junctions.shape) == 2:
43
+ raise ValueError("[Error] junctions should be 2-dim array.")
44
+
45
+ # Always convert to N*2
46
+ if junctions.shape[-1] != 2:
47
+ if junctions.shape[0] == 2:
48
+ junctions = junctions.T
49
+ else:
50
+ raise ValueError("[Error] At least one of the two dims should be 2.")
51
+
52
+ # Round and convert junctions to int (and check the boundary)
53
+ H, W = image.shape[:2]
54
+ junctions = (np.round(junctions)).astype(np.int)
55
+ junctions[junctions < 0] = 0
56
+ junctions[junctions[:, 0] >= H, 0] = H - 1 # (first dim) max bounded by H-1
57
+ junctions[junctions[:, 1] >= W, 1] = W - 1 # (second dim) max bounded by W-1
58
+
59
+ # Iterate through all the junctions
60
+ num_junc = junctions.shape[0]
61
+ if color is None:
62
+ color = (0, 255.0, 0)
63
+ for idx in range(num_junc):
64
+ # Fetch one junction
65
+ junc = junctions[idx, :]
66
+ cv2.circle(
67
+ image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
68
+ )
69
+
70
+ return image
71
+
72
+
73
+ # Plot line segements given junctions and line adjecent map
74
+ def plot_line_segments(
75
+ input_image,
76
+ junctions,
77
+ line_map,
78
+ junc_size=3,
79
+ color=(0, 255.0, 0),
80
+ line_width=1,
81
+ plot_survived_junc=True,
82
+ ):
83
+ """
84
+ input_image: can be 0~1 float or 0~255 uint8.
85
+ junctions: Nx2 or 2xN np array.
86
+ line_map: NxN np array
87
+ junc_size: the size of the plotted circles.
88
+ color: color of the line segments (can be string "random")
89
+ line_width: width of the drawn segments.
90
+ plot_survived_junc: whether we only plot the survived junctions.
91
+ """
92
+ # Create image copy
93
+ image = copy.copy(input_image)
94
+ # Make sure the image is converted to 255 uint8
95
+ if image.dtype == np.uint8:
96
+ pass
97
+ # A float type image ranging from 0~1
98
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
99
+ image = (image * 255.0).astype(np.uint8)
100
+ # A float type image ranging from 0.~255.
101
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
102
+ image = image.astype(np.uint8)
103
+ else:
104
+ raise ValueError(
105
+ "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
106
+ )
107
+
108
+ # Check whether the image is single channel
109
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
110
+ # Squeeze to H*W first
111
+ image = image.squeeze()
112
+
113
+ # Stack to channle 3
114
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
115
+
116
+ # Junction dimensions should be 2
117
+ if not len(junctions.shape) == 2:
118
+ raise ValueError("[Error] junctions should be 2-dim array.")
119
+
120
+ # Always convert to N*2
121
+ if junctions.shape[-1] != 2:
122
+ if junctions.shape[0] == 2:
123
+ junctions = junctions.T
124
+ else:
125
+ raise ValueError("[Error] At least one of the two dims should be 2.")
126
+
127
+ # line_map dimension should be 2
128
+ if not len(line_map.shape) == 2:
129
+ raise ValueError("[Error] line_map should be 2-dim array.")
130
+
131
+ # Color should be "random" or a list or tuple with length 3
132
+ if color != "random":
133
+ if not (isinstance(color, tuple) or isinstance(color, list)):
134
+ raise ValueError("[Error] color should have type list or tuple.")
135
+ else:
136
+ if len(color) != 3:
137
+ raise ValueError(
138
+ "[Error] color should be a list or tuple with length 3."
139
+ )
140
+
141
+ # Make a copy of the line_map
142
+ line_map_tmp = copy.copy(line_map)
143
+
144
+ # Parse line_map back to segment pairs
145
+ segments = np.zeros([0, 4])
146
+ for idx in range(junctions.shape[0]):
147
+ # if no connectivity, just skip it
148
+ if line_map_tmp[idx, :].sum() == 0:
149
+ continue
150
+ # record the line segment
151
+ else:
152
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
153
+ p1 = np.flip(junctions[idx, :]) # Convert to xy format
154
+ p2 = np.flip(junctions[idx2, :]) # Convert to xy format
155
+ segments = np.concatenate(
156
+ (segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
157
+ axis=0,
158
+ )
159
+
160
+ # Update line_map
161
+ line_map_tmp[idx, idx2] = 0
162
+ line_map_tmp[idx2, idx] = 0
163
+
164
+ # Draw segment pairs
165
+ for idx in range(segments.shape[0]):
166
+ seg = np.round(segments[idx, :]).astype(np.int)
167
+ # Decide the color
168
+ if color != "random":
169
+ color = tuple(color)
170
+ else:
171
+ color = tuple(
172
+ np.random.rand(
173
+ 3,
174
+ )
175
+ )
176
+ cv2.line(
177
+ image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
178
+ )
179
+
180
+ # Also draw the junctions
181
+ if not plot_survived_junc:
182
+ num_junc = junctions.shape[0]
183
+ for idx in range(num_junc):
184
+ # Fetch one junction
185
+ junc = junctions[idx, :]
186
+ cv2.circle(
187
+ image,
188
+ tuple(np.flip(junc)),
189
+ radius=junc_size,
190
+ color=(0, 255.0, 0),
191
+ thickness=3,
192
+ )
193
+ # Only plot the junctions which are part of a line segment
194
+ else:
195
+ for idx in range(segments.shape[0]):
196
+ seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
197
+ cv2.circle(
198
+ image,
199
+ tuple(seg[:2]),
200
+ radius=junc_size,
201
+ color=(0, 255.0, 0),
202
+ thickness=3,
203
+ )
204
+ cv2.circle(
205
+ image,
206
+ tuple(seg[2:]),
207
+ radius=junc_size,
208
+ color=(0, 255.0, 0),
209
+ thickness=3,
210
+ )
211
+
212
+ return image
213
+
214
+
215
+ # Plot line segments given Nx4 or Nx2x2 line segments
216
+ def plot_line_segments_from_segments(
217
+ input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
218
+ ):
219
+ # Create image copy
220
+ image = copy.copy(input_image)
221
+ # Make sure the image is converted to 255 uint8
222
+ if image.dtype == np.uint8:
223
+ pass
224
+ # A float type image ranging from 0~1
225
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
226
+ image = (image * 255.0).astype(np.uint8)
227
+ # A float type image ranging from 0.~255.
228
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
229
+ image = image.astype(np.uint8)
230
+ else:
231
+ raise ValueError(
232
+ "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
233
+ )
234
+
235
+ # Check whether the image is single channel
236
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
237
+ # Squeeze to H*W first
238
+ image = image.squeeze()
239
+
240
+ # Stack to channle 3
241
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
242
+
243
+ # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
244
+ H, W, _ = image.shape
245
+ # (1) Nx4 format
246
+ if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
247
+ # Round to int32
248
+ line_segments = line_segments.astype(np.int32)
249
+
250
+ # Clip H dimension
251
+ line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
252
+ line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)
253
+
254
+ # Clip W dimension
255
+ line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
256
+ line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)
257
+
258
+ # Convert to Nx2x2 format
259
+ line_segments = np.concatenate(
260
+ [
261
+ np.expand_dims(line_segments[:, :2], axis=1),
262
+ np.expand_dims(line_segments[:, 2:], axis=1),
263
+ ],
264
+ axis=1,
265
+ )
266
+
267
+ # (2) Nx2x2 format
268
+ elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
269
+ # Round to int32
270
+ line_segments = line_segments.astype(np.int32)
271
+
272
+ # Clip H dimension
273
+ line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
274
+ line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)
275
+
276
+ else:
277
+ raise ValueError(
278
+ "[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
279
+ )
280
+
281
+ # Draw segment pairs (all segments should be in HW format)
282
+ image = image.copy()
283
+ for idx in range(line_segments.shape[0]):
284
+ seg = np.round(line_segments[idx, :, :]).astype(np.int32)
285
+ # Decide the color
286
+ if color != "random":
287
+ color = tuple(color)
288
+ else:
289
+ color = tuple(
290
+ np.random.rand(
291
+ 3,
292
+ )
293
+ )
294
+ cv2.line(
295
+ image,
296
+ tuple(np.flip(seg[0, :])),
297
+ tuple(np.flip(seg[1, :])),
298
+ color=color,
299
+ thickness=line_width,
300
+ )
301
+
302
+ # Also draw the junctions
303
+ cv2.circle(
304
+ image,
305
+ tuple(np.flip(seg[0, :])),
306
+ radius=junc_size,
307
+ color=(0, 255.0, 0),
308
+ thickness=3,
309
+ )
310
+ cv2.circle(
311
+ image,
312
+ tuple(np.flip(seg[1, :])),
313
+ radius=junc_size,
314
+ color=(0, 255.0, 0),
315
+ thickness=3,
316
+ )
317
+
318
+ return image
319
+
320
+
321
+ # Additional functions to visualize multiple images at the same time,
322
+ # e.g. for line matching
323
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
324
+ """Plot a set of images horizontally.
325
+ Args:
326
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
327
+ titles: a list of strings, as titles for each image.
328
+ cmaps: colormaps for monochrome images.
329
+ """
330
+ n = len(imgs)
331
+ if not isinstance(cmaps, (list, tuple)):
332
+ cmaps = [cmaps] * n
333
+ # figsize = (size*n, size*3/4) if size is not None else None
334
+ figsize = (size * n, size * 6 / 5) if size is not None else None
335
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
336
+
337
+ if n == 1:
338
+ ax = [ax]
339
+ for i in range(n):
340
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
341
+ ax[i].get_yaxis().set_ticks([])
342
+ ax[i].get_xaxis().set_ticks([])
343
+ ax[i].set_axis_off()
344
+ for spine in ax[i].spines.values(): # remove frame
345
+ spine.set_visible(False)
346
+ if titles:
347
+ ax[i].set_title(titles[i])
348
+ fig.tight_layout(pad=pad)
349
+ return fig
350
+
351
+
352
+ def plot_keypoints(kpts, colors="lime", ps=4):
353
+ """Plot keypoints for existing images.
354
+ Args:
355
+ kpts: list of ndarrays of size (N, 2).
356
+ colors: string, or list of list of tuples (one for each keypoints).
357
+ ps: size of the keypoints as float.
358
+ """
359
+ if not isinstance(colors, list):
360
+ colors = [colors] * len(kpts)
361
+ axes = plt.gcf().axes
362
+ for a, k, c in zip(axes, kpts, colors):
363
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
364
+
365
+
366
+ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
367
+ """Plot matches for a pair of existing images.
368
+ Args:
369
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
370
+ color: color of each match, string or RGB tuple. Random if not given.
371
+ lw: width of the lines.
372
+ ps: size of the end points (no endpoint if ps=0)
373
+ indices: indices of the images to draw the matches on.
374
+ a: alpha opacity of the match lines.
375
+ """
376
+ fig = plt.gcf()
377
+ ax = fig.axes
378
+ assert len(ax) > max(indices)
379
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
380
+ fig.canvas.draw()
381
+
382
+ assert len(kpts0) == len(kpts1)
383
+ if color is None:
384
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
385
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
386
+ color = [color] * len(kpts0)
387
+
388
+ if lw > 0:
389
+ # transform the points into the figure coordinate system
390
+ transFigure = fig.transFigure.inverted()
391
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
392
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
393
+ fig.lines += [
394
+ matplotlib.lines.Line2D(
395
+ (fkpts0[i, 0], fkpts1[i, 0]),
396
+ (fkpts0[i, 1], fkpts1[i, 1]),
397
+ zorder=1,
398
+ transform=fig.transFigure,
399
+ c=color[i],
400
+ linewidth=lw,
401
+ alpha=a,
402
+ )
403
+ for i in range(len(kpts0))
404
+ ]
405
+
406
+ # freeze the axes to prevent the transform to change
407
+ ax0.autoscale(enable=False)
408
+ ax1.autoscale(enable=False)
409
+
410
+ if ps > 0:
411
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
412
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
413
+
414
+
415
+ def plot_lines(
416
+ lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
417
+ ):
418
+ """Plot lines and endpoints for existing images.
419
+ Args:
420
+ lines: list of ndarrays of size (N, 2, 2).
421
+ colors: string, or list of list of tuples (one for each keypoints).
422
+ ps: size of the keypoints as float pixels.
423
+ lw: line width as float pixels.
424
+ indices: indices of the images to draw the matches on.
425
+ """
426
+ if not isinstance(line_colors, list):
427
+ line_colors = [line_colors] * len(lines)
428
+ if not isinstance(point_colors, list):
429
+ point_colors = [point_colors] * len(lines)
430
+
431
+ fig = plt.gcf()
432
+ ax = fig.axes
433
+ assert len(ax) > max(indices)
434
+ axes = [ax[i] for i in indices]
435
+ fig.canvas.draw()
436
+
437
+ # Plot the lines and junctions
438
+ for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
439
+ for i in range(len(l)):
440
+ line = matplotlib.lines.Line2D(
441
+ (l[i, 0, 0], l[i, 1, 0]),
442
+ (l[i, 0, 1], l[i, 1, 1]),
443
+ zorder=1,
444
+ c=lc,
445
+ linewidth=lw,
446
+ )
447
+ a.add_line(line)
448
+ pts = l.reshape(-1, 2)
449
+ a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)
450
+
451
+ return fig
452
+
453
+
454
+ def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
455
+ """Plot matches for a pair of existing images, parametrized by their middle point.
456
+ Args:
457
+ kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
458
+ color: color of each match, string or RGB tuple. Random if not given.
459
+ lw: width of the lines.
460
+ indices: indices of the images to draw the matches on.
461
+ a: alpha opacity of the match lines.
462
+ """
463
+ fig = plt.gcf()
464
+ ax = fig.axes
465
+ assert len(ax) > max(indices)
466
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
467
+ fig.canvas.draw()
468
+
469
+ assert len(kpts0) == len(kpts1)
470
+ if color is None:
471
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
472
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
473
+ color = [color] * len(kpts0)
474
+
475
+ if lw > 0:
476
+ # transform the points into the figure coordinate system
477
+ transFigure = fig.transFigure.inverted()
478
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
479
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
480
+ fig.lines += [
481
+ matplotlib.lines.Line2D(
482
+ (fkpts0[i, 0], fkpts1[i, 0]),
483
+ (fkpts0[i, 1], fkpts1[i, 1]),
484
+ zorder=1,
485
+ transform=fig.transFigure,
486
+ c=color[i],
487
+ linewidth=lw,
488
+ alpha=a,
489
+ )
490
+ for i in range(len(kpts0))
491
+ ]
492
+
493
+ # freeze the axes to prevent the transform to change
494
+ ax0.autoscale(enable=False)
495
+ ax1.autoscale(enable=False)
496
+
497
+
498
+ def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
499
+ """Plot line matches for existing images with multiple colors.
500
+ Args:
501
+ lines: list of ndarrays of size (N, 2, 2).
502
+ correct_matches: bool array of size (N,) indicating correct matches.
503
+ lw: line width as float pixels.
504
+ indices: indices of the images to draw the matches on.
505
+ """
506
+ n_lines = len(lines[0])
507
+ colors = sns.color_palette("husl", n_colors=n_lines)
508
+ np.random.shuffle(colors)
509
+ alphas = np.ones(n_lines)
510
+ # If correct_matches is not None, display wrong matches with a low alpha
511
+ if correct_matches is not None:
512
+ alphas[~np.array(correct_matches)] = 0.2
513
+
514
+ fig = plt.gcf()
515
+ ax = fig.axes
516
+ assert len(ax) > max(indices)
517
+ axes = [ax[i] for i in indices]
518
+ fig.canvas.draw()
519
+
520
+ # Plot the lines
521
+ for a, l in zip(axes, lines):
522
+ # Transform the points into the figure coordinate system
523
+ transFigure = fig.transFigure.inverted()
524
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
525
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
526
+ fig.lines += [
527
+ matplotlib.lines.Line2D(
528
+ (endpoint0[i, 0], endpoint1[i, 0]),
529
+ (endpoint0[i, 1], endpoint1[i, 1]),
530
+ zorder=1,
531
+ transform=fig.transFigure,
532
+ c=colors[i],
533
+ alpha=alphas[i],
534
+ linewidth=lw,
535
+ )
536
+ for i in range(n_lines)
537
+ ]
538
+
539
+ return fig
540
+
541
+
542
+ def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
543
+ """Plot line matches for existing images with multiple colors:
544
+ green for correct matches, red for wrong ones, and blue for the rest.
545
+ Args:
546
+ lines: list of ndarrays of size (N, 2, 2).
547
+ correct_matches: list of bool arrays of size N with correct matches.
548
+ wrong_matches: list of bool arrays of size (N,) with correct matches.
549
+ lw: line width as float pixels.
550
+ indices: indices of the images to draw the matches on.
551
+ """
552
+ # palette = sns.color_palette()
553
+ palette = sns.color_palette("hls", 8)
554
+ blue = palette[5] # palette[0]
555
+ red = palette[0] # palette[3]
556
+ green = palette[2] # palette[2]
557
+ colors = [np.array([blue] * len(l)) for l in lines]
558
+ for i, c in enumerate(colors):
559
+ c[np.array(correct_matches[i])] = green
560
+ c[np.array(wrong_matches[i])] = red
561
+
562
+ fig = plt.gcf()
563
+ ax = fig.axes
564
+ assert len(ax) > max(indices)
565
+ axes = [ax[i] for i in indices]
566
+ fig.canvas.draw()
567
+
568
+ # Plot the lines
569
+ for a, l, c in zip(axes, lines, colors):
570
+ # Transform the points into the figure coordinate system
571
+ transFigure = fig.transFigure.inverted()
572
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
573
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
574
+ fig.lines += [
575
+ matplotlib.lines.Line2D(
576
+ (endpoint0[i, 0], endpoint1[i, 0]),
577
+ (endpoint0[i, 1], endpoint1[i, 1]),
578
+ zorder=1,
579
+ transform=fig.transFigure,
580
+ c=c[i],
581
+ linewidth=lw,
582
+ )
583
+ for i in range(len(l))
584
+ ]
585
+
586
+
587
+ def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
588
+ """Plot line matches for existing images with multiple colors and
589
+ highlight the actually matched subsegments.
590
+ Args:
591
+ lines: list of ndarrays of size (N, 2, 2).
592
+ subsegments: list of ndarrays of size (N, 2, 2).
593
+ lw: line width as float pixels.
594
+ indices: indices of the images to draw the matches on.
595
+ """
596
+ n_lines = len(lines[0])
597
+ colors = sns.cubehelix_palette(
598
+ start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
599
+ )
600
+
601
+ fig = plt.gcf()
602
+ ax = fig.axes
603
+ assert len(ax) > max(indices)
604
+ axes = [ax[i] for i in indices]
605
+ fig.canvas.draw()
606
+
607
+ # Plot the lines
608
+ for a, l, ss in zip(axes, lines, subsegments):
609
+ # Transform the points into the figure coordinate system
610
+ transFigure = fig.transFigure.inverted()
611
+
612
+ # Draw full line
613
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
614
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
615
+ fig.lines += [
616
+ matplotlib.lines.Line2D(
617
+ (endpoint0[i, 0], endpoint1[i, 0]),
618
+ (endpoint0[i, 1], endpoint1[i, 1]),
619
+ zorder=1,
620
+ transform=fig.transFigure,
621
+ c="red",
622
+ alpha=0.7,
623
+ linewidth=lw,
624
+ )
625
+ for i in range(n_lines)
626
+ ]
627
+
628
+ # Draw matched subsegment
629
+ endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
630
+ endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
631
+ fig.lines += [
632
+ matplotlib.lines.Line2D(
633
+ (endpoint0[i, 0], endpoint1[i, 0]),
634
+ (endpoint0[i, 1], endpoint1[i, 1]),
635
+ zorder=1,
636
+ transform=fig.transFigure,
637
+ c=colors[i],
638
+ alpha=1,
639
+ linewidth=lw,
640
+ )
641
+ for i in range(n_lines)
642
+ ]
hloc/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from packaging import version
3
+
4
+ __version__ = "1.3"
5
+
6
+ formatter = logging.Formatter(
7
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
8
+ )
9
+ handler = logging.StreamHandler()
10
+ handler.setFormatter(formatter)
11
+ handler.setLevel(logging.INFO)
12
+
13
+ logger = logging.getLogger("hloc")
14
+ logger.setLevel(logging.INFO)
15
+ logger.addHandler(handler)
16
+ logger.propagate = False
17
+
18
+ try:
19
+ import pycolmap
20
+ except ImportError:
21
+ logger.warning("pycolmap is not installed, some features may not work.")
22
+ else:
23
+ minimal_version = version.parse("0.3.0")
24
+ found_version = version.parse(getattr(pycolmap, "__version__"))
25
+ if found_version < minimal_version:
26
+ logger.warning(
27
+ "hloc now requires pycolmap>=%s but found pycolmap==%s, "
28
+ "please upgrade with `pip install --upgrade pycolmap`",
29
+ minimal_version,
30
+ found_version,
31
+ )
hloc/extract_features.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from pathlib import Path
4
+ from typing import Dict, List, Union, Optional
5
+ import h5py
6
+ from types import SimpleNamespace
7
+ import cv2
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import pprint
11
+ import collections.abc as collections
12
+ import PIL.Image
13
+ import torchvision.transforms.functional as F
14
+ from . import extractors, logger
15
+ from .utils.base_model import dynamic_load
16
+ from .utils.parsers import parse_image_lists
17
+ from .utils.io import read_image, list_h5_names
18
+
19
+
20
+ """
21
+ A set of standard configurations that can be directly selected from the command
22
+ line using their name. Each is a dictionary with the following entries:
23
+ - output: the name of the feature file that will be generated.
24
+ - model: the model configuration, as passed to a feature extractor.
25
+ - preprocessing: how to preprocess the images read from disk.
26
+ """
27
+ confs = {
28
+ "superpoint_aachen": {
29
+ "output": "feats-superpoint-n4096-r1024",
30
+ "model": {
31
+ "name": "superpoint",
32
+ "nms_radius": 3,
33
+ "max_keypoints": 4096,
34
+ "keypoint_threshold": 0.005,
35
+ },
36
+ "preprocessing": {
37
+ "grayscale": True,
38
+ "force_resize": True,
39
+ "resize_max": 1600,
40
+ "width": 640,
41
+ "height": 480,
42
+ "dfactor": 8,
43
+ },
44
+ },
45
+ # Resize images to 1600px even if they are originally smaller.
46
+ # Improves the keypoint localization if the images are of good quality.
47
+ "superpoint_max": {
48
+ "output": "feats-superpoint-n4096-rmax1600",
49
+ "model": {
50
+ "name": "superpoint",
51
+ "nms_radius": 3,
52
+ "max_keypoints": 4096,
53
+ "keypoint_threshold": 0.005,
54
+ },
55
+ "preprocessing": {
56
+ "grayscale": True,
57
+ "force_resize": True,
58
+ "resize_max": 1600,
59
+ "width": 640,
60
+ "height": 480,
61
+ "dfactor": 8,
62
+ },
63
+ },
64
+ "superpoint_inloc": {
65
+ "output": "feats-superpoint-n4096-r1600",
66
+ "model": {
67
+ "name": "superpoint",
68
+ "nms_radius": 4,
69
+ "max_keypoints": 4096,
70
+ "keypoint_threshold": 0.005,
71
+ },
72
+ "preprocessing": {
73
+ "grayscale": True,
74
+ "resize_max": 1600,
75
+ },
76
+ },
77
+ "r2d2": {
78
+ "output": "feats-r2d2-n5000-r1024",
79
+ "model": {
80
+ "name": "r2d2",
81
+ "max_keypoints": 5000,
82
+ "reliability_threshold": 0.7,
83
+ "repetability_threshold": 0.7,
84
+ },
85
+ "preprocessing": {
86
+ "grayscale": False,
87
+ "force_resize": True,
88
+ "resize_max": 1600,
89
+ "width": 640,
90
+ "height": 480,
91
+ "dfactor": 8,
92
+ },
93
+ },
94
+ "d2net-ss": {
95
+ "output": "feats-d2net-ss",
96
+ "model": {
97
+ "name": "d2net",
98
+ "multiscale": False,
99
+ },
100
+ "preprocessing": {
101
+ "grayscale": False,
102
+ "resize_max": 1600,
103
+ },
104
+ },
105
+ "d2net-ms": {
106
+ "output": "feats-d2net-ms",
107
+ "model": {
108
+ "name": "d2net",
109
+ "multiscale": True,
110
+ },
111
+ "preprocessing": {
112
+ "grayscale": False,
113
+ "resize_max": 1600,
114
+ },
115
+ },
116
+ "rootsift": {
117
+ "output": "feats-sift",
118
+ "model": {
119
+ "name": "dog",
120
+ "max_keypoints": 5000,
121
+ },
122
+ "preprocessing": {
123
+ "grayscale": True,
124
+ "force_resize": True,
125
+ "resize_max": 1600,
126
+ "width": 640,
127
+ "height": 480,
128
+ "dfactor": 8,
129
+ },
130
+ },
131
+ "sift": {
132
+ "output": "feats-sift",
133
+ "model": {
134
+ "name": "dog",
135
+ "descriptor": "sift",
136
+ "max_keypoints": 5000,
137
+ },
138
+ "preprocessing": {
139
+ "grayscale": True,
140
+ "force_resize": True,
141
+ "resize_max": 1600,
142
+ "width": 640,
143
+ "height": 480,
144
+ "dfactor": 8,
145
+ },
146
+ },
147
+ "sosnet": {
148
+ "output": "feats-sosnet",
149
+ "model": {"name": "dog", "descriptor": "sosnet"},
150
+ "preprocessing": {
151
+ "grayscale": True,
152
+ "resize_max": 1600,
153
+ "force_resize": True,
154
+ "width": 640,
155
+ "height": 480,
156
+ "dfactor": 8,
157
+ },
158
+ },
159
+ "hardnet": {
160
+ "output": "feats-hardnet",
161
+ "model": {"name": "dog", "descriptor": "hardnet"},
162
+ "preprocessing": {
163
+ "grayscale": True,
164
+ "resize_max": 1600,
165
+ "force_resize": True,
166
+ "width": 640,
167
+ "height": 480,
168
+ "dfactor": 8,
169
+ },
170
+ },
171
+ "disk": {
172
+ "output": "feats-disk",
173
+ "model": {
174
+ "name": "disk",
175
+ "max_keypoints": 5000,
176
+ },
177
+ "preprocessing": {
178
+ "grayscale": False,
179
+ "resize_max": 1600,
180
+ },
181
+ },
182
+ "alike": {
183
+ "output": "feats-alike",
184
+ "model": {
185
+ "name": "alike",
186
+ "max_keypoints": 5000,
187
+ "use_relu": True,
188
+ "multiscale": False,
189
+ "detection_threshold": 0.5,
190
+ "top_k": -1,
191
+ "sub_pixel": False,
192
+ },
193
+ "preprocessing": {
194
+ "grayscale": False,
195
+ "resize_max": 1600,
196
+ },
197
+ },
198
+ "lanet": {
199
+ "output": "feats-lanet",
200
+ "model": {
201
+ "name": "lanet",
202
+ "keypoint_threshold": 0.1,
203
+ "max_keypoints": 5000,
204
+ },
205
+ "preprocessing": {
206
+ "grayscale": False,
207
+ "resize_max": 1600,
208
+ },
209
+ },
210
+ "darkfeat": {
211
+ "output": "feats-darkfeat-n5000-r1024",
212
+ "model": {
213
+ "name": "darkfeat",
214
+ "max_keypoints": 5000,
215
+ "reliability_threshold": 0.7,
216
+ "repetability_threshold": 0.7,
217
+ },
218
+ "preprocessing": {
219
+ "grayscale": False,
220
+ "force_resize": True,
221
+ "resize_max": 1600,
222
+ "width": 640,
223
+ "height": 480,
224
+ "dfactor": 8,
225
+ },
226
+ },
227
+ "dedode": {
228
+ "output": "feats-dedode-n5000-r1024",
229
+ "model": {
230
+ "name": "dedode",
231
+ "max_keypoints": 5000,
232
+ },
233
+ "preprocessing": {
234
+ "grayscale": False,
235
+ "force_resize": True,
236
+ "resize_max": 1024,
237
+ "width": 768,
238
+ "height": 768,
239
+ "dfactor": 8,
240
+ },
241
+ },
242
+ "example": {
243
+ "output": "feats-example-n5000-r1024",
244
+ "model": {
245
+ "name": "example",
246
+ "keypoint_threshold": 0.1,
247
+ "max_keypoints": 2000,
248
+ "model_name": "model.pth",
249
+ },
250
+ "preprocessing": {
251
+ "grayscale": False,
252
+ "force_resize": True,
253
+ "resize_max": 1024,
254
+ "width": 768,
255
+ "height": 768,
256
+ "dfactor": 8,
257
+ },
258
+ },
259
+ # Global descriptors
260
+ "dir": {
261
+ "output": "global-feats-dir",
262
+ "model": {"name": "dir"},
263
+ "preprocessing": {"resize_max": 1024},
264
+ },
265
+ "netvlad": {
266
+ "output": "global-feats-netvlad",
267
+ "model": {"name": "netvlad"},
268
+ "preprocessing": {"resize_max": 1024},
269
+ },
270
+ "openibl": {
271
+ "output": "global-feats-openibl",
272
+ "model": {"name": "openibl"},
273
+ "preprocessing": {"resize_max": 1024},
274
+ },
275
+ "cosplace": {
276
+ "output": "global-feats-cosplace",
277
+ "model": {"name": "cosplace"},
278
+ "preprocessing": {"resize_max": 1024},
279
+ },
280
+ }
281
+
282
+
283
+ def resize_image(image, size, interp):
284
+ if interp.startswith("cv2_"):
285
+ interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper())
286
+ h, w = image.shape[:2]
287
+ if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
288
+ interp = cv2.INTER_LINEAR
289
+ resized = cv2.resize(image, size, interpolation=interp)
290
+ elif interp.startswith("pil_"):
291
+ interp = getattr(PIL.Image, interp[len("pil_") :].upper())
292
+ resized = PIL.Image.fromarray(image.astype(np.uint8))
293
+ resized = resized.resize(size, resample=interp)
294
+ resized = np.asarray(resized, dtype=image.dtype)
295
+ else:
296
+ raise ValueError(f"Unknown interpolation {interp}.")
297
+ return resized
298
+
299
+
300
+ class ImageDataset(torch.utils.data.Dataset):
301
+ default_conf = {
302
+ "globs": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
303
+ "grayscale": False,
304
+ "resize_max": None,
305
+ "force_resize": False,
306
+ "interpolation": "cv2_area", # pil_linear is more accurate but slower
307
+ }
308
+
309
+ def __init__(self, root, conf, paths=None):
310
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
311
+ self.root = root
312
+
313
+ if paths is None:
314
+ paths = []
315
+ for g in conf.globs:
316
+ paths += list(Path(root).glob("**/" + g))
317
+ if len(paths) == 0:
318
+ raise ValueError(f"Could not find any image in root: {root}.")
319
+ paths = sorted(list(set(paths)))
320
+ self.names = [i.relative_to(root).as_posix() for i in paths]
321
+ logger.info(f"Found {len(self.names)} images in root {root}.")
322
+ else:
323
+ if isinstance(paths, (Path, str)):
324
+ self.names = parse_image_lists(paths)
325
+ elif isinstance(paths, collections.Iterable):
326
+ self.names = [p.as_posix() if isinstance(p, Path) else p for p in paths]
327
+ else:
328
+ raise ValueError(f"Unknown format for path argument {paths}.")
329
+
330
+ for name in self.names:
331
+ if not (root / name).exists():
332
+ raise ValueError(f"Image {name} does not exists in root: {root}.")
333
+
334
+ def __getitem__(self, idx):
335
+ name = self.names[idx]
336
+ image = read_image(self.root / name, self.conf.grayscale)
337
+ image = image.astype(np.float32)
338
+ size = image.shape[:2][::-1]
339
+
340
+ if self.conf.resize_max and (
341
+ self.conf.force_resize or max(size) > self.conf.resize_max
342
+ ):
343
+ scale = self.conf.resize_max / max(size)
344
+ size_new = tuple(int(round(x * scale)) for x in size)
345
+ image = resize_image(image, size_new, self.conf.interpolation)
346
+
347
+ if self.conf.grayscale:
348
+ image = image[None]
349
+ else:
350
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
351
+ image = image / 255.0
352
+
353
+ data = {
354
+ "image": image,
355
+ "original_size": np.array(size),
356
+ }
357
+ return data
358
+
359
+ def __len__(self):
360
+ return len(self.names)
361
+
362
+
363
+ def extract(model, image_0, conf):
364
+ default_conf = {
365
+ "grayscale": True,
366
+ "resize_max": 1024,
367
+ "dfactor": 8,
368
+ "cache_images": False,
369
+ "force_resize": False,
370
+ "width": 320,
371
+ "height": 240,
372
+ "interpolation": "cv2_area",
373
+ }
374
+ conf = SimpleNamespace(**{**default_conf, **conf})
375
+ device = "cuda" if torch.cuda.is_available() else "cpu"
376
+
377
+ def preprocess(image: np.ndarray, conf: SimpleNamespace):
378
+ image = image.astype(np.float32, copy=False)
379
+ size = image.shape[:2][::-1]
380
+ scale = np.array([1.0, 1.0])
381
+ if conf.resize_max:
382
+ scale = conf.resize_max / max(size)
383
+ if scale < 1.0:
384
+ size_new = tuple(int(round(x * scale)) for x in size)
385
+ image = resize_image(image, size_new, "cv2_area")
386
+ scale = np.array(size) / np.array(size_new)
387
+ if conf.force_resize:
388
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
389
+ size_new = (conf.width, conf.height)
390
+ scale = np.array(size) / np.array(size_new)
391
+ if conf.grayscale:
392
+ assert image.ndim == 2, image.shape
393
+ image = image[None]
394
+ else:
395
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
396
+ image = torch.from_numpy(image / 255.0).float()
397
+
398
+ # assure that the size is divisible by dfactor
399
+ size_new = tuple(
400
+ map(lambda x: int(x // conf.dfactor * conf.dfactor), image.shape[-2:])
401
+ )
402
+ image = F.resize(image, size=size_new, antialias=True)
403
+ input_ = image.to(device, non_blocking=True)[None]
404
+ data = {
405
+ "image": input_,
406
+ "image_orig": image_0,
407
+ "original_size": np.array(size),
408
+ "size": np.array(image.shape[1:][::-1]),
409
+ }
410
+ return data
411
+
412
+ # convert to grayscale if needed
413
+ if len(image_0.shape) == 3 and conf.grayscale:
414
+ image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY)
415
+ else:
416
+ image0 = image_0
417
+ # comment following lines, image is always RGB mode
418
+ # if not conf.grayscale and len(image_0.shape) == 3:
419
+ # image0 = image_0[:, :, ::-1] # BGR to RGB
420
+ data = preprocess(image0, conf)
421
+ pred = model({"image": data["image"]})
422
+ pred["image_size"] = original_size = data["original_size"]
423
+ pred = {**pred, **data}
424
+ return pred
425
+
426
+
427
+ @torch.no_grad()
428
+ def main(
429
+ conf: Dict,
430
+ image_dir: Path,
431
+ export_dir: Optional[Path] = None,
432
+ as_half: bool = True,
433
+ image_list: Optional[Union[Path, List[str]]] = None,
434
+ feature_path: Optional[Path] = None,
435
+ overwrite: bool = False,
436
+ ) -> Path:
437
+ logger.info(
438
+ "Extracting local features with configuration:" f"\n{pprint.pformat(conf)}"
439
+ )
440
+
441
+ dataset = ImageDataset(image_dir, conf["preprocessing"], image_list)
442
+ if feature_path is None:
443
+ feature_path = Path(export_dir, conf["output"] + ".h5")
444
+ feature_path.parent.mkdir(exist_ok=True, parents=True)
445
+ skip_names = set(
446
+ list_h5_names(feature_path) if feature_path.exists() and not overwrite else ()
447
+ )
448
+ dataset.names = [n for n in dataset.names if n not in skip_names]
449
+ if len(dataset.names) == 0:
450
+ logger.info("Skipping the extraction.")
451
+ return feature_path
452
+
453
+ device = "cuda" if torch.cuda.is_available() else "cpu"
454
+ Model = dynamic_load(extractors, conf["model"]["name"])
455
+ model = Model(conf["model"]).eval().to(device)
456
+
457
+ loader = torch.utils.data.DataLoader(
458
+ dataset, num_workers=1, shuffle=False, pin_memory=True
459
+ )
460
+ for idx, data in enumerate(tqdm(loader)):
461
+ name = dataset.names[idx]
462
+ pred = model({"image": data["image"].to(device, non_blocking=True)})
463
+ pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
464
+
465
+ pred["image_size"] = original_size = data["original_size"][0].numpy()
466
+ if "keypoints" in pred:
467
+ size = np.array(data["image"].shape[-2:][::-1])
468
+ scales = (original_size / size).astype(np.float32)
469
+ pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5
470
+ if "scales" in pred:
471
+ pred["scales"] *= scales.mean()
472
+ # add keypoint uncertainties scaled to the original resolution
473
+ uncertainty = getattr(model, "detection_noise", 1) * scales.mean()
474
+
475
+ if as_half:
476
+ for k in pred:
477
+ dt = pred[k].dtype
478
+ if (dt == np.float32) and (dt != np.float16):
479
+ pred[k] = pred[k].astype(np.float16)
480
+
481
+ with h5py.File(str(feature_path), "a", libver="latest") as fd:
482
+ try:
483
+ if name in fd:
484
+ del fd[name]
485
+ grp = fd.create_group(name)
486
+ for k, v in pred.items():
487
+ grp.create_dataset(k, data=v)
488
+ if "keypoints" in pred:
489
+ grp["keypoints"].attrs["uncertainty"] = uncertainty
490
+ except OSError as error:
491
+ if "No space left on device" in error.args[0]:
492
+ logger.error(
493
+ "Out of disk space: storing features on disk can take "
494
+ "significant space, did you enable the as_half flag?"
495
+ )
496
+ del grp, fd[name]
497
+ raise error
498
+
499
+ del pred
500
+
501
+ logger.info("Finished exporting features.")
502
+ return feature_path
503
+
504
+
505
+ if __name__ == "__main__":
506
+ parser = argparse.ArgumentParser()
507
+ parser.add_argument("--image_dir", type=Path, required=True)
508
+ parser.add_argument("--export_dir", type=Path, required=True)
509
+ parser.add_argument(
510
+ "--conf", type=str, default="superpoint_aachen", choices=list(confs.keys())
511
+ )
512
+ parser.add_argument("--as_half", action="store_true")
513
+ parser.add_argument("--image_list", type=Path)
514
+ parser.add_argument("--feature_path", type=Path)
515
+ args = parser.parse_args()
516
+ main(confs[args.conf], args.image_dir, args.export_dir, args.as_half)
hloc/extractors/__init__.py ADDED
File without changes
hloc/extractors/alike.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import torch
5
+
6
+ from ..utils.base_model import BaseModel
7
+
8
+ alike_path = Path(__file__).parent / "../../third_party/ALIKE"
9
+ sys.path.append(str(alike_path))
10
+ from alike import ALike as Alike_
11
+ from alike import configs
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ class Alike(BaseModel):
17
+ default_conf = {
18
+ "model_name": "alike-t", # 'alike-t', 'alike-s', 'alike-n', 'alike-l'
19
+ "use_relu": True,
20
+ "multiscale": False,
21
+ "max_keypoints": 1000,
22
+ "detection_threshold": 0.5,
23
+ "top_k": -1,
24
+ "sub_pixel": False,
25
+ }
26
+
27
+ required_inputs = ["image"]
28
+
29
+ def _init(self, conf):
30
+ self.net = Alike_(
31
+ **configs[conf["model_name"]],
32
+ device=device,
33
+ top_k=conf["top_k"],
34
+ scores_th=conf["detection_threshold"],
35
+ n_limit=conf["max_keypoints"],
36
+ )
37
+
38
+ def _forward(self, data):
39
+ image = data["image"]
40
+ image = image.permute(0, 2, 3, 1).squeeze()
41
+ image = image.cpu().numpy() * 255.0
42
+ pred = self.net(image, sub_pixel=self.conf["sub_pixel"])
43
+
44
+ keypoints = pred["keypoints"]
45
+ descriptors = pred["descriptors"]
46
+ scores = pred["scores"]
47
+
48
+ return {
49
+ "keypoints": torch.from_numpy(keypoints)[None],
50
+ "scores": torch.from_numpy(scores)[None],
51
+ "descriptors": torch.from_numpy(descriptors.T)[None],
52
+ }
hloc/extractors/cosplace.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code for loading models trained with CosPlace as a global features extractor
3
+ for geolocalization through image retrieval.
4
+ Multiple models are available with different backbones. Below is a summary of
5
+ models available (backbone : list of available output descriptors
6
+ dimensionality). For example you can use a model based on a ResNet50 with
7
+ descriptors dimensionality 1024.
8
+ ResNet18: [32, 64, 128, 256, 512]
9
+ ResNet50: [32, 64, 128, 256, 512, 1024, 2048]
10
+ ResNet101: [32, 64, 128, 256, 512, 1024, 2048]
11
+ ResNet152: [32, 64, 128, 256, 512, 1024, 2048]
12
+ VGG16: [ 64, 128, 256, 512]
13
+
14
+ CosPlace paper: https://arxiv.org/abs/2204.02287
15
+ """
16
+
17
+ import torch
18
+ import torchvision.transforms as tvf
19
+
20
+ from ..utils.base_model import BaseModel
21
+
22
+
23
+ class CosPlace(BaseModel):
24
+ default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048}
25
+ required_inputs = ["image"]
26
+
27
+ def _init(self, conf):
28
+ self.net = torch.hub.load(
29
+ "gmberton/CosPlace",
30
+ "get_trained_model",
31
+ backbone=conf["backbone"],
32
+ fc_output_dim=conf["fc_output_dim"],
33
+ ).eval()
34
+
35
+ mean = [0.485, 0.456, 0.406]
36
+ std = [0.229, 0.224, 0.225]
37
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
38
+
39
+ def _forward(self, data):
40
+ image = self.norm_rgb(data["image"])
41
+ desc = self.net(image)
42
+ return {
43
+ "global_descriptor": desc,
44
+ }
hloc/extractors/d2net.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import torch
5
+
6
+ from ..utils.base_model import BaseModel
7
+
8
+ d2net_path = Path(__file__).parent / "../../third_party/d2net"
9
+ sys.path.append(str(d2net_path))
10
+ from lib.model_test import D2Net as _D2Net
11
+ from lib.pyramid import process_multiscale
12
+
13
+
14
+ class D2Net(BaseModel):
15
+ default_conf = {
16
+ "model_name": "d2_tf.pth",
17
+ "checkpoint_dir": d2net_path / "models",
18
+ "use_relu": True,
19
+ "multiscale": False,
20
+ }
21
+ required_inputs = ["image"]
22
+
23
+ def _init(self, conf):
24
+ model_file = conf["checkpoint_dir"] / conf["model_name"]
25
+ if not model_file.exists():
26
+ model_file.parent.mkdir(exist_ok=True)
27
+ cmd = [
28
+ "wget",
29
+ "https://dsmn.ml/files/d2-net/" + conf["model_name"],
30
+ "-O",
31
+ str(model_file),
32
+ ]
33
+ subprocess.run(cmd, check=True)
34
+
35
+ self.net = _D2Net(
36
+ model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
37
+ )
38
+
39
+ def _forward(self, data):
40
+ image = data["image"]
41
+ image = image.flip(1) # RGB -> BGR
42
+ norm = image.new_tensor([103.939, 116.779, 123.68])
43
+ image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization
44
+
45
+ if self.conf["multiscale"]:
46
+ keypoints, scores, descriptors = process_multiscale(image, self.net)
47
+ else:
48
+ keypoints, scores, descriptors = process_multiscale(
49
+ image, self.net, scales=[1]
50
+ )
51
+ keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale
52
+
53
+ return {
54
+ "keypoints": torch.from_numpy(keypoints)[None],
55
+ "scores": torch.from_numpy(scores)[None],
56
+ "descriptors": torch.from_numpy(descriptors.T)[None],
57
+ }
hloc/extractors/darkfeat.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import logging
5
+
6
+ from ..utils.base_model import BaseModel
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
11
+ sys.path.append(str(darkfeat_path))
12
+ from darkfeat import DarkFeat as DarkFeat_
13
+
14
+
15
+ class DarkFeat(BaseModel):
16
+ default_conf = {
17
+ "model_name": "DarkFeat.pth",
18
+ "max_keypoints": 1000,
19
+ "detection_threshold": 0.5,
20
+ "sub_pixel": False,
21
+ }
22
+ weight_urls = {
23
+ "DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t",
24
+ }
25
+ proxy = "http://localhost:1080"
26
+ required_inputs = ["image"]
27
+
28
+ def _init(self, conf):
29
+ model_path = darkfeat_path / "checkpoints" / conf["model_name"]
30
+ link = self.weight_urls[conf["model_name"]]
31
+ if not model_path.exists():
32
+ model_path.parent.mkdir(exist_ok=True)
33
+ cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
34
+ cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
35
+ logger.info(f"Downloading the DarkFeat model with `{cmd_wo_proxy}`.")
36
+ try:
37
+ subprocess.run(cmd_wo_proxy, check=True)
38
+ except subprocess.CalledProcessError as e:
39
+ logger.info(f"Downloading the DarkFeat model with `{cmd}`.")
40
+ try:
41
+ subprocess.run(cmd, check=True)
42
+ except subprocess.CalledProcessError as e:
43
+ logger.error(f"Failed to download the DarkFeat model.")
44
+ raise e
45
+
46
+ self.net = DarkFeat_(model_path)
47
+
48
+ def _forward(self, data):
49
+ pred = self.net({"image": data["image"]})
50
+ keypoints = pred["keypoints"]
51
+ descriptors = pred["descriptors"]
52
+ scores = pred["scores"]
53
+ return {
54
+ "keypoints": keypoints[None], # 1 x N x 2
55
+ "scores": scores[None], # 1 x N
56
+ "descriptors": descriptors[None], # 1 x 128 x N
57
+ }
hloc/extractors/dedode.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ from ..utils.base_model import BaseModel
8
+ import torchvision.transforms as transforms
9
+
10
+ dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
11
+ sys.path.append(str(dedode_path))
12
+
13
+ from DeDoDe import dedode_detector_L, dedode_descriptor_B
14
+ from DeDoDe.utils import to_pixel_coords
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class DeDoDe(BaseModel):
21
+ default_conf = {
22
+ "name": "dedode",
23
+ "model_detector_name": "dedode_detector_L.pth",
24
+ "model_descriptor_name": "dedode_descriptor_B.pth",
25
+ "max_keypoints": 2000,
26
+ "match_threshold": 0.2,
27
+ "dense": False, # Now fixed to be false
28
+ }
29
+ required_inputs = [
30
+ "image",
31
+ ]
32
+ weight_urls = {
33
+ "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
34
+ "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
35
+ }
36
+
37
+ # Initialize the line matcher
38
+ def _init(self, conf):
39
+ model_detector_path = dedode_path / "pretrained" / conf["model_detector_name"]
40
+ model_descriptor_path = (
41
+ dedode_path / "pretrained" / conf["model_descriptor_name"]
42
+ )
43
+
44
+ self.normalizer = transforms.Normalize(
45
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
46
+ )
47
+ # Download the model.
48
+ if not model_detector_path.exists():
49
+ model_detector_path.parent.mkdir(exist_ok=True)
50
+ link = self.weight_urls[conf["model_detector_name"]]
51
+ cmd = ["wget", link, "-O", str(model_detector_path)]
52
+ logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
53
+ subprocess.run(cmd, check=True)
54
+
55
+ if not model_descriptor_path.exists():
56
+ model_descriptor_path.parent.mkdir(exist_ok=True)
57
+ link = self.weight_urls[conf["model_descriptor_name"]]
58
+ cmd = ["wget", link, "-O", str(model_descriptor_path)]
59
+ logger.info(f"Downloading the DeDoDe descriptor model with `{cmd}`.")
60
+ subprocess.run(cmd, check=True)
61
+
62
+ logger.info(f"Loading DeDoDe model...")
63
+
64
+ # load the model
65
+ weights_detector = torch.load(model_detector_path, map_location="cpu")
66
+ weights_descriptor = torch.load(model_descriptor_path, map_location="cpu")
67
+ self.detector = dedode_detector_L(weights=weights_detector)
68
+ self.descriptor = dedode_descriptor_B(weights=weights_descriptor)
69
+ logger.info(f"Load DeDoDe model done.")
70
+
71
+ def _forward(self, data):
72
+ """
73
+ data: dict, keys: {'image0','image1'}
74
+ image shape: N x C x H x W
75
+ color mode: RGB
76
+ """
77
+ img0 = self.normalizer(data["image"].squeeze()).float()[None]
78
+ H_A, W_A = img0.shape[2:]
79
+
80
+ # step 1: detect keypoints
81
+ detections_A = None
82
+ batch_A = {"image": img0}
83
+ if self.conf["dense"]:
84
+ detections_A = self.detector.detect_dense(batch_A)
85
+ else:
86
+ detections_A = self.detector.detect(
87
+ batch_A, num_keypoints=self.conf["max_keypoints"]
88
+ )
89
+ keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
90
+
91
+ # step 2: describe keypoints
92
+ # dim: 1 x N x 256
93
+ description_A = self.descriptor.describe_keypoints(batch_A, keypoints_A)[
94
+ "descriptions"
95
+ ]
96
+ keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)
97
+
98
+ return {
99
+ "keypoints": keypoints_A, # 1 x N x 2
100
+ "descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N
101
+ "scores": P_A, # 1 x N
102
+ }
hloc/extractors/dir.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import torch
4
+ from zipfile import ZipFile
5
+ import os
6
+ import sklearn
7
+ import gdown
8
+
9
+ from ..utils.base_model import BaseModel
10
+
11
+ sys.path.append(str(Path(__file__).parent / "../../third_party/deep-image-retrieval"))
12
+ os.environ["DB_ROOT"] = "" # required by dirtorch
13
+
14
+ from dirtorch.utils import common # noqa: E402
15
+ from dirtorch.extract_features import load_model # noqa: E402
16
+
17
+ # The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
18
+ # which has been deprecated in sklearn v0.24
19
+ # and must be explicitly imported with `from sklearn.decomposition import PCA`.
20
+ # This is a hacky workaround to maintain forward compatibility.
21
+ sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca
22
+
23
+
24
+ class DIR(BaseModel):
25
+ default_conf = {
26
+ "model_name": "Resnet-101-AP-GeM",
27
+ "whiten_name": "Landmarks_clean",
28
+ "whiten_params": {
29
+ "whitenp": 0.25,
30
+ "whitenv": None,
31
+ "whitenm": 1.0,
32
+ },
33
+ "pooling": "gem",
34
+ "gemp": 3,
35
+ }
36
+ required_inputs = ["image"]
37
+
38
+ dir_models = {
39
+ "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy",
40
+ }
41
+
42
+ def _init(self, conf):
43
+ checkpoint = Path(torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt")
44
+ if not checkpoint.exists():
45
+ checkpoint.parent.mkdir(exist_ok=True, parents=True)
46
+ link = self.dir_models[conf["model_name"]]
47
+ gdown.download(str(link), str(checkpoint) + ".zip", quiet=False)
48
+ zf = ZipFile(str(checkpoint) + ".zip", "r")
49
+ zf.extractall(checkpoint.parent)
50
+ zf.close()
51
+ os.remove(str(checkpoint) + ".zip")
52
+
53
+ self.net = load_model(checkpoint, False) # first load on CPU
54
+ if conf["whiten_name"]:
55
+ assert conf["whiten_name"] in self.net.pca
56
+
57
+ def _forward(self, data):
58
+ image = data["image"]
59
+ assert image.shape[1] == 3
60
+ mean = self.net.preprocess["mean"]
61
+ std = self.net.preprocess["std"]
62
+ image = image - image.new_tensor(mean)[:, None, None]
63
+ image = image / image.new_tensor(std)[:, None, None]
64
+
65
+ desc = self.net(image)
66
+ desc = desc.unsqueeze(0) # batch dimension
67
+ if self.conf["whiten_name"]:
68
+ pca = self.net.pca[self.conf["whiten_name"]]
69
+ desc = common.whiten_features(
70
+ desc.cpu().numpy(), pca, **self.conf["whiten_params"]
71
+ )
72
+ desc = torch.from_numpy(desc)
73
+
74
+ return {
75
+ "global_descriptor": desc,
76
+ }
hloc/extractors/disk.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+
3
+ from ..utils.base_model import BaseModel
4
+
5
+
6
+ class DISK(BaseModel):
7
+ default_conf = {
8
+ "weights": "depth",
9
+ "max_keypoints": None,
10
+ "nms_window_size": 5,
11
+ "detection_threshold": 0.0,
12
+ "pad_if_not_divisible": True,
13
+ }
14
+ required_inputs = ["image"]
15
+
16
+ def _init(self, conf):
17
+ self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
18
+
19
+ def _forward(self, data):
20
+ image = data["image"]
21
+ features = self.model(
22
+ image,
23
+ n=self.conf["max_keypoints"],
24
+ window_size=self.conf["nms_window_size"],
25
+ score_threshold=self.conf["detection_threshold"],
26
+ pad_if_not_divisible=self.conf["pad_if_not_divisible"],
27
+ )
28
+ return {
29
+ "keypoints": [f.keypoints for f in features],
30
+ "scores": [f.detection_scores for f in features],
31
+ "descriptors": [f.descriptors.t() for f in features],
32
+ }
hloc/extractors/dog.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ from kornia.feature.laf import laf_from_center_scale_ori, extract_patches_from_pyramid
3
+ import numpy as np
4
+ import torch
5
+ import pycolmap
6
+
7
+ from ..utils.base_model import BaseModel
8
+
9
+
10
+ EPS = 1e-6
11
+
12
+
13
+ def sift_to_rootsift(x):
14
+ x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
15
+ x = np.sqrt(x.clip(min=EPS))
16
+ x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
17
+ return x
18
+
19
+
20
+ class DoG(BaseModel):
21
+ default_conf = {
22
+ "options": {
23
+ "first_octave": 0,
24
+ "peak_threshold": 0.01,
25
+ },
26
+ "descriptor": "rootsift",
27
+ "max_keypoints": -1,
28
+ "patch_size": 32,
29
+ "mr_size": 12,
30
+ }
31
+ required_inputs = ["image"]
32
+ detection_noise = 1.0
33
+ max_batch_size = 1024
34
+
35
+ def _init(self, conf):
36
+ if conf["descriptor"] == "sosnet":
37
+ self.describe = kornia.feature.SOSNet(pretrained=True)
38
+ elif conf["descriptor"] == "hardnet":
39
+ self.describe = kornia.feature.HardNet(pretrained=True)
40
+ elif conf["descriptor"] not in ["sift", "rootsift"]:
41
+ raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
42
+
43
+ self.sift = None # lazily instantiated on the first image
44
+ self.device = torch.device("cpu")
45
+
46
+ def to(self, *args, **kwargs):
47
+ device = kwargs.get("device")
48
+ if device is None:
49
+ match = [a for a in args if isinstance(a, (torch.device, str))]
50
+ if len(match) > 0:
51
+ device = match[0]
52
+ if device is not None:
53
+ self.device = torch.device(device)
54
+ return super().to(*args, **kwargs)
55
+
56
+ def _forward(self, data):
57
+ image = data["image"]
58
+ image_np = image.cpu().numpy()[0, 0]
59
+ assert image.shape[1] == 1
60
+ assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
61
+
62
+ if self.sift is None:
63
+ use_gpu = pycolmap.has_cuda and self.device.type == "cuda"
64
+ options = {**self.conf["options"]}
65
+ if self.conf["descriptor"] == "rootsift":
66
+ options["normalization"] = pycolmap.Normalization.L1_ROOT
67
+ else:
68
+ options["normalization"] = pycolmap.Normalization.L2
69
+ self.sift = pycolmap.Sift(
70
+ options=pycolmap.SiftExtractionOptions(options),
71
+ device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
72
+ )
73
+
74
+ keypoints, scores, descriptors = self.sift.extract(image_np)
75
+ scales = keypoints[:, 2]
76
+ oris = np.rad2deg(keypoints[:, 3])
77
+
78
+ if self.conf["descriptor"] in ["sift", "rootsift"]:
79
+ # We still renormalize because COLMAP does not normalize well,
80
+ # maybe due to numerical errors
81
+ if self.conf["descriptor"] == "rootsift":
82
+ descriptors = sift_to_rootsift(descriptors)
83
+ descriptors = torch.from_numpy(descriptors)
84
+ elif self.conf["descriptor"] in ("sosnet", "hardnet"):
85
+ center = keypoints[:, :2] + 0.5
86
+ laf_scale = scales * self.conf["mr_size"] / 2
87
+ laf_ori = -oris
88
+ lafs = laf_from_center_scale_ori(
89
+ torch.from_numpy(center)[None],
90
+ torch.from_numpy(laf_scale)[None, :, None, None],
91
+ torch.from_numpy(laf_ori)[None, :, None],
92
+ ).to(image.device)
93
+ patches = extract_patches_from_pyramid(
94
+ image, lafs, PS=self.conf["patch_size"]
95
+ )[0]
96
+ descriptors = patches.new_zeros((len(patches), 128))
97
+ if len(patches) > 0:
98
+ for start_idx in range(0, len(patches), self.max_batch_size):
99
+ end_idx = min(len(patches), start_idx + self.max_batch_size)
100
+ descriptors[start_idx:end_idx] = self.describe(
101
+ patches[start_idx:end_idx]
102
+ )
103
+ else:
104
+ raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')
105
+
106
+ keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
107
+ scales = torch.from_numpy(scales)
108
+ oris = torch.from_numpy(oris)
109
+ scores = torch.from_numpy(scores)
110
+ if self.conf["max_keypoints"] != -1:
111
+ # TODO: check that the scores from PyCOLMAP are 100% correct,
112
+ # follow https://github.com/mihaidusmanu/pycolmap/issues/8
113
+ max_number = (
114
+ scores.shape[0]
115
+ if scores.shape[0] < self.conf["max_keypoints"]
116
+ else self.conf["max_keypoints"]
117
+ )
118
+ values, indices = torch.topk(scores, max_number)
119
+ keypoints = keypoints[indices]
120
+ scales = scales[indices]
121
+ oris = oris[indices]
122
+ scores = scores[indices]
123
+ descriptors = descriptors[indices]
124
+
125
+ return {
126
+ "keypoints": keypoints[None],
127
+ "scales": scales[None],
128
+ "oris": oris[None],
129
+ "scores": scores[None],
130
+ "descriptors": descriptors.T[None],
131
+ }
hloc/extractors/example.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import torch
5
+ import logging
6
+
7
+ from ..utils.base_model import BaseModel
8
+
9
+ example_path = Path(__file__).parent / "../../third_party/example"
10
+ sys.path.append(str(example_path))
11
+
12
+ # import some modules here
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Example(BaseModel):
19
+ # change to your default configs
20
+ default_conf = {
21
+ "name": "example",
22
+ "keypoint_threshold": 0.1,
23
+ "max_keypoints": 2000,
24
+ "model_name": "model.pth",
25
+ }
26
+ required_inputs = ["image"]
27
+
28
+ def _init(self, conf):
29
+
30
+ # set checkpoints paths if needed
31
+ model_path = example_path / "checkpoints" / f'{conf["model_name"]}'
32
+ if not model_path.exists():
33
+ logger.info(f"No model found at {model_path}")
34
+
35
+ # init model
36
+ self.net = callable
37
+ # self.net = ExampleNet(is_test=True)
38
+ state_dict = torch.load(model_path, map_location="cpu")
39
+ self.net.load_state_dict(state_dict["model_state"])
40
+ logger.info(f"Load example model done.")
41
+
42
+ def _forward(self, data):
43
+ # data: dict, keys: 'image'
44
+ # image color mode: RGB
45
+ # image value range in [0, 1]
46
+ image = data["image"]
47
+
48
+ # B: batch size, N: number of keypoints
49
+ # keypoints shape: B x N x 2, type: torch tensor
50
+ # scores shape: B x N, type: torch tensor
51
+ # descriptors shape: B x 128 x N, type: torch tensor
52
+ keypoints, scores, descriptors = self.net(image)
53
+
54
+ return {
55
+ "keypoints": keypoints,
56
+ "scores": scores,
57
+ "descriptors": descriptors,
58
+ }
hloc/extractors/fire.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import subprocess
3
+ import logging
4
+ import sys
5
+ import torch
6
+ import torchvision.transforms as tvf
7
+
8
+ from ..utils.base_model import BaseModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+ fire_path = Path(__file__).parent / "../../third_party/fire"
12
+ sys.path.append(str(fire_path))
13
+
14
+
15
+ import fire_network
16
+
17
+
18
+ class FIRe(BaseModel):
19
+ default_conf = {
20
+ "global": True,
21
+ "asmk": False,
22
+ "model_name": "fire_SfM_120k.pth",
23
+ "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
24
+ "features_num": 1000, # TODO:not supported now
25
+ "asmk_name": "asmk_codebook.bin", # TODO:not supported now
26
+ "config_name": "eval_fire.yml",
27
+ }
28
+ required_inputs = ["image"]
29
+
30
+ # Models exported using
31
+ fire_models = {
32
+ "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
33
+ "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
34
+ }
35
+
36
+ def _init(self, conf):
37
+
38
+ assert conf["model_name"] in self.fire_models.keys()
39
+ # Config paths
40
+ model_path = fire_path / "model" / conf["model_name"]
41
+
42
+ # Download the model.
43
+ if not model_path.exists():
44
+ model_path.parent.mkdir(exist_ok=True)
45
+ link = self.fire_models[conf["model_name"]]
46
+ cmd = ["wget", link, "-O", str(model_path)]
47
+ logger.info(f"Downloading the FIRe model with `{cmd}`.")
48
+ subprocess.run(cmd, check=True)
49
+
50
+ logger.info(f"Loading fire model...")
51
+
52
+ # Load net
53
+ state = torch.load(model_path)
54
+ state["net_params"]["pretrained"] = None
55
+ net = fire_network.init_network(**state["net_params"])
56
+ net.load_state_dict(state["state_dict"])
57
+ self.net = net
58
+
59
+ self.norm_rgb = tvf.Normalize(
60
+ **dict(zip(["mean", "std"], net.runtime["mean_std"]))
61
+ )
62
+
63
+ # params
64
+ self.scales = conf["scales"]
65
+
66
+ def _forward(self, data):
67
+
68
+ image = self.norm_rgb(data["image"])
69
+
70
+ # Feature extraction.
71
+ desc = self.net.forward_global(image, scales=self.scales)
72
+
73
+ return {"global_descriptor": desc}
hloc/extractors/fire_local.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import subprocess
3
+ import logging
4
+ import sys
5
+ import torch
6
+ import torchvision.transforms as tvf
7
+
8
+ from ..utils.base_model import BaseModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+ fire_path = Path(__file__).parent / "../../third_party/fire"
12
+
13
+ sys.path.append(str(fire_path))
14
+
15
+
16
+ import fire_network
17
+ from lib.how.how.stages.evaluate import eval_asmk_fire, load_dataset_fire
18
+
19
+ from lib.asmk import asmk
20
+ from asmk import io_helpers, asmk_method, kernel as kern_pkg
21
+
22
+ EPS = 1e-6
23
+
24
+
25
+ class FIRe(BaseModel):
26
+ default_conf = {
27
+ "global": True,
28
+ "asmk": False,
29
+ "model_name": "fire_SfM_120k.pth",
30
+ "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
31
+ "features_num": 1000,
32
+ "asmk_name": "asmk_codebook.bin",
33
+ "config_name": "eval_fire.yml",
34
+ }
35
+ required_inputs = ["image"]
36
+
37
+ # Models exported using
38
+ fire_models = {
39
+ "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
40
+ "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
41
+ }
42
+
43
+ def _init(self, conf):
44
+
45
+ assert conf["model_name"] in self.fire_models.keys()
46
+
47
+ # Config paths
48
+ model_path = fire_path / "model" / conf["model_name"]
49
+ config_path = fire_path / conf["config_name"]
50
+ asmk_bin_path = fire_path / "model" / conf["asmk_name"]
51
+
52
+ # Download the model.
53
+ if not model_path.exists():
54
+ model_path.parent.mkdir(exist_ok=True)
55
+ link = self.fire_models[conf["model_name"]]
56
+ cmd = ["wget", link, "-O", str(model_path)]
57
+ logger.info(f"Downloading the FIRe model with `{cmd}`.")
58
+ subprocess.run(cmd, check=True)
59
+
60
+ logger.info(f"Loading fire model...")
61
+
62
+ # Load net
63
+ state = torch.load(model_path)
64
+ state["net_params"]["pretrained"] = None
65
+ net = fire_network.init_network(**state["net_params"])
66
+ net.load_state_dict(state["state_dict"])
67
+ self.net = net
68
+
69
+ self.norm_rgb = tvf.Normalize(
70
+ **dict(zip(["mean", "std"], net.runtime["mean_std"]))
71
+ )
72
+
73
+ # params
74
+ self.scales = conf["scales"]
75
+ self.features_num = conf["features_num"]
76
+
77
+ def _forward(self, data):
78
+
79
+ image = self.norm_rgb(data["image"])
80
+
81
+ local_desc = self.net.forward_local(
82
+ image, features_num=self.features_num, scales=self.scales
83
+ )
84
+
85
+ logger.info(f"output[0].shape = {local_desc[0].shape}\n")
86
+
87
+ return {
88
+ # 'global_descriptor': desc
89
+ "local_descriptor": local_desc
90
+ }
hloc/extractors/lanet.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import torch
5
+
6
+ from ..utils.base_model import BaseModel
7
+
8
+ lanet_path = Path(__file__).parent / "../../third_party/lanet"
9
+ sys.path.append(str(lanet_path))
10
+ from network_v0.model import PointModel
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ class LANet(BaseModel):
16
+ default_conf = {
17
+ "model_name": "v0",
18
+ "keypoint_threshold": 0.1,
19
+ }
20
+ required_inputs = ["image"]
21
+
22
+ def _init(self, conf):
23
+ model_path = lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
24
+ if not model_path.exists():
25
+ print(f"No model found at {model_path}")
26
+ self.net = PointModel(is_test=True)
27
+ state_dict = torch.load(model_path, map_location="cpu")
28
+ self.net.load_state_dict(state_dict["model_state"])
29
+
30
+ def _forward(self, data):
31
+ image = data["image"]
32
+ keypoints, scores, descriptors = self.net(image)
33
+ _, _, Hc, Wc = descriptors.shape
34
+
35
+ # Scores & Descriptors
36
+ kpts_score = (
37
+ torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy()
38
+ )
39
+ descriptors = (
40
+ descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy()
41
+ )
42
+
43
+ # Filter based on confidence threshold
44
+ descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
45
+ kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
46
+ keypoints = kpts_score[:, 1:]
47
+ scores = kpts_score[:, 0]
48
+
49
+ return {
50
+ "keypoints": torch.from_numpy(keypoints)[None],
51
+ "scores": torch.from_numpy(scores)[None],
52
+ "descriptors": torch.from_numpy(descriptors.T)[None],
53
+ }
hloc/extractors/netvlad.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import subprocess
3
+ import logging
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.models as models
9
+ from scipy.io import loadmat
10
+
11
+ from ..utils.base_model import BaseModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ EPS = 1e-6
16
+
17
+
18
+ class NetVLADLayer(nn.Module):
19
+ def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True):
20
+ super().__init__()
21
+ self.score_proj = nn.Conv1d(input_dim, K, kernel_size=1, bias=score_bias)
22
+ centers = nn.parameter.Parameter(torch.empty([input_dim, K]))
23
+ nn.init.xavier_uniform_(centers)
24
+ self.register_parameter("centers", centers)
25
+ self.intranorm = intranorm
26
+ self.output_dim = input_dim * K
27
+
28
+ def forward(self, x):
29
+ b = x.size(0)
30
+ scores = self.score_proj(x)
31
+ scores = F.softmax(scores, dim=1)
32
+ diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1)
33
+ desc = (scores.unsqueeze(1) * diff).sum(dim=-1)
34
+ if self.intranorm:
35
+ # From the official MATLAB implementation.
36
+ desc = F.normalize(desc, dim=1)
37
+ desc = desc.view(b, -1)
38
+ desc = F.normalize(desc, dim=1)
39
+ return desc
40
+
41
+
42
+ class NetVLAD(BaseModel):
43
+ default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True}
44
+ required_inputs = ["image"]
45
+
46
+ # Models exported using
47
+ # https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m.
48
+ dir_models = {
49
+ "VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat",
50
+ "VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat",
51
+ }
52
+
53
+ def _init(self, conf):
54
+ assert conf["model_name"] in self.dir_models.keys()
55
+
56
+ # Download the checkpoint.
57
+ checkpoint = Path(torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat")
58
+ if not checkpoint.exists():
59
+ checkpoint.parent.mkdir(exist_ok=True, parents=True)
60
+ link = self.dir_models[conf["model_name"]]
61
+ cmd = ["wget", link, "-O", str(checkpoint)]
62
+ logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
63
+ subprocess.run(cmd, check=True)
64
+
65
+ # Create the network.
66
+ # Remove classification head.
67
+ backbone = list(models.vgg16().children())[0]
68
+ # Remove last ReLU + MaxPool2d.
69
+ self.backbone = nn.Sequential(*list(backbone.children())[:-2])
70
+
71
+ self.netvlad = NetVLADLayer()
72
+
73
+ if conf["whiten"]:
74
+ self.whiten = nn.Linear(self.netvlad.output_dim, 4096)
75
+
76
+ # Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open
77
+ mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True)
78
+
79
+ # CNN weights.
80
+ for layer, mat_layer in zip(self.backbone.children(), mat["net"].layers):
81
+ if isinstance(layer, nn.Conv2d):
82
+ w = mat_layer.weights[0] # Shape: S x S x IN x OUT
83
+ b = mat_layer.weights[1] # Shape: OUT
84
+ # Prepare for PyTorch - enforce float32 and right shape.
85
+ # w should have shape: OUT x IN x S x S
86
+ # b should have shape: OUT
87
+ w = torch.tensor(w).float().permute([3, 2, 0, 1])
88
+ b = torch.tensor(b).float()
89
+ # Update layer weights.
90
+ layer.weight = nn.Parameter(w)
91
+ layer.bias = nn.Parameter(b)
92
+
93
+ # NetVLAD weights.
94
+ score_w = mat["net"].layers[30].weights[0] # D x K
95
+ # centers are stored as opposite in official MATLAB code
96
+ center_w = -mat["net"].layers[30].weights[1] # D x K
97
+ # Prepare for PyTorch - make sure it is float32 and has right shape.
98
+ # score_w should have shape K x D x 1
99
+ # center_w should have shape D x K
100
+ score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1)
101
+ center_w = torch.tensor(center_w).float()
102
+ # Update layer weights.
103
+ self.netvlad.score_proj.weight = nn.Parameter(score_w)
104
+ self.netvlad.centers = nn.Parameter(center_w)
105
+
106
+ # Whitening weights.
107
+ if conf["whiten"]:
108
+ w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT
109
+ b = mat["net"].layers[33].weights[1] # Shape: OUT
110
+ # Prepare for PyTorch - make sure it is float32 and has right shape
111
+ w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN
112
+ b = torch.tensor(b.squeeze()).float() # Shape: OUT
113
+ # Update layer weights.
114
+ self.whiten.weight = nn.Parameter(w)
115
+ self.whiten.bias = nn.Parameter(b)
116
+
117
+ # Preprocessing parameters.
118
+ self.preprocess = {
119
+ "mean": mat["net"].meta.normalization.averageImage[0, 0],
120
+ "std": np.array([1, 1, 1], dtype=np.float32),
121
+ }
122
+
123
+ def _forward(self, data):
124
+ image = data["image"]
125
+ assert image.shape[1] == 3
126
+ assert image.min() >= -EPS and image.max() <= 1 + EPS
127
+ image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255.
128
+ mean = self.preprocess["mean"]
129
+ std = self.preprocess["std"]
130
+ image = image - image.new_tensor(mean).view(1, -1, 1, 1)
131
+ image = image / image.new_tensor(std).view(1, -1, 1, 1)
132
+
133
+ # Feature extraction.
134
+ descriptors = self.backbone(image)
135
+ b, c, _, _ = descriptors.size()
136
+ descriptors = descriptors.view(b, c, -1)
137
+
138
+ # NetVLAD layer.
139
+ descriptors = F.normalize(descriptors, dim=1) # Pre-normalization.
140
+ desc = self.netvlad(descriptors)
141
+
142
+ # Whiten if needed.
143
+ if hasattr(self, "whiten"):
144
+ desc = self.whiten(desc)
145
+ desc = F.normalize(desc, dim=1) # Final L2 normalization.
146
+
147
+ return {"global_descriptor": desc}
hloc/extractors/openibl.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as tvf
3
+
4
+ from ..utils.base_model import BaseModel
5
+
6
+
7
+ class OpenIBL(BaseModel):
8
+ default_conf = {
9
+ "model_name": "vgg16_netvlad",
10
+ }
11
+ required_inputs = ["image"]
12
+
13
+ def _init(self, conf):
14
+ self.net = torch.hub.load(
15
+ "yxgeee/OpenIBL", conf["model_name"], pretrained=True
16
+ ).eval()
17
+ mean = [0.48501960784313836, 0.4579568627450961, 0.4076039215686255]
18
+ std = [0.00392156862745098, 0.00392156862745098, 0.00392156862745098]
19
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
20
+
21
+ def _forward(self, data):
22
+ image = self.norm_rgb(data["image"])
23
+ desc = self.net(image)
24
+ return {
25
+ "global_descriptor": desc,
26
+ }
hloc/extractors/r2d2.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import torchvision.transforms as tvf
4
+
5
+ from ..utils.base_model import BaseModel
6
+
7
+ r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
8
+ sys.path.append(str(r2d2_path))
9
+ from extract import load_network, NonMaxSuppression, extract_multiscale
10
+
11
+
12
+ class R2D2(BaseModel):
13
+ default_conf = {
14
+ "model_name": "r2d2_WASF_N16.pt",
15
+ "max_keypoints": 5000,
16
+ "scale_factor": 2**0.25,
17
+ "min_size": 256,
18
+ "max_size": 1024,
19
+ "min_scale": 0,
20
+ "max_scale": 1,
21
+ "reliability_threshold": 0.7,
22
+ "repetability_threshold": 0.7,
23
+ }
24
+ required_inputs = ["image"]
25
+
26
+ def _init(self, conf):
27
+ model_fn = r2d2_path / "models" / conf["model_name"]
28
+ self.norm_rgb = tvf.Normalize(
29
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
30
+ )
31
+ self.net = load_network(model_fn)
32
+ self.detector = NonMaxSuppression(
33
+ rel_thr=conf["reliability_threshold"],
34
+ rep_thr=conf["repetability_threshold"],
35
+ )
36
+
37
+ def _forward(self, data):
38
+ img = data["image"]
39
+ img = self.norm_rgb(img)
40
+
41
+ xys, desc, scores = extract_multiscale(
42
+ self.net,
43
+ img,
44
+ self.detector,
45
+ scale_f=self.conf["scale_factor"],
46
+ min_size=self.conf["min_size"],
47
+ max_size=self.conf["max_size"],
48
+ min_scale=self.conf["min_scale"],
49
+ max_scale=self.conf["max_scale"],
50
+ )
51
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
52
+ xy = xys[idxs, :2]
53
+ desc = desc[idxs].t()
54
+ scores = scores[idxs]
55
+
56
+ pred = {
57
+ "keypoints": xy[None],
58
+ "descriptors": desc[None],
59
+ "scores": scores[None],
60
+ }
61
+ return pred
hloc/extractors/rekd.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import subprocess
4
+ import torch
5
+
6
+ from ..utils.base_model import BaseModel
7
+
8
+ rekd_path = Path(__file__).parent / "../../third_party/REKD"
9
+ sys.path.append(str(rekd_path))
10
+ from training.model.REKD import REKD as REKD_
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ class REKD(BaseModel):
16
+ default_conf = {
17
+ "model_name": "v0",
18
+ "keypoint_threshold": 0.1,
19
+ }
20
+ required_inputs = ["image"]
21
+
22
+ def _init(self, conf):
23
+ model_path = rekd_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
24
+ if not model_path.exists():
25
+ print(f"No model found at {model_path}")
26
+ self.net = REKD_(is_test=True)
27
+ state_dict = torch.load(model_path, map_location="cpu")
28
+ self.net.load_state_dict(state_dict["model_state"])
29
+
30
+ def _forward(self, data):
31
+ image = data["image"]
32
+ keypoints, scores, descriptors = self.net(image)
33
+ _, _, Hc, Wc = descriptors.shape
34
+
35
+ # Scores & Descriptors
36
+ kpts_score = (
37
+ torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy()
38
+ )
39
+ descriptors = (
40
+ descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy()
41
+ )
42
+
43
+ # Filter based on confidence threshold
44
+ descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
45
+ kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
46
+ keypoints = kpts_score[:, 1:]
47
+ scores = kpts_score[:, 0]
48
+
49
+ return {
50
+ "keypoints": torch.from_numpy(keypoints)[None],
51
+ "scores": torch.from_numpy(scores)[None],
52
+ "descriptors": torch.from_numpy(descriptors.T)[None],
53
+ }
hloc/extractors/superpoint.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import torch
4
+
5
+ from ..utils.base_model import BaseModel
6
+
7
+ sys.path.append(str(Path(__file__).parent / "../../third_party"))
8
+ from SuperGluePretrainedNetwork.models import superpoint # noqa E402
9
+
10
+
11
+ # The original keypoint sampling is incorrect. We patch it here but
12
+ # we don't fix it upstream to not impact exisiting evaluations.
13
+ def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
14
+ """Interpolate descriptors at keypoint locations"""
15
+ b, c, h, w = descriptors.shape
16
+ keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
17
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
18
+ descriptors = torch.nn.functional.grid_sample(
19
+ descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
20
+ )
21
+ descriptors = torch.nn.functional.normalize(
22
+ descriptors.reshape(b, c, -1), p=2, dim=1
23
+ )
24
+ return descriptors
25
+
26
+
27
+ class SuperPoint(BaseModel):
28
+ default_conf = {
29
+ "nms_radius": 4,
30
+ "keypoint_threshold": 0.005,
31
+ "max_keypoints": -1,
32
+ "remove_borders": 4,
33
+ "fix_sampling": False,
34
+ }
35
+ required_inputs = ["image"]
36
+ detection_noise = 2.0
37
+
38
+ def _init(self, conf):
39
+ if conf["fix_sampling"]:
40
+ superpoint.sample_descriptors = sample_descriptors_fix_sampling
41
+ self.net = superpoint.SuperPoint(conf)
42
+
43
+ def _forward(self, data):
44
+ return self.net(data)
hloc/match_dense.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms.functional as F
4
+ from types import SimpleNamespace
5
+ from .extract_features import read_image, resize_image
6
+ import cv2
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ confs = {
11
+ # Best quality but loads of points. Only use for small scenes
12
+ "loftr": {
13
+ "output": "matches-loftr",
14
+ "model": {
15
+ "name": "loftr",
16
+ "weights": "outdoor",
17
+ "max_keypoints": 2000,
18
+ "match_threshold": 0.2,
19
+ },
20
+ "preprocessing": {
21
+ "grayscale": True,
22
+ "resize_max": 1024,
23
+ "dfactor": 8,
24
+ "width": 640,
25
+ "height": 480,
26
+ "force_resize": True,
27
+ },
28
+ "max_error": 1, # max error for assigned keypoints (in px)
29
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
30
+ },
31
+ # Semi-scalable loftr which limits detected keypoints
32
+ "loftr_aachen": {
33
+ "output": "matches-loftr_aachen",
34
+ "model": {
35
+ "name": "loftr",
36
+ "weights": "outdoor",
37
+ "max_keypoints": 2000,
38
+ "match_threshold": 0.2,
39
+ },
40
+ "preprocessing": {"grayscale": True, "resize_max": 1024, "dfactor": 8},
41
+ "max_error": 2, # max error for assigned keypoints (in px)
42
+ "cell_size": 8, # size of quantization patch (max 1 kp/patch)
43
+ },
44
+ # Use for matching superpoint feats with loftr
45
+ "loftr_superpoint": {
46
+ "output": "matches-loftr_aachen",
47
+ "model": {
48
+ "name": "loftr",
49
+ "weights": "outdoor",
50
+ "max_keypoints": 2000,
51
+ "match_threshold": 0.2,
52
+ },
53
+ "preprocessing": {"grayscale": True, "resize_max": 1024, "dfactor": 8},
54
+ "max_error": 4, # max error for assigned keypoints (in px)
55
+ "cell_size": 4, # size of quantization patch (max 1 kp/patch)
56
+ },
57
+ # Use topicfm for matching feats
58
+ "topicfm": {
59
+ "output": "matches-topicfm",
60
+ "model": {
61
+ "name": "topicfm",
62
+ "weights": "outdoor",
63
+ "max_keypoints": 2000,
64
+ "match_threshold": 0.2,
65
+ },
66
+ "preprocessing": {
67
+ "grayscale": True,
68
+ "force_resize": True,
69
+ "resize_max": 1024,
70
+ "dfactor": 8,
71
+ "width": 640,
72
+ "height": 480,
73
+ },
74
+ },
75
+ # Use topicfm for matching feats
76
+ "aspanformer": {
77
+ "output": "matches-aspanformer",
78
+ "model": {
79
+ "name": "aspanformer",
80
+ "weights": "outdoor",
81
+ "max_keypoints": 2000,
82
+ "match_threshold": 0.2,
83
+ },
84
+ "preprocessing": {
85
+ "grayscale": True,
86
+ "force_resize": True,
87
+ "resize_max": 1024,
88
+ "width": 640,
89
+ "height": 480,
90
+ "dfactor": 8,
91
+ },
92
+ },
93
+ "dkm": {
94
+ "output": "matches-dkm",
95
+ "model": {
96
+ "name": "dkm",
97
+ "weights": "outdoor",
98
+ "max_keypoints": 2000,
99
+ "match_threshold": 0.2,
100
+ },
101
+ "preprocessing": {
102
+ "grayscale": False,
103
+ "force_resize": True,
104
+ "resize_max": 1024,
105
+ "width": 80,
106
+ "height": 60,
107
+ "dfactor": 8,
108
+ },
109
+ },
110
+ "roma": {
111
+ "output": "matches-roma",
112
+ "model": {
113
+ "name": "roma",
114
+ "weights": "outdoor",
115
+ "max_keypoints": 2000,
116
+ "match_threshold": 0.2,
117
+ },
118
+ "preprocessing": {
119
+ "grayscale": False,
120
+ "force_resize": True,
121
+ "resize_max": 1024,
122
+ "width": 320,
123
+ "height": 240,
124
+ "dfactor": 8,
125
+ },
126
+ },
127
+ "dedode_sparse": {
128
+ "output": "matches-dedode",
129
+ "model": {
130
+ "name": "dedode",
131
+ "max_keypoints": 2000,
132
+ "match_threshold": 0.2,
133
+ "dense": False,
134
+ },
135
+ "preprocessing": {
136
+ "grayscale": False,
137
+ "force_resize": True,
138
+ "resize_max": 1024,
139
+ "width": 768,
140
+ "height": 768,
141
+ "dfactor": 8,
142
+ },
143
+ },
144
+ "sold2": {
145
+ "output": "matches-sold2",
146
+ "model": {
147
+ "name": "sold2",
148
+ "max_keypoints": 2000,
149
+ "match_threshold": 0.2,
150
+ },
151
+ "preprocessing": {
152
+ "grayscale": True,
153
+ "force_resize": True,
154
+ "resize_max": 1024,
155
+ "width": 640,
156
+ "height": 480,
157
+ "dfactor": 8,
158
+ },
159
+ },
160
+ "gluestick": {
161
+ "output": "matches-gluestick",
162
+ "model": {
163
+ "name": "gluestick",
164
+ "use_lines": True,
165
+ "max_keypoints": 1000,
166
+ "max_lines": 300,
167
+ "force_num_keypoints": False,
168
+ },
169
+ "preprocessing": {
170
+ "grayscale": True,
171
+ "force_resize": True,
172
+ "resize_max": 1024,
173
+ "width": 640,
174
+ "height": 480,
175
+ "dfactor": 8,
176
+ },
177
+ },
178
+ }
179
+
180
+
181
+ def scale_keypoints(kpts, scale):
182
+ if np.any(scale != 1.0):
183
+ kpts *= kpts.new_tensor(scale)
184
+ return kpts
185
+
186
+
187
+ def scale_lines(lines, scale):
188
+ if np.any(scale != 1.0):
189
+ lines *= lines.new_tensor(scale)
190
+ return lines
191
+
192
+
193
+ def match(model, path_0, path_1, conf):
194
+ default_conf = {
195
+ "grayscale": True,
196
+ "resize_max": 1024,
197
+ "dfactor": 8,
198
+ "cache_images": False,
199
+ "force_resize": False,
200
+ "width": 320,
201
+ "height": 240,
202
+ }
203
+
204
+ def preprocess(image: np.ndarray):
205
+ image = image.astype(np.float32, copy=False)
206
+ size = image.shape[:2][::-1]
207
+ scale = np.array([1.0, 1.0])
208
+ if conf.resize_max:
209
+ scale = conf.resize_max / max(size)
210
+ if scale < 1.0:
211
+ size_new = tuple(int(round(x * scale)) for x in size)
212
+ image = resize_image(image, size_new, "cv2_area")
213
+ scale = np.array(size) / np.array(size_new)
214
+ if conf.force_resize:
215
+ size = image.shape[:2][::-1]
216
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
217
+ size_new = (conf.width, conf.height)
218
+ scale = np.array(size) / np.array(size_new)
219
+ if conf.grayscale:
220
+ assert image.ndim == 2, image.shape
221
+ image = image[None]
222
+ else:
223
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
224
+ image = torch.from_numpy(image / 255.0).float()
225
+ # assure that the size is divisible by dfactor
226
+ size_new = tuple(
227
+ map(lambda x: int(x // conf.dfactor * conf.dfactor), image.shape[-2:])
228
+ )
229
+ image = F.resize(image, size=size_new, antialias=True)
230
+ scale = np.array(size) / np.array(size_new)[::-1]
231
+ return image, scale
232
+
233
+ conf = SimpleNamespace(**{**default_conf, **conf})
234
+ image0 = read_image(path_0, conf.grayscale)
235
+ image1 = read_image(path_1, conf.grayscale)
236
+ image0, scale0 = preprocess(image0)
237
+ image1, scale1 = preprocess(image1)
238
+ image0 = image0.to(device)[None]
239
+ image1 = image1.to(device)[None]
240
+ pred = model({"image0": image0, "image1": image1})
241
+
242
+ # Rescale keypoints and move to cpu
243
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
244
+ kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5
245
+ kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5
246
+
247
+ ret = {
248
+ "image0": image0.squeeze().cpu().numpy(),
249
+ "image1": image1.squeeze().cpu().numpy(),
250
+ "keypoints0": kpts0.cpu().numpy(),
251
+ "keypoints1": kpts1.cpu().numpy(),
252
+ }
253
+ if "mconf" in pred.keys():
254
+ ret["mconf"] = pred["mconf"].cpu().numpy()
255
+ return ret
256
+
257
+
258
+ @torch.no_grad()
259
+ def match_images(model, image_0, image_1, conf, device="cpu"):
260
+ default_conf = {
261
+ "grayscale": True,
262
+ "resize_max": 1024,
263
+ "dfactor": 8,
264
+ "cache_images": False,
265
+ "force_resize": False,
266
+ "width": 320,
267
+ "height": 240,
268
+ }
269
+
270
+ def preprocess(image: np.ndarray):
271
+ image = image.astype(np.float32, copy=False)
272
+ size = image.shape[:2][::-1]
273
+ scale = np.array([1.0, 1.0])
274
+ if conf.resize_max:
275
+ scale = conf.resize_max / max(size)
276
+ if scale < 1.0:
277
+ size_new = tuple(int(round(x * scale)) for x in size)
278
+ image = resize_image(image, size_new, "cv2_area")
279
+ scale = np.array(size) / np.array(size_new)
280
+ if conf.force_resize:
281
+ size = image.shape[:2][::-1]
282
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
283
+ size_new = (conf.width, conf.height)
284
+ scale = np.array(size) / np.array(size_new)
285
+ if conf.grayscale:
286
+ assert image.ndim == 2, image.shape
287
+ image = image[None]
288
+ else:
289
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
290
+ image = torch.from_numpy(image / 255.0).float()
291
+
292
+ # assure that the size is divisible by dfactor
293
+ size_new = tuple(
294
+ map(lambda x: int(x // conf.dfactor * conf.dfactor), image.shape[-2:])
295
+ )
296
+ image = F.resize(image, size=size_new)
297
+ scale = np.array(size) / np.array(size_new)[::-1]
298
+ return image, scale
299
+
300
+ conf = SimpleNamespace(**{**default_conf, **conf})
301
+
302
+ if len(image_0.shape) == 3 and conf.grayscale:
303
+ image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY)
304
+ else:
305
+ image0 = image_0
306
+ if len(image_0.shape) == 3 and conf.grayscale:
307
+ image1 = cv2.cvtColor(image_1, cv2.COLOR_RGB2GRAY)
308
+ else:
309
+ image1 = image_1
310
+
311
+ # comment following lines, image is always RGB mode
312
+ # if not conf.grayscale and len(image0.shape) == 3:
313
+ # image0 = image0[:, :, ::-1] # BGR to RGB
314
+ # if not conf.grayscale and len(image1.shape) == 3:
315
+ # image1 = image1[:, :, ::-1] # BGR to RGB
316
+
317
+ image0, scale0 = preprocess(image0)
318
+ image1, scale1 = preprocess(image1)
319
+ image0 = image0.to(device)[None]
320
+ image1 = image1.to(device)[None]
321
+ pred = model({"image0": image0, "image1": image1})
322
+
323
+ s0 = np.array(image_0.shape[:2][::-1]) / np.array(image0.shape[-2:][::-1])
324
+ s1 = np.array(image_1.shape[:2][::-1]) / np.array(image1.shape[-2:][::-1])
325
+
326
+ # Rescale keypoints and move to cpu
327
+ if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
328
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
329
+ kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
330
+ kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
331
+
332
+ ret = {
333
+ "image0": image0.squeeze().cpu().numpy(),
334
+ "image1": image1.squeeze().cpu().numpy(),
335
+ "image0_orig": image_0,
336
+ "image1_orig": image_1,
337
+ "keypoints0": kpts0.cpu().numpy(),
338
+ "keypoints1": kpts1.cpu().numpy(),
339
+ "keypoints0_orig": kpts0_origin.cpu().numpy(),
340
+ "keypoints1_orig": kpts1_origin.cpu().numpy(),
341
+ "original_size0": np.array(image_0.shape[:2][::-1]),
342
+ "original_size1": np.array(image_1.shape[:2][::-1]),
343
+ "new_size0": np.array(image0.shape[-2:][::-1]),
344
+ "new_size1": np.array(image1.shape[-2:][::-1]),
345
+ "scale0": s0,
346
+ "scale1": s1,
347
+ }
348
+ if "mconf" in pred.keys():
349
+ ret["mconf"] = pred["mconf"].cpu().numpy()
350
+ if "lines0" in pred.keys() and "lines1" in pred.keys():
351
+ if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
352
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
353
+ kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
354
+ kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
355
+ kpts0_origin = kpts0_origin.cpu().numpy()
356
+ kpts1_origin = kpts1_origin.cpu().numpy()
357
+ else:
358
+ kpts0_origin, kpts1_origin = None, None # np.zeros([0]), np.zeros([0])
359
+ lines0, lines1 = pred["lines0"], pred["lines1"]
360
+ lines0_raw, lines1_raw = pred["raw_lines0"], pred["raw_lines1"]
361
+
362
+ lines0_raw = torch.from_numpy(lines0_raw.copy())
363
+ lines1_raw = torch.from_numpy(lines1_raw.copy())
364
+ lines0_raw = scale_lines(lines0_raw + 0.5, s0) - 0.5
365
+ lines1_raw = scale_lines(lines1_raw + 0.5, s1) - 0.5
366
+
367
+ lines0 = torch.from_numpy(lines0.copy())
368
+ lines1 = torch.from_numpy(lines1.copy())
369
+ lines0 = scale_lines(lines0 + 0.5, s0) - 0.5
370
+ lines1 = scale_lines(lines1 + 0.5, s1) - 0.5
371
+
372
+ ret = {
373
+ "image0_orig": image_0,
374
+ "image1_orig": image_1,
375
+ "line0": lines0_raw.cpu().numpy(),
376
+ "line1": lines1_raw.cpu().numpy(),
377
+ "line0_orig": lines0.cpu().numpy(),
378
+ "line1_orig": lines1.cpu().numpy(),
379
+ "line_keypoints0_orig": kpts0_origin,
380
+ "line_keypoints1_orig": kpts1_origin,
381
+ }
382
+ del pred
383
+ torch.cuda.empty_cache()
384
+ return ret
hloc/match_features.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Union, Optional, Dict, List, Tuple
3
+ from pathlib import Path
4
+ import pprint
5
+ from queue import Queue
6
+ from threading import Thread
7
+ from functools import partial
8
+ from tqdm import tqdm
9
+ import h5py
10
+ import torch
11
+
12
+ from . import matchers, logger
13
+ from .utils.base_model import dynamic_load
14
+ from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
15
+ import numpy as np
16
+
17
+ """
18
+ A set of standard configurations that can be directly selected from the command
19
+ line using their name. Each is a dictionary with the following entries:
20
+ - output: the name of the match file that will be generated.
21
+ - model: the model configuration, as passed to a feature matcher.
22
+ """
23
+ confs = {
24
+ "superglue": {
25
+ "output": "matches-superglue",
26
+ "model": {
27
+ "name": "superglue",
28
+ "weights": "outdoor",
29
+ "sinkhorn_iterations": 50,
30
+ "match_threshold": 0.2,
31
+ },
32
+ "preprocessing": {
33
+ "grayscale": True,
34
+ "resize_max": 1024,
35
+ "dfactor": 8,
36
+ "force_resize": False,
37
+ },
38
+ },
39
+ "superglue-fast": {
40
+ "output": "matches-superglue-it5",
41
+ "model": {
42
+ "name": "superglue",
43
+ "weights": "outdoor",
44
+ "sinkhorn_iterations": 5,
45
+ "match_threshold": 0.2,
46
+ },
47
+ },
48
+ "superpoint-lightglue": {
49
+ "output": "matches-lightglue",
50
+ "model": {
51
+ "name": "lightglue",
52
+ "match_threshold": 0.2,
53
+ "width_confidence": 0.99, # for point pruning
54
+ "depth_confidence": 0.95, # for early stopping,
55
+ "features": "superpoint",
56
+ "model_name": "superpoint_lightglue.pth",
57
+ },
58
+ "preprocessing": {
59
+ "grayscale": True,
60
+ "resize_max": 1024,
61
+ "dfactor": 8,
62
+ "force_resize": False,
63
+ },
64
+ },
65
+ "disk-lightglue": {
66
+ "output": "matches-lightglue",
67
+ "model": {
68
+ "name": "lightglue",
69
+ "match_threshold": 0.2,
70
+ "width_confidence": 0.99, # for point pruning
71
+ "depth_confidence": 0.95, # for early stopping,
72
+ "features": "disk",
73
+ "model_name": "disk_lightglue.pth",
74
+ },
75
+ "preprocessing": {
76
+ "grayscale": True,
77
+ "resize_max": 1024,
78
+ "dfactor": 8,
79
+ "force_resize": False,
80
+ },
81
+ },
82
+ "sgmnet": {
83
+ "output": "matches-sgmnet",
84
+ "model": {
85
+ "name": "sgmnet",
86
+ "seed_top_k": [256, 256],
87
+ "seed_radius_coe": 0.01,
88
+ "net_channels": 128,
89
+ "layer_num": 9,
90
+ "head": 4,
91
+ "seedlayer": [0, 6],
92
+ "use_mc_seeding": True,
93
+ "use_score_encoding": False,
94
+ "conf_bar": [1.11, 0.1],
95
+ "sink_iter": [10, 100],
96
+ "detach_iter": 1000000,
97
+ "match_threshold": 0.2,
98
+ },
99
+ "preprocessing": {
100
+ "grayscale": True,
101
+ "resize_max": 1024,
102
+ "dfactor": 8,
103
+ "force_resize": False,
104
+ },
105
+ },
106
+ "NN-superpoint": {
107
+ "output": "matches-NN-mutual-dist.7",
108
+ "model": {
109
+ "name": "nearest_neighbor",
110
+ "do_mutual_check": True,
111
+ "distance_threshold": 0.7,
112
+ "match_threshold": 0.2,
113
+ },
114
+ },
115
+ "NN-ratio": {
116
+ "output": "matches-NN-mutual-ratio.8",
117
+ "model": {
118
+ "name": "nearest_neighbor",
119
+ "do_mutual_check": True,
120
+ "ratio_threshold": 0.8,
121
+ "match_threshold": 0.2,
122
+ },
123
+ },
124
+ "NN-mutual": {
125
+ "output": "matches-NN-mutual",
126
+ "model": {
127
+ "name": "nearest_neighbor",
128
+ "do_mutual_check": True,
129
+ "match_threshold": 0.2,
130
+ },
131
+ },
132
+ "Dual-Softmax": {
133
+ "output": "matches-Dual-Softmax",
134
+ "model": {
135
+ "name": "dual_softmax",
136
+ "do_mutual_check": True,
137
+ "match_threshold": 0.2, # TODO
138
+ },
139
+ },
140
+ "adalam": {
141
+ "output": "matches-adalam",
142
+ "model": {
143
+ "name": "adalam",
144
+ "match_threshold": 0.2,
145
+ },
146
+ },
147
+ }
148
+
149
+
150
+ class WorkQueue:
151
+ def __init__(self, work_fn, num_threads=1):
152
+ self.queue = Queue(num_threads)
153
+ self.threads = [
154
+ Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads)
155
+ ]
156
+ for thread in self.threads:
157
+ thread.start()
158
+
159
+ def join(self):
160
+ for thread in self.threads:
161
+ self.queue.put(None)
162
+ for thread in self.threads:
163
+ thread.join()
164
+
165
+ def thread_fn(self, work_fn):
166
+ item = self.queue.get()
167
+ while item is not None:
168
+ work_fn(item)
169
+ item = self.queue.get()
170
+
171
+ def put(self, data):
172
+ self.queue.put(data)
173
+
174
+
175
+ class FeaturePairsDataset(torch.utils.data.Dataset):
176
+ def __init__(self, pairs, feature_path_q, feature_path_r):
177
+ self.pairs = pairs
178
+ self.feature_path_q = feature_path_q
179
+ self.feature_path_r = feature_path_r
180
+
181
+ def __getitem__(self, idx):
182
+ name0, name1 = self.pairs[idx]
183
+ data = {}
184
+ with h5py.File(self.feature_path_q, "r") as fd:
185
+ grp = fd[name0]
186
+ for k, v in grp.items():
187
+ data[k + "0"] = torch.from_numpy(v.__array__()).float()
188
+ # some matchers might expect an image but only use its size
189
+ data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
190
+ with h5py.File(self.feature_path_r, "r") as fd:
191
+ grp = fd[name1]
192
+ for k, v in grp.items():
193
+ data[k + "1"] = torch.from_numpy(v.__array__()).float()
194
+ data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
195
+ return data
196
+
197
+ def __len__(self):
198
+ return len(self.pairs)
199
+
200
+
201
+ def writer_fn(inp, match_path):
202
+ pair, pred = inp
203
+ with h5py.File(str(match_path), "a", libver="latest") as fd:
204
+ if pair in fd:
205
+ del fd[pair]
206
+ grp = fd.create_group(pair)
207
+ matches = pred["matches0"][0].cpu().short().numpy()
208
+ grp.create_dataset("matches0", data=matches)
209
+ if "matching_scores0" in pred:
210
+ scores = pred["matching_scores0"][0].cpu().half().numpy()
211
+ grp.create_dataset("matching_scores0", data=scores)
212
+
213
+
214
+ def main(
215
+ conf: Dict,
216
+ pairs: Path,
217
+ features: Union[Path, str],
218
+ export_dir: Optional[Path] = None,
219
+ matches: Optional[Path] = None,
220
+ features_ref: Optional[Path] = None,
221
+ overwrite: bool = False,
222
+ ) -> Path:
223
+
224
+ if isinstance(features, Path) or Path(features).exists():
225
+ features_q = features
226
+ if matches is None:
227
+ raise ValueError(
228
+ "Either provide both features and matches as Path" " or both as names."
229
+ )
230
+ else:
231
+ if export_dir is None:
232
+ raise ValueError(
233
+ "Provide an export_dir if features is not" f" a file path: {features}."
234
+ )
235
+ features_q = Path(export_dir, features + ".h5")
236
+ if matches is None:
237
+ matches = Path(export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5')
238
+
239
+ if features_ref is None:
240
+ features_ref = features_q
241
+ match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)
242
+
243
+ return matches
244
+
245
+
246
+ def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None):
247
+ """Avoid to recompute duplicates to save time."""
248
+ pairs = set()
249
+ for i, j in pairs_all:
250
+ if (j, i) not in pairs:
251
+ pairs.add((i, j))
252
+ pairs = list(pairs)
253
+ if match_path is not None and match_path.exists():
254
+ with h5py.File(str(match_path), "r", libver="latest") as fd:
255
+ pairs_filtered = []
256
+ for i, j in pairs:
257
+ if (
258
+ names_to_pair(i, j) in fd
259
+ or names_to_pair(j, i) in fd
260
+ or names_to_pair_old(i, j) in fd
261
+ or names_to_pair_old(j, i) in fd
262
+ ):
263
+ continue
264
+ pairs_filtered.append((i, j))
265
+ return pairs_filtered
266
+ return pairs
267
+
268
+
269
+ @torch.no_grad()
270
+ def match_from_paths(
271
+ conf: Dict,
272
+ pairs_path: Path,
273
+ match_path: Path,
274
+ feature_path_q: Path,
275
+ feature_path_ref: Path,
276
+ overwrite: bool = False,
277
+ ) -> Path:
278
+ logger.info(
279
+ "Matching local features with configuration:" f"\n{pprint.pformat(conf)}"
280
+ )
281
+
282
+ if not feature_path_q.exists():
283
+ raise FileNotFoundError(f"Query feature file {feature_path_q}.")
284
+ if not feature_path_ref.exists():
285
+ raise FileNotFoundError(f"Reference feature file {feature_path_ref}.")
286
+ match_path.parent.mkdir(exist_ok=True, parents=True)
287
+
288
+ assert pairs_path.exists(), pairs_path
289
+ pairs = parse_retrieval(pairs_path)
290
+ pairs = [(q, r) for q, rs in pairs.items() for r in rs]
291
+ pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
292
+ if len(pairs) == 0:
293
+ logger.info("Skipping the matching.")
294
+ return
295
+
296
+ device = "cuda" if torch.cuda.is_available() else "cpu"
297
+ Model = dynamic_load(matchers, conf["model"]["name"])
298
+ model = Model(conf["model"]).eval().to(device)
299
+
300
+ dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
301
+ loader = torch.utils.data.DataLoader(
302
+ dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True
303
+ )
304
+ writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)
305
+
306
+ for idx, data in enumerate(tqdm(loader, smoothing=0.1)):
307
+ data = {
308
+ k: v if k.startswith("image") else v.to(device, non_blocking=True)
309
+ for k, v in data.items()
310
+ }
311
+ pred = model(data)
312
+ pair = names_to_pair(*pairs[idx])
313
+ writer_queue.put((pair, pred))
314
+ writer_queue.join()
315
+ logger.info("Finished exporting matches.")
316
+
317
+
318
+ def scale_keypoints(kpts, scale):
319
+ if np.any(scale != 1.0):
320
+ kpts *= kpts.new_tensor(scale)
321
+ return kpts
322
+
323
+
324
+ @torch.no_grad()
325
+ def match_images(model, feat0, feat1):
326
+ # forward pass to match keypoints
327
+ desc0 = feat0["descriptors"][0]
328
+ desc1 = feat1["descriptors"][0]
329
+ if len(desc0.shape) == 2:
330
+ desc0 = desc0.unsqueeze(0)
331
+ if len(desc1.shape) == 2:
332
+ desc1 = desc1.unsqueeze(0)
333
+ pred = model(
334
+ {
335
+ "image0": feat0["image"],
336
+ "keypoints0": feat0["keypoints"][0],
337
+ "scores0": feat0["scores"][0].unsqueeze(0),
338
+ "descriptors0": desc0,
339
+ "image1": feat1["image"],
340
+ "keypoints1": feat1["keypoints"][0],
341
+ "scores1": feat1["scores"][0].unsqueeze(0),
342
+ "descriptors1": desc1,
343
+ }
344
+ )
345
+ pred = {
346
+ k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
347
+ for k, v in pred.items()
348
+ }
349
+ kpts0, kpts1 = (
350
+ feat0["keypoints"][0].cpu().numpy(),
351
+ feat1["keypoints"][0].cpu().numpy(),
352
+ )
353
+ matches, confid = pred["matches0"], pred["matching_scores0"]
354
+ # Keep the matching keypoints.
355
+ valid = matches > -1
356
+ mkpts0 = kpts0[valid]
357
+ mkpts1 = kpts1[matches[valid]]
358
+ mconfid = confid[valid]
359
+ # rescale the keypoints to their original size
360
+ s0 = feat0["original_size"] / feat0["size"]
361
+ s1 = feat1["original_size"] / feat1["size"]
362
+ kpts0_origin = scale_keypoints(torch.from_numpy(mkpts0 + 0.5), s0) - 0.5
363
+ kpts1_origin = scale_keypoints(torch.from_numpy(mkpts1 + 0.5), s1) - 0.5
364
+ ret = {
365
+ "image0_orig": feat0["image_orig"],
366
+ "image1_orig": feat1["image_orig"],
367
+ "keypoints0": kpts0,
368
+ "keypoints1": kpts1,
369
+ "keypoints0_orig": kpts0_origin.numpy(),
370
+ "keypoints1_orig": kpts1_origin.numpy(),
371
+ "mconf": mconfid,
372
+ }
373
+ del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin
374
+ torch.cuda.empty_cache()
375
+
376
+ return ret
377
+
378
+
379
+ if __name__ == "__main__":
380
+ parser = argparse.ArgumentParser()
381
+ parser.add_argument("--pairs", type=Path, required=True)
382
+ parser.add_argument("--export_dir", type=Path)
383
+ parser.add_argument("--features", type=str, default="feats-superpoint-n4096-r1024")
384
+ parser.add_argument("--matches", type=Path)
385
+ parser.add_argument(
386
+ "--conf", type=str, default="superglue", choices=list(confs.keys())
387
+ )
388
+ args = parser.parse_args()
389
+ main(confs[args.conf], args.pairs, args.features, args.export_dir)
hloc/matchers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def get_matcher(matcher):
2
+ mod = __import__(f"{__name__}.{matcher}", fromlist=[""])
3
+ return getattr(mod, "Model")
hloc/matchers/adalam.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..utils.base_model import BaseModel
4
+
5
+ from kornia.feature.adalam import AdalamFilter
6
+ from kornia.utils.helpers import get_cuda_device_if_available
7
+
8
+
9
+ class AdaLAM(BaseModel):
10
+ # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
11
+ default_conf = {
12
+ "area_ratio": 100,
13
+ "search_expansion": 4,
14
+ "ransac_iters": 128,
15
+ "min_inliers": 6,
16
+ "min_confidence": 200,
17
+ "orientation_difference_threshold": 30,
18
+ "scale_rate_threshold": 1.5,
19
+ "detected_scale_rate_threshold": 5,
20
+ "refit": True,
21
+ "force_seed_mnn": True,
22
+ "device": get_cuda_device_if_available(),
23
+ }
24
+ required_inputs = [
25
+ "image0",
26
+ "image1",
27
+ "descriptors0",
28
+ "descriptors1",
29
+ "keypoints0",
30
+ "keypoints1",
31
+ "scales0",
32
+ "scales1",
33
+ "oris0",
34
+ "oris1",
35
+ ]
36
+
37
+ def _init(self, conf):
38
+ self.adalam = AdalamFilter(conf)
39
+
40
+ def _forward(self, data):
41
+ assert data["keypoints0"].size(0) == 1
42
+ if data["keypoints0"].size(1) < 2 or data["keypoints1"].size(1) < 2:
43
+ matches = torch.zeros(
44
+ (0, 2), dtype=torch.int64, device=data["keypoints0"].device
45
+ )
46
+ else:
47
+ matches = self.adalam.match_and_filter(
48
+ data["keypoints0"][0],
49
+ data["keypoints1"][0],
50
+ data["descriptors0"][0].T,
51
+ data["descriptors1"][0].T,
52
+ data["image0"].shape[2:],
53
+ data["image1"].shape[2:],
54
+ data["oris0"][0],
55
+ data["oris1"][0],
56
+ data["scales0"][0],
57
+ data["scales1"][0],
58
+ )
59
+ matches_new = torch.full(
60
+ (data["keypoints0"].size(1),),
61
+ -1,
62
+ dtype=torch.int64,
63
+ device=data["keypoints0"].device,
64
+ )
65
+ matches_new[matches[:, 0]] = matches[:, 1]
66
+ return {
67
+ "matches0": matches_new.unsqueeze(0),
68
+ "matching_scores0": torch.zeros(matches_new.size(0)).unsqueeze(0),
69
+ }
hloc/matchers/aspanformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from ..utils.base_model import BaseModel
4
+ from ..utils import do_system
5
+ from pathlib import Path
6
+ import subprocess
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ sys.path.append(str(Path(__file__).parent / "../../third_party"))
12
+ from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
13
+ from ASpanFormer.src.config.default import get_cfg_defaults
14
+ from ASpanFormer.src.utils.misc import lower_config
15
+ from ASpanFormer.demo import demo_utils
16
+
17
+ aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
18
+
19
+
20
+ class ASpanFormer(BaseModel):
21
+ default_conf = {
22
+ "weights": "outdoor",
23
+ "match_threshold": 0.2,
24
+ "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
25
+ "model_name": "weights_aspanformer.tar",
26
+ }
27
+ required_inputs = ["image0", "image1"]
28
+ proxy = "http://localhost:1080"
29
+ aspanformer_models = {
30
+ "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
31
+ }
32
+
33
+ def _init(self, conf):
34
+ model_path = aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
35
+ # Download the model.
36
+ if not model_path.exists():
37
+ # model_path.parent.mkdir(exist_ok=True)
38
+ tar_path = aspanformer_path / conf["model_name"]
39
+ if not tar_path.exists():
40
+ link = self.aspanformer_models[conf["model_name"]]
41
+ cmd = ["gdown", link, "-O", str(tar_path), "--proxy", self.proxy]
42
+ cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
43
+ logger.info(f"Downloading the Aspanformer model with `{cmd_wo_proxy}`.")
44
+ try:
45
+ subprocess.run(cmd_wo_proxy, check=True)
46
+ except subprocess.CalledProcessError as e:
47
+ logger.info(f"Downloading the Aspanformer model with `{cmd}`.")
48
+ try:
49
+ subprocess.run(cmd, check=True)
50
+ except subprocess.CalledProcessError as e:
51
+ logger.error(f"Failed to download the Aspanformer model.")
52
+ raise e
53
+
54
+ do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")
55
+
56
+ logger.info(f"Loading Aspanformer model...")
57
+
58
+ config = get_cfg_defaults()
59
+ config.merge_from_file(conf["config_path"])
60
+ _config = lower_config(config)
61
+ self.net = _ASpanFormer(config=_config["aspan"])
62
+ weight_path = model_path
63
+ state_dict = torch.load(str(weight_path), map_location="cpu")["state_dict"]
64
+ self.net.load_state_dict(state_dict, strict=False)
65
+
66
+ def _forward(self, data):
67
+ data_ = {
68
+ "image0": data["image0"],
69
+ "image1": data["image1"],
70
+ }
71
+ self.net(data_, online_resize=True)
72
+ corr0 = data_["mkpts0_f"]
73
+ corr1 = data_["mkpts1_f"]
74
+ pred = {}
75
+ pred["keypoints0"], pred["keypoints1"] = corr0, corr1
76
+ return pred