diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d22cef17abd65c6b9edcc26dfb84d6ae5fe6c6ac 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8cbe2e58034837efc5de2fa55d0963db97d83ad0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,19 @@
+build/
+
+lib/
+bin/
+
+cmake_modules/
+cmake-build-debug/
+.idea/
+.vscode/
+*.pyc
+flagged
+.ipynb_checkpoints
+__pycache__
+Untitled*
+experiments
+third_party/REKD
+Dockerfile
+hloc/matchers/dedode.py
+gradio_cached_examples
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..911aa1f7b9984ee40059a313c837ed8d52d86c44
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,45 @@
+[submodule "third_party/Roma"]
+ path = third_party/Roma
+ url = https://github.com/Vincentqyw/RoMa.git
+[submodule "third_party/SuperGluePretrainedNetwork"]
+ path = third_party/SuperGluePretrainedNetwork
+ url = https://github.com/magicleap/SuperGluePretrainedNetwork.git
+[submodule "third_party/SOLD2"]
+ path = third_party/SOLD2
+ url = https://github.com/cvg/SOLD2.git
+[submodule "third_party/GlueStick"]
+ path = third_party/GlueStick
+ url = https://github.com/cvg/GlueStick.git
+[submodule "third_party/ASpanFormer"]
+ path = third_party/ASpanFormer
+ url = https://github.com/Vincentqyw/ml-aspanformer.git
+[submodule "third_party/TopicFM"]
+ path = third_party/TopicFM
+ url = https://github.com/Vincentqyw/TopicFM.git
+[submodule "third_party/d2net"]
+ path = third_party/d2net
+ url = https://github.com/Vincentqyw/d2-net.git
+[submodule "third_party/r2d2"]
+ path = third_party/r2d2
+ url = https://github.com/naver/r2d2.git
+[submodule "third_party/DKM"]
+ path = third_party/DKM
+ url = https://github.com/Vincentqyw/DKM.git
+[submodule "third_party/ALIKE"]
+ path = third_party/ALIKE
+ url = https://github.com/Shiaoming/ALIKE.git
+[submodule "third_party/lanet"]
+ path = third_party/lanet
+ url = https://github.com/wangch-g/lanet.git
+[submodule "third_party/LightGlue"]
+ path = third_party/LightGlue
+ url = https://github.com/cvg/LightGlue.git
+[submodule "third_party/SGMNet"]
+ path = third_party/SGMNet
+ url = https://github.com/vdvchen/SGMNet.git
+[submodule "third_party/DarkFeat"]
+ path = third_party/DarkFeat
+ url = https://github.com/THU-LYJ-Lab/DarkFeat.git
+[submodule "third_party/DeDoDe"]
+ path = third_party/DeDoDe
+ url = https://github.com/Parskatt/DeDoDe.git
diff --git a/README.md b/README.md
index 1d15749a0bb99fdb5d2edeb948e735b262a3f021..a2fc83a35620a5b8dddda6864dd8fd1a04ca3c38 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,107 @@
----
-title: Image Matching Webui
-emoji: 🏆
-colorFrom: indigo
-colorTo: gray
-sdk: gradio
-sdk_version: 3.35.2
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+[![Contributors][contributors-shield]][contributors-url]
+[![Forks][forks-shield]][forks-url]
+[![Stargazers][stars-shield]][stars-url]
+[![Issues][issues-shield]][issues-url]
+
+
+
Image Matching WebUI
find matches between 2 images
+
+
+## Description
+
+This simple tool efficiently matches image pairs using multiple famous image matching algorithms. The tool features a Graphical User Interface (GUI) designed using [gradio](https://gradio.app/). You can effortlessly select two images and a matching algorithm and obtain a precise matching result.
+**Note**: the images source can be either local images or webcam images.
+
+Here is a demo of the tool:
+
+![demo](assets/demo.gif)
+
+The tool currently supports various popular image matching algorithms, namely:
+- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
+- [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), TBD
+- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
+- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
+- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
+- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
+- [x] [GlueStick](https://github.com/cvg/GlueStick), ArXiv 2023
+- [ ] [ConvMatch](https://github.com/SuhZhang/ConvMatch), AAAI 2023
+- [x] [SOLD2](https://github.com/cvg/SOLD2), CVPR 2021
+- [ ] [LineTR](https://github.com/yosungho/LineTR), RA-L 2021
+- [x] [DKM](https://github.com/Parskatt/DKM), CVPR 2023
+- [x] [RoMa](https://github.com/Vincentqyw/RoMa), Arxiv 2023
+- [ ] [NCMNet](https://github.com/xinliu29/NCMNet), CVPR 2023
+- [x] [TopicFM](https://github.com/Vincentqyw/TopicFM), AAAI 2023
+- [x] [AspanFormer](https://github.com/Vincentqyw/ml-aspanformer), ECCV 2022
+- [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
+- [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
+- [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
+- [x] [ALIKE](https://github.com/Shiaoming/ALIKE), ArXiv 2022
+- [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
+- [x] [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), CVPRW 2018
+- [x] [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), CVPR 2020
+- [x] [D2Net](https://github.com/Vincentqyw/d2-net), CVPR 2019
+- [x] [R2D2](https://github.com/naver/r2d2), NeurIPS 2019
+- [x] [DISK](https://github.com/cvlab-epfl/disk), NeurIPS 2020
+- [ ] [Key.Net](https://github.com/axelBarroso/Key.Net), ICCV 2019
+- [ ] [OANet](https://github.com/zjhthu/OANet), ICCV 2019
+- [ ] [SOSNet](https://github.com/scape-research/SOSNet), CVPR 2019
+- [x] [SIFT](https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html), IJCV 2004
+
+## How to use
+
+### requirements
+``` bash
+git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git
+cd image-matching-webui
+conda env create -f environment.yaml
+conda activate imw
+```
+
+### run demo
+``` bash
+python3 ./app.py
+```
+then open http://localhost:7860 in your browser.
+
+![](assets/gui.jpg)
+
+### Add your own feature / matcher
+
+I provide an example to add local feature in [hloc/extractors/example.py](hloc/extractors/example.py). Then add feature settings in `confs` in file [hloc/extract_features.py](hloc/extract_features.py). Last step is adding some settings to `model_zoo` in file [extra_utils/utils.py](extra_utils/utils.py).
+
+## Contributions welcome!
+
+External contributions are very much welcome. Please follow the [PEP8 style guidelines](https://www.python.org/dev/peps/pep-0008/) using a linter like flake8 (reformat using command `python -m black .`). This is a non-exhaustive list of features that might be valuable additions:
+
+- [x] add webcam support
+- [x] add [line feature matching](https://github.com/Vincentqyw/LineSegmentsDetection) algorithms
+- [x] example to add a new feature extractor / matcher
+- [ ] ransac to filter outliers
+- [ ] support export matches to colmap ([#issue 6](https://github.com/Vincentqyw/image-matching-webui/issues/6))
+- [ ] add config file to set default parameters
+- [ ] dynamically load models and reduce GPU overload
+
+Adding local features / matchers as submodules is very easy. For example, to add the [GlueStick](https://github.com/cvg/GlueStick):
+
+``` bash
+git submodule add https://github.com/cvg/GlueStick.git third_party/GlueStick
+```
+
+If remote submodule repositories are updated, don't forget to pull submodules with `git submodule update --remote`, if you only want to update one submodule, use `git submodule update --remote third_party/GlueStick`.
+
+## Resources
+- [Image Matching: Local Features & Beyond](https://image-matching-workshop.github.io)
+- [Long-term Visual Localization](https://www.visuallocalization.net)
+
+## Acknowledgement
+
+This code is built based on [Hierarchical-Localization](https://github.com/cvg/Hierarchical-Localization). We express our gratitude to the authors for their valuable source code.
+
+[contributors-shield]: https://img.shields.io/github/contributors/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[contributors-url]: https://github.com/Vincentqyw/image-matching-webui/graphs/contributors
+[forks-shield]: https://img.shields.io/github/forks/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[forks-url]: https://github.com/Vincentqyw/image-matching-webui/network/members
+[stars-shield]: https://img.shields.io/github/stars/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[stars-url]: https://github.com/Vincentqyw/image-matching-webui/stargazers
+[issues-shield]: https://img.shields.io/github/issues/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[issues-url]: https://github.com/Vincentqyw/image-matching-webui/issues
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f40cd2015c277f53245ae7cce49437d2e64e838
--- /dev/null
+++ b/app.py
@@ -0,0 +1,291 @@
+import argparse
+import gradio as gr
+
+from hloc import extract_features
+from extra_utils.utils import (
+ matcher_zoo,
+ device,
+ match_dense,
+ match_features,
+ get_model,
+ get_feature_model,
+ display_matches
+)
+
+def run_matching(
+ match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
+):
+ # image0 and image1 is RGB mode
+ if image0 is None or image1 is None:
+ raise gr.Error("Error: No images found! Please upload two images.")
+
+ model = matcher_zoo[key]
+ match_conf = model["config"]
+ # update match config
+ match_conf["model"]["match_threshold"] = match_threshold
+ match_conf["model"]["max_keypoints"] = extract_max_keypoints
+
+ matcher = get_model(match_conf)
+ if model["dense"]:
+ pred = match_dense.match_images(
+ matcher, image0, image1, match_conf["preprocessing"], device=device
+ )
+ del matcher
+ extract_conf = None
+ else:
+ extract_conf = model["config_feature"]
+ # update extract config
+ extract_conf["model"]["max_keypoints"] = extract_max_keypoints
+ extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
+ extractor = get_feature_model(extract_conf)
+ pred0 = extract_features.extract(
+ extractor, image0, extract_conf["preprocessing"]
+ )
+ pred1 = extract_features.extract(
+ extractor, image1, extract_conf["preprocessing"]
+ )
+ pred = match_features.match_images(matcher, pred0, pred1)
+ del extractor
+ fig, num_inliers = display_matches(pred)
+ del pred
+ return (
+ fig,
+ {"matches number": num_inliers},
+ {"match_conf": match_conf, "extractor_conf": extract_conf},
+ )
+
+
+def ui_change_imagebox(choice):
+ return {"value": None, "source": choice, "__type__": "update"}
+
+
+def ui_reset_state(
+ match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
+):
+ match_threshold = 0.2
+ extract_max_keypoints = 1000
+ keypoint_threshold = 0.015
+ key = list(matcher_zoo.keys())[0]
+ image0 = None
+ image1 = None
+ return (
+ match_threshold,
+ extract_max_keypoints,
+ keypoint_threshold,
+ key,
+ image0,
+ image1,
+ {"value": None, "source": "upload", "__type__": "update"},
+ {"value": None, "source": "upload", "__type__": "update"},
+ "upload",
+ None,
+ {},
+ {},
+ )
+
+
+def run(config):
+ with gr.Blocks(
+ theme=gr.themes.Monochrome(), css="footer {visibility: hidden}"
+ ) as app:
+ gr.Markdown(
+ """
+
+
Image Matching WebUI
+
+ """
+ )
+
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ with gr.Row():
+ matcher_list = gr.Dropdown(
+ choices=list(matcher_zoo.keys()),
+ value="disk+lightglue",
+ label="Matching Model",
+ interactive=True,
+ )
+ match_image_src = gr.Radio(
+ ["upload", "webcam", "canvas"],
+ label="Image Source",
+ value="upload",
+ )
+
+ with gr.Row():
+ match_setting_threshold = gr.Slider(
+ minimum=0.0,
+ maximum=1,
+ step=0.001,
+ label="Match threshold",
+ value=0.1,
+ )
+ match_setting_max_features = gr.Slider(
+ minimum=10,
+ maximum=10000,
+ step=10,
+ label="Max number of features",
+ value=1000,
+ )
+ # TODO: add line settings
+ with gr.Row():
+ detect_keypoints_threshold = gr.Slider(
+ minimum=0,
+ maximum=1,
+ step=0.001,
+ label="Keypoint threshold",
+ value=0.015,
+ )
+ detect_line_threshold = gr.Slider(
+ minimum=0.1,
+ maximum=1,
+ step=0.01,
+ label="Line threshold",
+ value=0.2,
+ )
+ # matcher_lists = gr.Radio(
+ # ["NN-mutual", "Dual-Softmax"],
+ # label="Matcher mode",
+ # value="NN-mutual",
+ # )
+ with gr.Row():
+ input_image0 = gr.Image(
+ label="Image 0",
+ type="numpy",
+ interactive=True,
+ image_mode="RGB",
+ )
+ input_image1 = gr.Image(
+ label="Image 1",
+ type="numpy",
+ interactive=True,
+ image_mode="RGB",
+ )
+
+ with gr.Row():
+ button_reset = gr.Button(label="Reset", value="Reset")
+ button_run = gr.Button(
+ label="Run Match", value="Run Match", variant="primary"
+ )
+
+ with gr.Accordion("Open for More!", open=False):
+ gr.Markdown(
+ f"""
+ Supported Algorithms
+ {", ".join(matcher_zoo.keys())}
+ """
+ )
+
+ # collect inputs
+ inputs = [
+ match_setting_threshold,
+ match_setting_max_features,
+ detect_keypoints_threshold,
+ matcher_list,
+ input_image0,
+ input_image1,
+ ]
+
+ # Add some examples
+ with gr.Row():
+ examples = [
+ [
+ 0.1,
+ 2000,
+ 0.015,
+ "disk+lightglue",
+ "datasets/sacre_coeur/mapping/71295362_4051449754.jpg",
+ "datasets/sacre_coeur/mapping/93341989_396310999.jpg",
+ ],
+ [
+ 0.1,
+ 2000,
+ 0.015,
+ "loftr",
+ "datasets/sacre_coeur/mapping/03903474_1471484089.jpg",
+ "datasets/sacre_coeur/mapping/02928139_3448003521.jpg",
+ ],
+ [
+ 0.1,
+ 2000,
+ 0.015,
+ "disk",
+ "datasets/sacre_coeur/mapping/10265353_3838484249.jpg",
+ "datasets/sacre_coeur/mapping/51091044_3486849416.jpg",
+ ],
+ [
+ 0.1,
+ 2000,
+ 0.015,
+ "topicfm",
+ "datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
+ "datasets/sacre_coeur/mapping/93341989_396310999.jpg",
+ ],
+ [
+ 0.1,
+ 2000,
+ 0.015,
+ "superpoint+superglue",
+ "datasets/sacre_coeur/mapping/17295357_9106075285.jpg",
+ "datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
+ ],
+ ]
+ # Example inputs
+ gr.Examples(
+ examples=examples,
+ inputs=inputs,
+ outputs=[],
+ fn=run_matching,
+ cache_examples=False,
+ label="Examples (click one of the images below to Run Match)",
+ )
+
+ with gr.Column():
+ output_mkpts = gr.Image(label="Keypoints Matching", type="numpy")
+ matches_result_info = gr.JSON(label="Matches Statistics")
+ matcher_info = gr.JSON(label="Match info")
+
+ # callbacks
+ match_image_src.change(
+ fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image0
+ )
+ match_image_src.change(
+ fn=ui_change_imagebox, inputs=match_image_src, outputs=input_image1
+ )
+
+ # collect outputs
+ outputs = [
+ output_mkpts,
+ matches_result_info,
+ matcher_info,
+ ]
+ # button callbacks
+ button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
+
+ # Reset images
+ reset_outputs = [
+ match_setting_threshold,
+ match_setting_max_features,
+ detect_keypoints_threshold,
+ matcher_list,
+ input_image0,
+ input_image1,
+ input_image0,
+ input_image1,
+ match_image_src,
+ output_mkpts,
+ matches_result_info,
+ matcher_info,
+ ]
+ button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
+
+ app.launch(share=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config_path", type=str, default="config.yaml", help="configuration file path"
+ )
+ args = parser.parse_args()
+ config = None
+ run(config)
diff --git a/assets/demo.gif b/assets/demo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..9af81d1ecded321bbd99ac7b84191518d6daf17d
--- /dev/null
+++ b/assets/demo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f163c0e2699181897c81c68e01c60fa4289e886a2a40932d53dd529262d3735
+size 8907062
diff --git a/assets/gui.jpg b/assets/gui.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ac551d33c6df00bc011640b7434c98493a40076f
Binary files /dev/null and b/assets/gui.jpg differ
diff --git a/datasets/.gitignore b/datasets/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/datasets/lines/terrace0.JPG b/datasets/lines/terrace0.JPG
new file mode 100644
index 0000000000000000000000000000000000000000..bca7123ec9e472916cb873fbba8c077b5c44134c
Binary files /dev/null and b/datasets/lines/terrace0.JPG differ
diff --git a/datasets/lines/terrace1.JPG b/datasets/lines/terrace1.JPG
new file mode 100644
index 0000000000000000000000000000000000000000..7292c15f867476256ab1d34eff0e2623d899b291
Binary files /dev/null and b/datasets/lines/terrace1.JPG differ
diff --git a/datasets/sacre_coeur/README.md b/datasets/sacre_coeur/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d69115f7f262f6d97aa52bed9083bf3374249645
--- /dev/null
+++ b/datasets/sacre_coeur/README.md
@@ -0,0 +1,3 @@
+# Sacre Coeur demo
+
+We provide here a subset of images depicting the Sacre Coeur. These images were obtained from the [Image Matching Challenge 2021](https://www.cs.ubc.ca/research/image-matching-challenge/2021/data/) and were originally collected by the [Yahoo Flickr Creative Commons 100M (YFCC) dataset](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).
diff --git a/datasets/sacre_coeur/mapping/02928139_3448003521.jpg b/datasets/sacre_coeur/mapping/02928139_3448003521.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..639eb56d43580720a503930ff10835b25f103240
Binary files /dev/null and b/datasets/sacre_coeur/mapping/02928139_3448003521.jpg differ
diff --git a/datasets/sacre_coeur/mapping/03903474_1471484089.jpg b/datasets/sacre_coeur/mapping/03903474_1471484089.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6cc94daa30e953c7de3140c6548b6f2755ef3e7d
Binary files /dev/null and b/datasets/sacre_coeur/mapping/03903474_1471484089.jpg differ
diff --git a/datasets/sacre_coeur/mapping/10265353_3838484249.jpg b/datasets/sacre_coeur/mapping/10265353_3838484249.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bd2c0dcf8934fe80249e38fc78ff4a72ea4c6904
Binary files /dev/null and b/datasets/sacre_coeur/mapping/10265353_3838484249.jpg differ
diff --git a/datasets/sacre_coeur/mapping/17295357_9106075285.jpg b/datasets/sacre_coeur/mapping/17295357_9106075285.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f2db127eb8b00a467d383568d97250e37632ef71
Binary files /dev/null and b/datasets/sacre_coeur/mapping/17295357_9106075285.jpg differ
diff --git a/datasets/sacre_coeur/mapping/32809961_8274055477.jpg b/datasets/sacre_coeur/mapping/32809961_8274055477.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e5650dfda836ac0d357542d84cb111e2cc145a0c
Binary files /dev/null and b/datasets/sacre_coeur/mapping/32809961_8274055477.jpg differ
diff --git a/datasets/sacre_coeur/mapping/44120379_8371960244.jpg b/datasets/sacre_coeur/mapping/44120379_8371960244.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a2fa9c4c760b85b3cb8493fb550b015bb2e4b635
Binary files /dev/null and b/datasets/sacre_coeur/mapping/44120379_8371960244.jpg differ
diff --git a/datasets/sacre_coeur/mapping/51091044_3486849416.jpg b/datasets/sacre_coeur/mapping/51091044_3486849416.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..640e3351e47cdb6568ff25316aede0478c173f9b
Binary files /dev/null and b/datasets/sacre_coeur/mapping/51091044_3486849416.jpg differ
diff --git a/datasets/sacre_coeur/mapping/60584745_2207571072.jpg b/datasets/sacre_coeur/mapping/60584745_2207571072.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..788e25dc21c8e0a1d50148742c51c0f42555b7b6
Binary files /dev/null and b/datasets/sacre_coeur/mapping/60584745_2207571072.jpg differ
diff --git a/datasets/sacre_coeur/mapping/71295362_4051449754.jpg b/datasets/sacre_coeur/mapping/71295362_4051449754.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8a3a206489b987120423786478681a7209003881
Binary files /dev/null and b/datasets/sacre_coeur/mapping/71295362_4051449754.jpg differ
diff --git a/datasets/sacre_coeur/mapping/93341989_396310999.jpg b/datasets/sacre_coeur/mapping/93341989_396310999.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9f53cddd66d27aad9a77be1dcf7265cdc039dbce
Binary files /dev/null and b/datasets/sacre_coeur/mapping/93341989_396310999.jpg differ
diff --git a/extra_utils/__init__.py b/extra_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/extra_utils/plotting.py b/extra_utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f204222d6bdea1459dbba4d94238ccaa31655d
--- /dev/null
+++ b/extra_utils/plotting.py
@@ -0,0 +1,504 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib, os, cv2
+import matplotlib.cm as cm
+from PIL import Image
+import torch.nn.functional as F
+import torch
+
+
+def _compute_conf_thresh(data):
+ dataset_name = data["dataset_name"][0].lower()
+ if dataset_name == "scannet":
+ thr = 5e-4
+ elif dataset_name == "megadepth":
+ thr = 1e-4
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}")
+ return thr
+
+
+# --- VISUALIZATION --- #
+
+
+def make_matching_figure(
+ img0,
+ img1,
+ mkpts0,
+ mkpts1,
+ color,
+ titles=None,
+ kpts0=None,
+ kpts1=None,
+ text=[],
+ dpi=75,
+ path=None,
+ pad=0,
+):
+ # draw image pair
+ # assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0) # , cmap='gray')
+ axes[1].imshow(img1) # , cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ if titles is not None:
+ axes[i].set_title(titles[i])
+
+ plt.tight_layout(pad=pad)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure,
+ c=color[i],
+ linewidth=2,
+ )
+ for i in range(len(mkpts0))
+ ]
+
+ # freeze the axes to prevent the transform to change
+ axes[0].autoscale(enable=False)
+ axes[1].autoscale(enable=False)
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4)
+
+ # put txts
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
+ fig.text(
+ 0.01,
+ 0.99,
+ "\n".join(text),
+ transform=fig.axes[0].transAxes,
+ fontsize=15,
+ va="top",
+ ha="left",
+ color=txt_color,
+ )
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha="dynamic"):
+ b_mask = data["m_bids"] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
+ kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if "scale0" in data:
+ kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
+ kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
+
+ epi_errs = data["epi_errs"][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == "dynamic":
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f"#Matches {len(kpts0)}",
+ f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
+ f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
+ return figure
+
+
+def _make_confidence_figure(data, b_id):
+ # TODO: Implement confidence figure
+ raise NotImplementedError()
+
+
+def make_matching_figures(data, config, mode="evaluation"):
+ """Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ["evaluation", "confidence"] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data["image0"].size(0)):
+ if mode == "evaluation":
+ fig = _make_evaluation_figure(
+ data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
+ )
+ elif mode == "confidence":
+ fig = _make_confidence_figure(data, b_id)
+ else:
+ raise ValueError(f"Unknown plot mode: {mode}")
+ figures[mode].append(fig)
+ return figures
+
+
+def dynamic_alpha(
+ n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
+):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]
+ ) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
+ return np.clip(
+ np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
+ 0,
+ 1,
+ )
+
+
+np.random.seed(1995)
+color_map = np.arange(100)
+np.random.shuffle(color_map)
+
+
+def draw_topics(
+ data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None
+):
+
+ topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
+ hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
+ hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
+ # print(hw0_i, hw1_i)
+ scale0, scale1 = hw0_i[0] // hw0_c[0], hw1_i[0] // hw1_c[0]
+ if "scale0" in data:
+ scale0 *= data["scale0"][0]
+ else:
+ scale0 = (scale0, scale0)
+ if "scale1" in data:
+ scale1 *= data["scale1"][0]
+ else:
+ scale1 = (scale1, scale1)
+
+ n_topics = topic0.shape[-1]
+ # mask0_nonzero = topic0[0].sum(dim=-1, keepdim=True) > 0
+ # mask1_nonzero = topic1[0].sum(dim=-1, keepdim=True) > 0
+ theta0 = topic0[0].sum(dim=0)
+ theta0 /= theta0.sum().float()
+ theta1 = topic1[0].sum(dim=0)
+ theta1 /= theta1.sum().float()
+ # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics]
+ # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics]
+ top_topics = torch.argsort(theta0 * theta1, descending=True)[:show_n_topics]
+ # print(sum_topic0, sum_topic1)
+
+ topic0 = topic0[0].argmax(
+ dim=-1, keepdim=True
+ ) # .float() / (n_topics - 1) #* 255 + 1 #
+ # topic0[~mask0_nonzero] = -1
+ topic1 = topic1[0].argmax(
+ dim=-1, keepdim=True
+ ) # .float() / (n_topics - 1) #* 255 + 1
+ # topic1[~mask1_nonzero] = -1
+ label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1
+ for i, k in enumerate(top_topics):
+ label_img0[topic0 == k] = color_map[k]
+ label_img1[topic1 == k] = color_map[k]
+
+ # print(hw0_c, scale0)
+ # print(hw1_c, scale1)
+ # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0)
+ map_topic0 = (
+ label_img0.float().view(hw0_c).cpu().numpy()
+ ) # map_topic0.squeeze(0).squeeze(0).cpu().numpy()
+ map_topic0 = cv2.resize(
+ map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1]))
+ )
+ # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1)
+ map_topic1 = (
+ label_img1.float().view(hw1_c).cpu().numpy()
+ ) # map_topic1.squeeze(0).squeeze(0).cpu().numpy()
+ map_topic1 = cv2.resize(
+ map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1]))
+ )
+
+ # show image0
+ if saved_name is None:
+ return map_topic0, map_topic1
+
+ if not os.path.exists(saved_folder):
+ os.makedirs(saved_folder)
+ path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name))
+ plt.imshow(img0)
+ masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0)
+ plt.imshow(
+ masked_map_topic0,
+ cmap=plt.cm.jet,
+ vmin=0,
+ vmax=n_topics - 1,
+ alpha=0.3,
+ interpolation="bilinear",
+ )
+ # plt.show()
+ plt.axis("off")
+ plt.savefig(path_saved_img0, bbox_inches="tight", pad_inches=0, dpi=250)
+ plt.close()
+
+ path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name))
+ plt.imshow(img1)
+ masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1)
+ plt.imshow(
+ masked_map_topic1,
+ cmap=plt.cm.jet,
+ vmin=0,
+ vmax=n_topics - 1,
+ alpha=0.3,
+ interpolation="bilinear",
+ )
+ plt.axis("off")
+ plt.savefig(path_saved_img1, bbox_inches="tight", pad_inches=0, dpi=250)
+ plt.close()
+
+
+def draw_topicfm_demo(
+ data,
+ img0,
+ img1,
+ mkpts0,
+ mkpts1,
+ mcolor,
+ text,
+ show_n_topics=8,
+ topic_alpha=0.3,
+ margin=5,
+ path=None,
+ opencv_display=False,
+ opencv_title="",
+):
+ topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics)
+
+ mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(
+ topic_map1 >= 0, axis=-1
+ )
+
+ topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
+ topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
+ topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
+ overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
+ overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
+
+ cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
+ cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
+
+ overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(
+ np.uint8
+ )
+
+ h0, w0 = img0.shape[:2]
+ h1, w1 = img1.shape[:2]
+ h, w = h0 * 2 + margin * 2, w0 * 2 + margin
+ out_fig = 255 * np.ones((h, w, 3), dtype=np.uint8)
+ out_fig[:h0, :w0] = overlay0
+ if h0 >= h1:
+ start = (h0 - h1) // 2
+ out_fig[start : (start + h1), (w0 + margin) : (w0 + margin + w1)] = overlay1
+ else:
+ start = (h1 - h0) // 2
+ out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
+ start : (start + h0)
+ ]
+
+ step_h = h0 + margin * 2
+ out_fig[step_h : step_h + h0, :w0] = (img0 * 255).astype(np.uint8)
+ if h0 >= h1:
+ start = step_h + (h0 - h1) // 2
+ out_fig[start : start + h1, (w0 + margin) : (w0 + margin + w1)] = (
+ img1 * 255
+ ).astype(np.uint8)
+ else:
+ start = (h1 - h0) // 2
+ out_fig[step_h : step_h + h0, (w0 + margin) : (w0 + margin + w1)] = (
+ img1[start : start + h0] * 255
+ ).astype(np.uint8)
+
+ # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
+ mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
+
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor):
+ c = c.tolist()
+ cv2.line(
+ out_fig,
+ (x0, y0 + step_h),
+ (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
+ color=c,
+ thickness=1,
+ lineType=cv2.LINE_AA,
+ )
+ # display line end-points as circles
+ cv2.circle(out_fig, (x0, y0 + step_h), 2, c, -1, lineType=cv2.LINE_AA)
+ cv2.circle(
+ out_fig,
+ (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2),
+ 2,
+ c,
+ -1,
+ lineType=cv2.LINE_AA,
+ )
+
+ # Scale factor for consistent visualization across scales.
+ sc = min(h / 960.0, 2.0)
+
+ # Big text.
+ Ht = int(30 * sc) # text height
+ txt_color_fg = (255, 255, 255)
+ txt_color_bg = (0, 0, 0)
+ for i, t in enumerate(text):
+ cv2.putText(
+ out_fig,
+ t,
+ (int(8 * sc), Ht + step_h * i),
+ cv2.FONT_HERSHEY_DUPLEX,
+ 1.0 * sc,
+ txt_color_bg,
+ 2,
+ cv2.LINE_AA,
+ )
+ cv2.putText(
+ out_fig,
+ t,
+ (int(8 * sc), Ht + step_h * i),
+ cv2.FONT_HERSHEY_DUPLEX,
+ 1.0 * sc,
+ txt_color_fg,
+ 1,
+ cv2.LINE_AA,
+ )
+
+ if path is not None:
+ cv2.imwrite(str(path), out_fig)
+
+ if opencv_display:
+ cv2.imshow(opencv_title, out_fig)
+ cv2.waitKey(1)
+
+ return out_fig
+
+
+def fig2im(fig):
+ fig.canvas.draw()
+ w, h = fig.canvas.get_width_height()
+ buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
+ im = buf_ndarray.reshape(h, w, 3)
+ return im
+
+
+def draw_matches(
+ mkpts0, mkpts1, img0, img1, conf, titles=None, dpi=150, path=None, pad=0.5
+):
+ thr = 5e-4
+ thr = 0.5
+ color = error_colormap(conf, thr, alpha=0.1)
+ text = [
+ f"image name",
+ f"#Matches: {len(mkpts0)}",
+ ]
+ if path:
+ fig2im(
+ make_matching_figure(
+ img0,
+ img1,
+ mkpts0,
+ mkpts1,
+ color,
+ titles=titles,
+ text=text,
+ path=path,
+ dpi=dpi,
+ pad=pad,
+ )
+ )
+ else:
+ return fig2im(
+ make_matching_figure(
+ img0,
+ img1,
+ mkpts0,
+ mkpts1,
+ color,
+ titles=titles,
+ text=text,
+ pad=pad,
+ dpi=dpi,
+ )
+ )
+
+
+def draw_image_pairs(img0, img1, text=[], dpi=75, path=None, pad=0.5):
+ # draw image pair
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0) # , cmap='gray')
+ axes[1].imshow(img1) # , cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=pad)
+
+ # put txts
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
+ fig.text(
+ 0.01,
+ 0.99,
+ "\n".join(text),
+ transform=fig.axes[0].transAxes,
+ fontsize=15,
+ va="top",
+ ha="left",
+ color=txt_color,
+ )
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
+ plt.close()
+ else:
+ return fig2im(fig)
diff --git a/extra_utils/utils.py b/extra_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e80aa6e30468df32d7f1d49bf8cde92eb13279d
--- /dev/null
+++ b/extra_utils/utils.py
@@ -0,0 +1,182 @@
+import torch
+import numpy as np
+import cv2
+from hloc import matchers, extractors
+from hloc.utils.base_model import dynamic_load
+from hloc import match_dense, match_features, extract_features
+from .plotting import draw_matches, fig2im
+from .visualize_util import plot_images, plot_color_line_matches
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def get_model(match_conf):
+ Model = dynamic_load(matchers, match_conf["model"]["name"])
+ model = Model(match_conf["model"]).eval().to(device)
+ return model
+
+
+def get_feature_model(conf):
+ Model = dynamic_load(extractors, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+ return model
+
+
+def display_matches(pred: dict):
+ img0 = pred["image0_orig"]
+ img1 = pred["image1_orig"]
+
+ num_inliers = 0
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
+ mkpts0 = pred["keypoints0_orig"]
+ mkpts1 = pred["keypoints1_orig"]
+ num_inliers = len(mkpts0)
+ if "mconf" in pred.keys():
+ mconf = pred["mconf"]
+ else:
+ mconf = np.ones(len(mkpts0))
+ fig_mkpts = draw_matches(
+ mkpts0,
+ mkpts1,
+ img0,
+ img1,
+ mconf,
+ dpi=300,
+ titles=["Image 0 - matched keypoints", "Image 1 - matched keypoints"],
+ )
+ fig = fig_mkpts
+ if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
+ # lines
+ mtlines0 = pred["line0_orig"]
+ mtlines1 = pred["line1_orig"]
+ num_inliers = len(mtlines0)
+ fig_lines = plot_images(
+ [img0.squeeze(), img1.squeeze()],
+ ["Image 0 - matched lines", "Image 1 - matched lines"],
+ dpi=300,
+ )
+ fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2)
+ fig_lines = fig2im(fig_lines)
+
+ # keypoints
+ mkpts0 = pred["line_keypoints0_orig"]
+ mkpts1 = pred["line_keypoints1_orig"]
+
+ if mkpts0 is not None and mkpts1 is not None:
+ num_inliers = len(mkpts0)
+ if "mconf" in pred.keys():
+ mconf = pred["mconf"]
+ else:
+ mconf = np.ones(len(mkpts0))
+ fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
+ fig_lines = cv2.resize(fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]))
+ fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
+ else:
+ fig = fig_lines
+ return fig, num_inliers
+
+
+# Matchers collections
+matcher_zoo = {
+ "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
+ "sold2": {"config": match_dense.confs["sold2"], "dense": True},
+ # 'dedode-sparse': {
+ # 'config': match_dense.confs['dedode_sparse'],
+ # 'dense': True # dense mode, we need 2 images
+ # },
+ "loftr": {"config": match_dense.confs["loftr"], "dense": True},
+ "topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
+ "aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
+ "dedode": {
+ "config": match_features.confs["Dual-Softmax"],
+ "config_feature": extract_features.confs["dedode"],
+ "dense": False,
+ },
+ "superpoint+superglue": {
+ "config": match_features.confs["superglue"],
+ "config_feature": extract_features.confs["superpoint_max"],
+ "dense": False,
+ },
+ "superpoint+lightglue": {
+ "config": match_features.confs["superpoint-lightglue"],
+ "config_feature": extract_features.confs["superpoint_max"],
+ "dense": False,
+ },
+ "disk": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["disk"],
+ "dense": False,
+ },
+ "disk+dualsoftmax": {
+ "config": match_features.confs["Dual-Softmax"],
+ "config_feature": extract_features.confs["disk"],
+ "dense": False,
+ },
+ "superpoint+dualsoftmax": {
+ "config": match_features.confs["Dual-Softmax"],
+ "config_feature": extract_features.confs["superpoint_max"],
+ "dense": False,
+ },
+ "disk+lightglue": {
+ "config": match_features.confs["disk-lightglue"],
+ "config_feature": extract_features.confs["disk"],
+ "dense": False,
+ },
+ "superpoint+mnn": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["superpoint_max"],
+ "dense": False,
+ },
+ "sift+sgmnet": {
+ "config": match_features.confs["sgmnet"],
+ "config_feature": extract_features.confs["sift"],
+ "dense": False,
+ },
+ "sosnet": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["sosnet"],
+ "dense": False,
+ },
+ "hardnet": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["hardnet"],
+ "dense": False,
+ },
+ "d2net": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["d2net-ss"],
+ "dense": False,
+ },
+ "d2net-ms": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["d2net-ms"],
+ "dense": False,
+ },
+ "alike": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["alike"],
+ "dense": False,
+ },
+ "lanet": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["lanet"],
+ "dense": False,
+ },
+ "r2d2": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["r2d2"],
+ "dense": False,
+ },
+ "darkfeat": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["darkfeat"],
+ "dense": False,
+ },
+ "sift": {
+ "config": match_features.confs["NN-mutual"],
+ "config_feature": extract_features.confs["sift"],
+ "dense": False,
+ },
+ "roma": {"config": match_dense.confs["roma"], "dense": True},
+ "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
+}
diff --git a/extra_utils/visualize_util.py b/extra_utils/visualize_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1efe6489c2614abf6f6f3714fa5a912162582da5
--- /dev/null
+++ b/extra_utils/visualize_util.py
@@ -0,0 +1,642 @@
+""" Organize some frequently used visualization functions. """
+import cv2
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+import copy
+import seaborn as sns
+
+
+# Plot junctions onto the image (return a separate copy)
+def plot_junctions(input_image, junctions, junc_size=3, color=None):
+ """
+ input_image: can be 0~1 float or 0~255 uint8.
+ junctions: Nx2 or 2xN np array.
+ junc_size: the size of the plotted circles.
+ """
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
+ image = (image * 255.0).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError(
+ "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
+ )
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Junction dimensions should be N*2
+ if not len(junctions.shape) == 2:
+ raise ValueError("[Error] junctions should be 2-dim array.")
+
+ # Always convert to N*2
+ if junctions.shape[-1] != 2:
+ if junctions.shape[0] == 2:
+ junctions = junctions.T
+ else:
+ raise ValueError("[Error] At least one of the two dims should be 2.")
+
+ # Round and convert junctions to int (and check the boundary)
+ H, W = image.shape[:2]
+ junctions = (np.round(junctions)).astype(np.int)
+ junctions[junctions < 0] = 0
+ junctions[junctions[:, 0] >= H, 0] = H - 1 # (first dim) max bounded by H-1
+ junctions[junctions[:, 1] >= W, 1] = W - 1 # (second dim) max bounded by W-1
+
+ # Iterate through all the junctions
+ num_junc = junctions.shape[0]
+ if color is None:
+ color = (0, 255.0, 0)
+ for idx in range(num_junc):
+ # Fetch one junction
+ junc = junctions[idx, :]
+ cv2.circle(
+ image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
+ )
+
+ return image
+
+
+# Plot line segements given junctions and line adjecent map
+def plot_line_segments(
+ input_image,
+ junctions,
+ line_map,
+ junc_size=3,
+ color=(0, 255.0, 0),
+ line_width=1,
+ plot_survived_junc=True,
+):
+ """
+ input_image: can be 0~1 float or 0~255 uint8.
+ junctions: Nx2 or 2xN np array.
+ line_map: NxN np array
+ junc_size: the size of the plotted circles.
+ color: color of the line segments (can be string "random")
+ line_width: width of the drawn segments.
+ plot_survived_junc: whether we only plot the survived junctions.
+ """
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
+ image = (image * 255.0).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError(
+ "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
+ )
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Junction dimensions should be 2
+ if not len(junctions.shape) == 2:
+ raise ValueError("[Error] junctions should be 2-dim array.")
+
+ # Always convert to N*2
+ if junctions.shape[-1] != 2:
+ if junctions.shape[0] == 2:
+ junctions = junctions.T
+ else:
+ raise ValueError("[Error] At least one of the two dims should be 2.")
+
+ # line_map dimension should be 2
+ if not len(line_map.shape) == 2:
+ raise ValueError("[Error] line_map should be 2-dim array.")
+
+ # Color should be "random" or a list or tuple with length 3
+ if color != "random":
+ if not (isinstance(color, tuple) or isinstance(color, list)):
+ raise ValueError("[Error] color should have type list or tuple.")
+ else:
+ if len(color) != 3:
+ raise ValueError(
+ "[Error] color should be a list or tuple with length 3."
+ )
+
+ # Make a copy of the line_map
+ line_map_tmp = copy.copy(line_map)
+
+ # Parse line_map back to segment pairs
+ segments = np.zeros([0, 4])
+ for idx in range(junctions.shape[0]):
+ # if no connectivity, just skip it
+ if line_map_tmp[idx, :].sum() == 0:
+ continue
+ # record the line segment
+ else:
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
+ p1 = np.flip(junctions[idx, :]) # Convert to xy format
+ p2 = np.flip(junctions[idx2, :]) # Convert to xy format
+ segments = np.concatenate(
+ (segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
+ axis=0,
+ )
+
+ # Update line_map
+ line_map_tmp[idx, idx2] = 0
+ line_map_tmp[idx2, idx] = 0
+
+ # Draw segment pairs
+ for idx in range(segments.shape[0]):
+ seg = np.round(segments[idx, :]).astype(np.int)
+ # Decide the color
+ if color != "random":
+ color = tuple(color)
+ else:
+ color = tuple(
+ np.random.rand(
+ 3,
+ )
+ )
+ cv2.line(
+ image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
+ )
+
+ # Also draw the junctions
+ if not plot_survived_junc:
+ num_junc = junctions.shape[0]
+ for idx in range(num_junc):
+ # Fetch one junction
+ junc = junctions[idx, :]
+ cv2.circle(
+ image,
+ tuple(np.flip(junc)),
+ radius=junc_size,
+ color=(0, 255.0, 0),
+ thickness=3,
+ )
+ # Only plot the junctions which are part of a line segment
+ else:
+ for idx in range(segments.shape[0]):
+ seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
+ cv2.circle(
+ image,
+ tuple(seg[:2]),
+ radius=junc_size,
+ color=(0, 255.0, 0),
+ thickness=3,
+ )
+ cv2.circle(
+ image,
+ tuple(seg[2:]),
+ radius=junc_size,
+ color=(0, 255.0, 0),
+ thickness=3,
+ )
+
+ return image
+
+
+# Plot line segments given Nx4 or Nx2x2 line segments
+def plot_line_segments_from_segments(
+ input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
+):
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
+ image = (image * 255.0).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError(
+ "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
+ )
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
+ H, W, _ = image.shape
+ # (1) Nx4 format
+ if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
+ # Round to int32
+ line_segments = line_segments.astype(np.int32)
+
+ # Clip H dimension
+ line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
+ line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)
+
+ # Clip W dimension
+ line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
+ line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)
+
+ # Convert to Nx2x2 format
+ line_segments = np.concatenate(
+ [
+ np.expand_dims(line_segments[:, :2], axis=1),
+ np.expand_dims(line_segments[:, 2:], axis=1),
+ ],
+ axis=1,
+ )
+
+ # (2) Nx2x2 format
+ elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
+ # Round to int32
+ line_segments = line_segments.astype(np.int32)
+
+ # Clip H dimension
+ line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
+ line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)
+
+ else:
+ raise ValueError(
+ "[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
+ )
+
+ # Draw segment pairs (all segments should be in HW format)
+ image = image.copy()
+ for idx in range(line_segments.shape[0]):
+ seg = np.round(line_segments[idx, :, :]).astype(np.int32)
+ # Decide the color
+ if color != "random":
+ color = tuple(color)
+ else:
+ color = tuple(
+ np.random.rand(
+ 3,
+ )
+ )
+ cv2.line(
+ image,
+ tuple(np.flip(seg[0, :])),
+ tuple(np.flip(seg[1, :])),
+ color=color,
+ thickness=line_width,
+ )
+
+ # Also draw the junctions
+ cv2.circle(
+ image,
+ tuple(np.flip(seg[0, :])),
+ radius=junc_size,
+ color=(0, 255.0, 0),
+ thickness=3,
+ )
+ cv2.circle(
+ image,
+ tuple(np.flip(seg[1, :])),
+ radius=junc_size,
+ color=(0, 255.0, 0),
+ thickness=3,
+ )
+
+ return image
+
+
+# Additional functions to visualize multiple images at the same time,
+# e.g. for line matching
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+ # figsize = (size*n, size*3/4) if size is not None else None
+ figsize = (size * n, size * 6 / 5) if size is not None else None
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
+
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+ return fig
+
+
+def plot_keypoints(kpts, colors="lime", ps=4):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=color[i],
+ linewidth=lw,
+ alpha=a,
+ )
+ for i in range(len(kpts0))
+ ]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
+
+
+def plot_lines(
+ lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
+):
+ """Plot lines and endpoints for existing images.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float pixels.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ if not isinstance(line_colors, list):
+ line_colors = [line_colors] * len(lines)
+ if not isinstance(point_colors, list):
+ point_colors = [point_colors] * len(lines)
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines and junctions
+ for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
+ for i in range(len(l)):
+ line = matplotlib.lines.Line2D(
+ (l[i, 0, 0], l[i, 1, 0]),
+ (l[i, 0, 1], l[i, 1, 1]),
+ zorder=1,
+ c=lc,
+ linewidth=lw,
+ )
+ a.add_line(line)
+ pts = l.reshape(-1, 2)
+ a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)
+
+ return fig
+
+
+def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
+ """Plot matches for a pair of existing images, parametrized by their middle point.
+ Args:
+ kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=color[i],
+ linewidth=lw,
+ alpha=a,
+ )
+ for i in range(len(kpts0))
+ ]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+
+def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: bool array of size (N,) indicating correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.color_palette("husl", n_colors=n_lines)
+ np.random.shuffle(colors)
+ alphas = np.ones(n_lines)
+ # If correct_matches is not None, display wrong matches with a low alpha
+ if correct_matches is not None:
+ alphas[~np.array(correct_matches)] = 0.2
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l in zip(axes, lines):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=colors[i],
+ alpha=alphas[i],
+ linewidth=lw,
+ )
+ for i in range(n_lines)
+ ]
+
+ return fig
+
+
+def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors:
+ green for correct matches, red for wrong ones, and blue for the rest.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: list of bool arrays of size N with correct matches.
+ wrong_matches: list of bool arrays of size (N,) with correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ # palette = sns.color_palette()
+ palette = sns.color_palette("hls", 8)
+ blue = palette[5] # palette[0]
+ red = palette[0] # palette[3]
+ green = palette[2] # palette[2]
+ colors = [np.array([blue] * len(l)) for l in lines]
+ for i, c in enumerate(colors):
+ c[np.array(correct_matches[i])] = green
+ c[np.array(wrong_matches[i])] = red
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l, c in zip(axes, lines, colors):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=c[i],
+ linewidth=lw,
+ )
+ for i in range(len(l))
+ ]
+
+
+def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors and
+ highlight the actually matched subsegments.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ subsegments: list of ndarrays of size (N, 2, 2).
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.cubehelix_palette(
+ start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
+ )
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l, ss in zip(axes, lines, subsegments):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+
+ # Draw full line
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c="red",
+ alpha=0.7,
+ linewidth=lw,
+ )
+ for i in range(n_lines)
+ ]
+
+ # Draw matched subsegment
+ endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=colors[i],
+ alpha=1,
+ linewidth=lw,
+ )
+ for i in range(n_lines)
+ ]
diff --git a/hloc/__init__.py b/hloc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63ec0953efe032d45747027a60b4156729c974a8
--- /dev/null
+++ b/hloc/__init__.py
@@ -0,0 +1,31 @@
+import logging
+from packaging import version
+
+__version__ = "1.3"
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+)
+handler = logging.StreamHandler()
+handler.setFormatter(formatter)
+handler.setLevel(logging.INFO)
+
+logger = logging.getLogger("hloc")
+logger.setLevel(logging.INFO)
+logger.addHandler(handler)
+logger.propagate = False
+
+try:
+ import pycolmap
+except ImportError:
+ logger.warning("pycolmap is not installed, some features may not work.")
+else:
+ minimal_version = version.parse("0.3.0")
+ found_version = version.parse(getattr(pycolmap, "__version__"))
+ if found_version < minimal_version:
+ logger.warning(
+ "hloc now requires pycolmap>=%s but found pycolmap==%s, "
+ "please upgrade with `pip install --upgrade pycolmap`",
+ minimal_version,
+ found_version,
+ )
diff --git a/hloc/extract_features.py b/hloc/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..18ad7fb8d250dc4b48330c78ac866c5da50d6de4
--- /dev/null
+++ b/hloc/extract_features.py
@@ -0,0 +1,516 @@
+import argparse
+import torch
+from pathlib import Path
+from typing import Dict, List, Union, Optional
+import h5py
+from types import SimpleNamespace
+import cv2
+import numpy as np
+from tqdm import tqdm
+import pprint
+import collections.abc as collections
+import PIL.Image
+import torchvision.transforms.functional as F
+from . import extractors, logger
+from .utils.base_model import dynamic_load
+from .utils.parsers import parse_image_lists
+from .utils.io import read_image, list_h5_names
+
+
+"""
+A set of standard configurations that can be directly selected from the command
+line using their name. Each is a dictionary with the following entries:
+ - output: the name of the feature file that will be generated.
+ - model: the model configuration, as passed to a feature extractor.
+ - preprocessing: how to preprocess the images read from disk.
+"""
+confs = {
+ "superpoint_aachen": {
+ "output": "feats-superpoint-n4096-r1024",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ "keypoint_threshold": 0.005,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ # Resize images to 1600px even if they are originally smaller.
+ # Improves the keypoint localization if the images are of good quality.
+ "superpoint_max": {
+ "output": "feats-superpoint-n4096-rmax1600",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ "keypoint_threshold": 0.005,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "superpoint_inloc": {
+ "output": "feats-superpoint-n4096-r1600",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 4,
+ "max_keypoints": 4096,
+ "keypoint_threshold": 0.005,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1600,
+ },
+ },
+ "r2d2": {
+ "output": "feats-r2d2-n5000-r1024",
+ "model": {
+ "name": "r2d2",
+ "max_keypoints": 5000,
+ "reliability_threshold": 0.7,
+ "repetability_threshold": 0.7,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "d2net-ss": {
+ "output": "feats-d2net-ss",
+ "model": {
+ "name": "d2net",
+ "multiscale": False,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "d2net-ms": {
+ "output": "feats-d2net-ms",
+ "model": {
+ "name": "d2net",
+ "multiscale": True,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "rootsift": {
+ "output": "feats-sift",
+ "model": {
+ "name": "dog",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "sift": {
+ "output": "feats-sift",
+ "model": {
+ "name": "dog",
+ "descriptor": "sift",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "sosnet": {
+ "output": "feats-sosnet",
+ "model": {"name": "dog", "descriptor": "sosnet"},
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1600,
+ "force_resize": True,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "hardnet": {
+ "output": "feats-hardnet",
+ "model": {"name": "dog", "descriptor": "hardnet"},
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1600,
+ "force_resize": True,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "disk": {
+ "output": "feats-disk",
+ "model": {
+ "name": "disk",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "alike": {
+ "output": "feats-alike",
+ "model": {
+ "name": "alike",
+ "max_keypoints": 5000,
+ "use_relu": True,
+ "multiscale": False,
+ "detection_threshold": 0.5,
+ "top_k": -1,
+ "sub_pixel": False,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "lanet": {
+ "output": "feats-lanet",
+ "model": {
+ "name": "lanet",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "darkfeat": {
+ "output": "feats-darkfeat-n5000-r1024",
+ "model": {
+ "name": "darkfeat",
+ "max_keypoints": 5000,
+ "reliability_threshold": 0.7,
+ "repetability_threshold": 0.7,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "dedode": {
+ "output": "feats-dedode-n5000-r1024",
+ "model": {
+ "name": "dedode",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 768,
+ "height": 768,
+ "dfactor": 8,
+ },
+ },
+ "example": {
+ "output": "feats-example-n5000-r1024",
+ "model": {
+ "name": "example",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 2000,
+ "model_name": "model.pth",
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 768,
+ "height": 768,
+ "dfactor": 8,
+ },
+ },
+ # Global descriptors
+ "dir": {
+ "output": "global-feats-dir",
+ "model": {"name": "dir"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "netvlad": {
+ "output": "global-feats-netvlad",
+ "model": {"name": "netvlad"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "openibl": {
+ "output": "global-feats-openibl",
+ "model": {"name": "openibl"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "cosplace": {
+ "output": "global-feats-cosplace",
+ "model": {"name": "cosplace"},
+ "preprocessing": {"resize_max": 1024},
+ },
+}
+
+
+def resize_image(image, size, interp):
+ if interp.startswith("cv2_"):
+ interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper())
+ h, w = image.shape[:2]
+ if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
+ interp = cv2.INTER_LINEAR
+ resized = cv2.resize(image, size, interpolation=interp)
+ elif interp.startswith("pil_"):
+ interp = getattr(PIL.Image, interp[len("pil_") :].upper())
+ resized = PIL.Image.fromarray(image.astype(np.uint8))
+ resized = resized.resize(size, resample=interp)
+ resized = np.asarray(resized, dtype=image.dtype)
+ else:
+ raise ValueError(f"Unknown interpolation {interp}.")
+ return resized
+
+
+class ImageDataset(torch.utils.data.Dataset):
+ default_conf = {
+ "globs": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
+ "grayscale": False,
+ "resize_max": None,
+ "force_resize": False,
+ "interpolation": "cv2_area", # pil_linear is more accurate but slower
+ }
+
+ def __init__(self, root, conf, paths=None):
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+ self.root = root
+
+ if paths is None:
+ paths = []
+ for g in conf.globs:
+ paths += list(Path(root).glob("**/" + g))
+ if len(paths) == 0:
+ raise ValueError(f"Could not find any image in root: {root}.")
+ paths = sorted(list(set(paths)))
+ self.names = [i.relative_to(root).as_posix() for i in paths]
+ logger.info(f"Found {len(self.names)} images in root {root}.")
+ else:
+ if isinstance(paths, (Path, str)):
+ self.names = parse_image_lists(paths)
+ elif isinstance(paths, collections.Iterable):
+ self.names = [p.as_posix() if isinstance(p, Path) else p for p in paths]
+ else:
+ raise ValueError(f"Unknown format for path argument {paths}.")
+
+ for name in self.names:
+ if not (root / name).exists():
+ raise ValueError(f"Image {name} does not exists in root: {root}.")
+
+ def __getitem__(self, idx):
+ name = self.names[idx]
+ image = read_image(self.root / name, self.conf.grayscale)
+ image = image.astype(np.float32)
+ size = image.shape[:2][::-1]
+
+ if self.conf.resize_max and (
+ self.conf.force_resize or max(size) > self.conf.resize_max
+ ):
+ scale = self.conf.resize_max / max(size)
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, self.conf.interpolation)
+
+ if self.conf.grayscale:
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = image / 255.0
+
+ data = {
+ "image": image,
+ "original_size": np.array(size),
+ }
+ return data
+
+ def __len__(self):
+ return len(self.names)
+
+
+def extract(model, image_0, conf):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ "force_resize": False,
+ "width": 320,
+ "height": 240,
+ "interpolation": "cv2_area",
+ }
+ conf = SimpleNamespace(**{**default_conf, **conf})
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def preprocess(image: np.ndarray, conf: SimpleNamespace):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+ if conf.resize_max:
+ scale = conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+ if conf.force_resize:
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
+ size_new = (conf.width, conf.height)
+ scale = np.array(size) / np.array(size_new)
+ if conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(lambda x: int(x // conf.dfactor * conf.dfactor), image.shape[-2:])
+ )
+ image = F.resize(image, size=size_new, antialias=True)
+ input_ = image.to(device, non_blocking=True)[None]
+ data = {
+ "image": input_,
+ "image_orig": image_0,
+ "original_size": np.array(size),
+ "size": np.array(image.shape[1:][::-1]),
+ }
+ return data
+
+ # convert to grayscale if needed
+ if len(image_0.shape) == 3 and conf.grayscale:
+ image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY)
+ else:
+ image0 = image_0
+ # comment following lines, image is always RGB mode
+ # if not conf.grayscale and len(image_0.shape) == 3:
+ # image0 = image_0[:, :, ::-1] # BGR to RGB
+ data = preprocess(image0, conf)
+ pred = model({"image": data["image"]})
+ pred["image_size"] = original_size = data["original_size"]
+ pred = {**pred, **data}
+ return pred
+
+
+@torch.no_grad()
+def main(
+ conf: Dict,
+ image_dir: Path,
+ export_dir: Optional[Path] = None,
+ as_half: bool = True,
+ image_list: Optional[Union[Path, List[str]]] = None,
+ feature_path: Optional[Path] = None,
+ overwrite: bool = False,
+) -> Path:
+ logger.info(
+ "Extracting local features with configuration:" f"\n{pprint.pformat(conf)}"
+ )
+
+ dataset = ImageDataset(image_dir, conf["preprocessing"], image_list)
+ if feature_path is None:
+ feature_path = Path(export_dir, conf["output"] + ".h5")
+ feature_path.parent.mkdir(exist_ok=True, parents=True)
+ skip_names = set(
+ list_h5_names(feature_path) if feature_path.exists() and not overwrite else ()
+ )
+ dataset.names = [n for n in dataset.names if n not in skip_names]
+ if len(dataset.names) == 0:
+ logger.info("Skipping the extraction.")
+ return feature_path
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ Model = dynamic_load(extractors, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+
+ loader = torch.utils.data.DataLoader(
+ dataset, num_workers=1, shuffle=False, pin_memory=True
+ )
+ for idx, data in enumerate(tqdm(loader)):
+ name = dataset.names[idx]
+ pred = model({"image": data["image"].to(device, non_blocking=True)})
+ pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
+
+ pred["image_size"] = original_size = data["original_size"][0].numpy()
+ if "keypoints" in pred:
+ size = np.array(data["image"].shape[-2:][::-1])
+ scales = (original_size / size).astype(np.float32)
+ pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5
+ if "scales" in pred:
+ pred["scales"] *= scales.mean()
+ # add keypoint uncertainties scaled to the original resolution
+ uncertainty = getattr(model, "detection_noise", 1) * scales.mean()
+
+ if as_half:
+ for k in pred:
+ dt = pred[k].dtype
+ if (dt == np.float32) and (dt != np.float16):
+ pred[k] = pred[k].astype(np.float16)
+
+ with h5py.File(str(feature_path), "a", libver="latest") as fd:
+ try:
+ if name in fd:
+ del fd[name]
+ grp = fd.create_group(name)
+ for k, v in pred.items():
+ grp.create_dataset(k, data=v)
+ if "keypoints" in pred:
+ grp["keypoints"].attrs["uncertainty"] = uncertainty
+ except OSError as error:
+ if "No space left on device" in error.args[0]:
+ logger.error(
+ "Out of disk space: storing features on disk can take "
+ "significant space, did you enable the as_half flag?"
+ )
+ del grp, fd[name]
+ raise error
+
+ del pred
+
+ logger.info("Finished exporting features.")
+ return feature_path
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--image_dir", type=Path, required=True)
+ parser.add_argument("--export_dir", type=Path, required=True)
+ parser.add_argument(
+ "--conf", type=str, default="superpoint_aachen", choices=list(confs.keys())
+ )
+ parser.add_argument("--as_half", action="store_true")
+ parser.add_argument("--image_list", type=Path)
+ parser.add_argument("--feature_path", type=Path)
+ args = parser.parse_args()
+ main(confs[args.conf], args.image_dir, args.export_dir, args.as_half)
diff --git a/hloc/extractors/__init__.py b/hloc/extractors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/extractors/alike.py b/hloc/extractors/alike.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7086186c7dcf828d81f19a8bdcc40214f9f7d21
--- /dev/null
+++ b/hloc/extractors/alike.py
@@ -0,0 +1,52 @@
+import sys
+from pathlib import Path
+import subprocess
+import torch
+
+from ..utils.base_model import BaseModel
+
+alike_path = Path(__file__).parent / "../../third_party/ALIKE"
+sys.path.append(str(alike_path))
+from alike import ALike as Alike_
+from alike import configs
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Alike(BaseModel):
+ default_conf = {
+ "model_name": "alike-t", # 'alike-t', 'alike-s', 'alike-n', 'alike-l'
+ "use_relu": True,
+ "multiscale": False,
+ "max_keypoints": 1000,
+ "detection_threshold": 0.5,
+ "top_k": -1,
+ "sub_pixel": False,
+ }
+
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = Alike_(
+ **configs[conf["model_name"]],
+ device=device,
+ top_k=conf["top_k"],
+ scores_th=conf["detection_threshold"],
+ n_limit=conf["max_keypoints"],
+ )
+
+ def _forward(self, data):
+ image = data["image"]
+ image = image.permute(0, 2, 3, 1).squeeze()
+ image = image.cpu().numpy() * 255.0
+ pred = self.net(image, sub_pixel=self.conf["sub_pixel"])
+
+ keypoints = pred["keypoints"]
+ descriptors = pred["descriptors"]
+ scores = pred["scores"]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/hloc/extractors/cosplace.py b/hloc/extractors/cosplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d13a84d57d80bee090709623cce74453784844b
--- /dev/null
+++ b/hloc/extractors/cosplace.py
@@ -0,0 +1,44 @@
+"""
+Code for loading models trained with CosPlace as a global features extractor
+for geolocalization through image retrieval.
+Multiple models are available with different backbones. Below is a summary of
+models available (backbone : list of available output descriptors
+dimensionality). For example you can use a model based on a ResNet50 with
+descriptors dimensionality 1024.
+ ResNet18: [32, 64, 128, 256, 512]
+ ResNet50: [32, 64, 128, 256, 512, 1024, 2048]
+ ResNet101: [32, 64, 128, 256, 512, 1024, 2048]
+ ResNet152: [32, 64, 128, 256, 512, 1024, 2048]
+ VGG16: [ 64, 128, 256, 512]
+
+CosPlace paper: https://arxiv.org/abs/2204.02287
+"""
+
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+
+class CosPlace(BaseModel):
+ default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048}
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "gmberton/CosPlace",
+ "get_trained_model",
+ backbone=conf["backbone"],
+ fc_output_dim=conf["fc_output_dim"],
+ ).eval()
+
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+ desc = self.net(image)
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/hloc/extractors/d2net.py b/hloc/extractors/d2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e94dac175ea3514b93c801ffdf63e9f5fbe327
--- /dev/null
+++ b/hloc/extractors/d2net.py
@@ -0,0 +1,57 @@
+import sys
+from pathlib import Path
+import subprocess
+import torch
+
+from ..utils.base_model import BaseModel
+
+d2net_path = Path(__file__).parent / "../../third_party/d2net"
+sys.path.append(str(d2net_path))
+from lib.model_test import D2Net as _D2Net
+from lib.pyramid import process_multiscale
+
+
+class D2Net(BaseModel):
+ default_conf = {
+ "model_name": "d2_tf.pth",
+ "checkpoint_dir": d2net_path / "models",
+ "use_relu": True,
+ "multiscale": False,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_file = conf["checkpoint_dir"] / conf["model_name"]
+ if not model_file.exists():
+ model_file.parent.mkdir(exist_ok=True)
+ cmd = [
+ "wget",
+ "https://dsmn.ml/files/d2-net/" + conf["model_name"],
+ "-O",
+ str(model_file),
+ ]
+ subprocess.run(cmd, check=True)
+
+ self.net = _D2Net(
+ model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
+ )
+
+ def _forward(self, data):
+ image = data["image"]
+ image = image.flip(1) # RGB -> BGR
+ norm = image.new_tensor([103.939, 116.779, 123.68])
+ image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization
+
+ if self.conf["multiscale"]:
+ keypoints, scores, descriptors = process_multiscale(image, self.net)
+ else:
+ keypoints, scores, descriptors = process_multiscale(
+ image, self.net, scales=[1]
+ )
+ keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/hloc/extractors/darkfeat.py b/hloc/extractors/darkfeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..1990a7b6e138683477380d26ecdc7db657fa7a3f
--- /dev/null
+++ b/hloc/extractors/darkfeat.py
@@ -0,0 +1,57 @@
+import sys
+from pathlib import Path
+import subprocess
+import logging
+
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+
+darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
+sys.path.append(str(darkfeat_path))
+from darkfeat import DarkFeat as DarkFeat_
+
+
+class DarkFeat(BaseModel):
+ default_conf = {
+ "model_name": "DarkFeat.pth",
+ "max_keypoints": 1000,
+ "detection_threshold": 0.5,
+ "sub_pixel": False,
+ }
+ weight_urls = {
+ "DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t",
+ }
+ proxy = "http://localhost:1080"
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = darkfeat_path / "checkpoints" / conf["model_name"]
+ link = self.weight_urls[conf["model_name"]]
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
+ cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
+ logger.info(f"Downloading the DarkFeat model with `{cmd_wo_proxy}`.")
+ try:
+ subprocess.run(cmd_wo_proxy, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.info(f"Downloading the DarkFeat model with `{cmd}`.")
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f"Failed to download the DarkFeat model.")
+ raise e
+
+ self.net = DarkFeat_(model_path)
+
+ def _forward(self, data):
+ pred = self.net({"image": data["image"]})
+ keypoints = pred["keypoints"]
+ descriptors = pred["descriptors"]
+ scores = pred["scores"]
+ return {
+ "keypoints": keypoints[None], # 1 x N x 2
+ "scores": scores[None], # 1 x N
+ "descriptors": descriptors[None], # 1 x 128 x N
+ }
diff --git a/hloc/extractors/dedode.py b/hloc/extractors/dedode.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f2736f013520e537fd5edd40bcb195b84d21db
--- /dev/null
+++ b/hloc/extractors/dedode.py
@@ -0,0 +1,102 @@
+import sys
+from pathlib import Path
+import subprocess
+import logging
+import torch
+from PIL import Image
+from ..utils.base_model import BaseModel
+import torchvision.transforms as transforms
+
+dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
+sys.path.append(str(dedode_path))
+
+from DeDoDe import dedode_detector_L, dedode_descriptor_B
+from DeDoDe.utils import to_pixel_coords
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+logger = logging.getLogger(__name__)
+
+
+class DeDoDe(BaseModel):
+ default_conf = {
+ "name": "dedode",
+ "model_detector_name": "dedode_detector_L.pth",
+ "model_descriptor_name": "dedode_descriptor_B.pth",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ "dense": False, # Now fixed to be false
+ }
+ required_inputs = [
+ "image",
+ ]
+ weight_urls = {
+ "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
+ "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
+ }
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_detector_path = dedode_path / "pretrained" / conf["model_detector_name"]
+ model_descriptor_path = (
+ dedode_path / "pretrained" / conf["model_descriptor_name"]
+ )
+
+ self.normalizer = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ # Download the model.
+ if not model_detector_path.exists():
+ model_detector_path.parent.mkdir(exist_ok=True)
+ link = self.weight_urls[conf["model_detector_name"]]
+ cmd = ["wget", link, "-O", str(model_detector_path)]
+ logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ if not model_descriptor_path.exists():
+ model_descriptor_path.parent.mkdir(exist_ok=True)
+ link = self.weight_urls[conf["model_descriptor_name"]]
+ cmd = ["wget", link, "-O", str(model_descriptor_path)]
+ logger.info(f"Downloading the DeDoDe descriptor model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ logger.info(f"Loading DeDoDe model...")
+
+ # load the model
+ weights_detector = torch.load(model_detector_path, map_location="cpu")
+ weights_descriptor = torch.load(model_descriptor_path, map_location="cpu")
+ self.detector = dedode_detector_L(weights=weights_detector)
+ self.descriptor = dedode_descriptor_B(weights=weights_descriptor)
+ logger.info(f"Load DeDoDe model done.")
+
+ def _forward(self, data):
+ """
+ data: dict, keys: {'image0','image1'}
+ image shape: N x C x H x W
+ color mode: RGB
+ """
+ img0 = self.normalizer(data["image"].squeeze()).float()[None]
+ H_A, W_A = img0.shape[2:]
+
+ # step 1: detect keypoints
+ detections_A = None
+ batch_A = {"image": img0}
+ if self.conf["dense"]:
+ detections_A = self.detector.detect_dense(batch_A)
+ else:
+ detections_A = self.detector.detect(
+ batch_A, num_keypoints=self.conf["max_keypoints"]
+ )
+ keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
+
+ # step 2: describe keypoints
+ # dim: 1 x N x 256
+ description_A = self.descriptor.describe_keypoints(batch_A, keypoints_A)[
+ "descriptions"
+ ]
+ keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)
+
+ return {
+ "keypoints": keypoints_A, # 1 x N x 2
+ "descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N
+ "scores": P_A, # 1 x N
+ }
diff --git a/hloc/extractors/dir.py b/hloc/extractors/dir.py
new file mode 100644
index 0000000000000000000000000000000000000000..30689c4ff8f8ef447bf7661bdfe471a4990226ff
--- /dev/null
+++ b/hloc/extractors/dir.py
@@ -0,0 +1,76 @@
+import sys
+from pathlib import Path
+import torch
+from zipfile import ZipFile
+import os
+import sklearn
+import gdown
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party/deep-image-retrieval"))
+os.environ["DB_ROOT"] = "" # required by dirtorch
+
+from dirtorch.utils import common # noqa: E402
+from dirtorch.extract_features import load_model # noqa: E402
+
+# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
+# which has been deprecated in sklearn v0.24
+# and must be explicitly imported with `from sklearn.decomposition import PCA`.
+# This is a hacky workaround to maintain forward compatibility.
+sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca
+
+
+class DIR(BaseModel):
+ default_conf = {
+ "model_name": "Resnet-101-AP-GeM",
+ "whiten_name": "Landmarks_clean",
+ "whiten_params": {
+ "whitenp": 0.25,
+ "whitenv": None,
+ "whitenm": 1.0,
+ },
+ "pooling": "gem",
+ "gemp": 3,
+ }
+ required_inputs = ["image"]
+
+ dir_models = {
+ "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy",
+ }
+
+ def _init(self, conf):
+ checkpoint = Path(torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt")
+ if not checkpoint.exists():
+ checkpoint.parent.mkdir(exist_ok=True, parents=True)
+ link = self.dir_models[conf["model_name"]]
+ gdown.download(str(link), str(checkpoint) + ".zip", quiet=False)
+ zf = ZipFile(str(checkpoint) + ".zip", "r")
+ zf.extractall(checkpoint.parent)
+ zf.close()
+ os.remove(str(checkpoint) + ".zip")
+
+ self.net = load_model(checkpoint, False) # first load on CPU
+ if conf["whiten_name"]:
+ assert conf["whiten_name"] in self.net.pca
+
+ def _forward(self, data):
+ image = data["image"]
+ assert image.shape[1] == 3
+ mean = self.net.preprocess["mean"]
+ std = self.net.preprocess["std"]
+ image = image - image.new_tensor(mean)[:, None, None]
+ image = image / image.new_tensor(std)[:, None, None]
+
+ desc = self.net(image)
+ desc = desc.unsqueeze(0) # batch dimension
+ if self.conf["whiten_name"]:
+ pca = self.net.pca[self.conf["whiten_name"]]
+ desc = common.whiten_features(
+ desc.cpu().numpy(), pca, **self.conf["whiten_params"]
+ )
+ desc = torch.from_numpy(desc)
+
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/hloc/extractors/disk.py b/hloc/extractors/disk.py
new file mode 100644
index 0000000000000000000000000000000000000000..a80d8d7ff0a112dc049ebccdce199829747a11e2
--- /dev/null
+++ b/hloc/extractors/disk.py
@@ -0,0 +1,32 @@
+import kornia
+
+from ..utils.base_model import BaseModel
+
+
+class DISK(BaseModel):
+ default_conf = {
+ "weights": "depth",
+ "max_keypoints": None,
+ "nms_window_size": 5,
+ "detection_threshold": 0.0,
+ "pad_if_not_divisible": True,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
+
+ def _forward(self, data):
+ image = data["image"]
+ features = self.model(
+ image,
+ n=self.conf["max_keypoints"],
+ window_size=self.conf["nms_window_size"],
+ score_threshold=self.conf["detection_threshold"],
+ pad_if_not_divisible=self.conf["pad_if_not_divisible"],
+ )
+ return {
+ "keypoints": [f.keypoints for f in features],
+ "scores": [f.detection_scores for f in features],
+ "descriptors": [f.descriptors.t() for f in features],
+ }
diff --git a/hloc/extractors/dog.py b/hloc/extractors/dog.py
new file mode 100644
index 0000000000000000000000000000000000000000..96203c88db0aba5d284ed1556d4859671c1f8c01
--- /dev/null
+++ b/hloc/extractors/dog.py
@@ -0,0 +1,131 @@
+import kornia
+from kornia.feature.laf import laf_from_center_scale_ori, extract_patches_from_pyramid
+import numpy as np
+import torch
+import pycolmap
+
+from ..utils.base_model import BaseModel
+
+
+EPS = 1e-6
+
+
+def sift_to_rootsift(x):
+ x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
+ x = np.sqrt(x.clip(min=EPS))
+ x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
+ return x
+
+
+class DoG(BaseModel):
+ default_conf = {
+ "options": {
+ "first_octave": 0,
+ "peak_threshold": 0.01,
+ },
+ "descriptor": "rootsift",
+ "max_keypoints": -1,
+ "patch_size": 32,
+ "mr_size": 12,
+ }
+ required_inputs = ["image"]
+ detection_noise = 1.0
+ max_batch_size = 1024
+
+ def _init(self, conf):
+ if conf["descriptor"] == "sosnet":
+ self.describe = kornia.feature.SOSNet(pretrained=True)
+ elif conf["descriptor"] == "hardnet":
+ self.describe = kornia.feature.HardNet(pretrained=True)
+ elif conf["descriptor"] not in ["sift", "rootsift"]:
+ raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
+
+ self.sift = None # lazily instantiated on the first image
+ self.device = torch.device("cpu")
+
+ def to(self, *args, **kwargs):
+ device = kwargs.get("device")
+ if device is None:
+ match = [a for a in args if isinstance(a, (torch.device, str))]
+ if len(match) > 0:
+ device = match[0]
+ if device is not None:
+ self.device = torch.device(device)
+ return super().to(*args, **kwargs)
+
+ def _forward(self, data):
+ image = data["image"]
+ image_np = image.cpu().numpy()[0, 0]
+ assert image.shape[1] == 1
+ assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
+
+ if self.sift is None:
+ use_gpu = pycolmap.has_cuda and self.device.type == "cuda"
+ options = {**self.conf["options"]}
+ if self.conf["descriptor"] == "rootsift":
+ options["normalization"] = pycolmap.Normalization.L1_ROOT
+ else:
+ options["normalization"] = pycolmap.Normalization.L2
+ self.sift = pycolmap.Sift(
+ options=pycolmap.SiftExtractionOptions(options),
+ device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
+ )
+
+ keypoints, scores, descriptors = self.sift.extract(image_np)
+ scales = keypoints[:, 2]
+ oris = np.rad2deg(keypoints[:, 3])
+
+ if self.conf["descriptor"] in ["sift", "rootsift"]:
+ # We still renormalize because COLMAP does not normalize well,
+ # maybe due to numerical errors
+ if self.conf["descriptor"] == "rootsift":
+ descriptors = sift_to_rootsift(descriptors)
+ descriptors = torch.from_numpy(descriptors)
+ elif self.conf["descriptor"] in ("sosnet", "hardnet"):
+ center = keypoints[:, :2] + 0.5
+ laf_scale = scales * self.conf["mr_size"] / 2
+ laf_ori = -oris
+ lafs = laf_from_center_scale_ori(
+ torch.from_numpy(center)[None],
+ torch.from_numpy(laf_scale)[None, :, None, None],
+ torch.from_numpy(laf_ori)[None, :, None],
+ ).to(image.device)
+ patches = extract_patches_from_pyramid(
+ image, lafs, PS=self.conf["patch_size"]
+ )[0]
+ descriptors = patches.new_zeros((len(patches), 128))
+ if len(patches) > 0:
+ for start_idx in range(0, len(patches), self.max_batch_size):
+ end_idx = min(len(patches), start_idx + self.max_batch_size)
+ descriptors[start_idx:end_idx] = self.describe(
+ patches[start_idx:end_idx]
+ )
+ else:
+ raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')
+
+ keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
+ scales = torch.from_numpy(scales)
+ oris = torch.from_numpy(oris)
+ scores = torch.from_numpy(scores)
+ if self.conf["max_keypoints"] != -1:
+ # TODO: check that the scores from PyCOLMAP are 100% correct,
+ # follow https://github.com/mihaidusmanu/pycolmap/issues/8
+ max_number = (
+ scores.shape[0]
+ if scores.shape[0] < self.conf["max_keypoints"]
+ else self.conf["max_keypoints"]
+ )
+ values, indices = torch.topk(scores, max_number)
+ keypoints = keypoints[indices]
+ scales = scales[indices]
+ oris = oris[indices]
+ scores = scores[indices]
+ descriptors = descriptors[indices]
+
+ return {
+ "keypoints": keypoints[None],
+ "scales": scales[None],
+ "oris": oris[None],
+ "scores": scores[None],
+ "descriptors": descriptors.T[None],
+ }
diff --git a/hloc/extractors/example.py b/hloc/extractors/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e44c0ba11f0f8aa48a02281915c6a05149ae5d
--- /dev/null
+++ b/hloc/extractors/example.py
@@ -0,0 +1,58 @@
+import sys
+from pathlib import Path
+import subprocess
+import torch
+import logging
+
+from ..utils.base_model import BaseModel
+
+example_path = Path(__file__).parent / "../../third_party/example"
+sys.path.append(str(example_path))
+
+# import some modules here
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+logger = logging.getLogger(__name__)
+
+
+class Example(BaseModel):
+ # change to your default configs
+ default_conf = {
+ "name": "example",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 2000,
+ "model_name": "model.pth",
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+
+ # set checkpoints paths if needed
+ model_path = example_path / "checkpoints" / f'{conf["model_name"]}'
+ if not model_path.exists():
+ logger.info(f"No model found at {model_path}")
+
+ # init model
+ self.net = callable
+ # self.net = ExampleNet(is_test=True)
+ state_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(state_dict["model_state"])
+ logger.info(f"Load example model done.")
+
+ def _forward(self, data):
+ # data: dict, keys: 'image'
+ # image color mode: RGB
+ # image value range in [0, 1]
+ image = data["image"]
+
+ # B: batch size, N: number of keypoints
+ # keypoints shape: B x N x 2, type: torch tensor
+ # scores shape: B x N, type: torch tensor
+ # descriptors shape: B x 128 x N, type: torch tensor
+ keypoints, scores, descriptors = self.net(image)
+
+ return {
+ "keypoints": keypoints,
+ "scores": scores,
+ "descriptors": descriptors,
+ }
diff --git a/hloc/extractors/fire.py b/hloc/extractors/fire.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0937225e0caa9d2aaae781dcc221c774d8c9cdf
--- /dev/null
+++ b/hloc/extractors/fire.py
@@ -0,0 +1,73 @@
+from pathlib import Path
+import subprocess
+import logging
+import sys
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+fire_path = Path(__file__).parent / "../../third_party/fire"
+sys.path.append(str(fire_path))
+
+
+import fire_network
+
+
+class FIRe(BaseModel):
+ default_conf = {
+ "global": True,
+ "asmk": False,
+ "model_name": "fire_SfM_120k.pth",
+ "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
+ "features_num": 1000, # TODO:not supported now
+ "asmk_name": "asmk_codebook.bin", # TODO:not supported now
+ "config_name": "eval_fire.yml",
+ }
+ required_inputs = ["image"]
+
+ # Models exported using
+ fire_models = {
+ "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
+ "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
+ }
+
+ def _init(self, conf):
+
+ assert conf["model_name"] in self.fire_models.keys()
+ # Config paths
+ model_path = fire_path / "model" / conf["model_name"]
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.fire_models[conf["model_name"]]
+ cmd = ["wget", link, "-O", str(model_path)]
+ logger.info(f"Downloading the FIRe model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ logger.info(f"Loading fire model...")
+
+ # Load net
+ state = torch.load(model_path)
+ state["net_params"]["pretrained"] = None
+ net = fire_network.init_network(**state["net_params"])
+ net.load_state_dict(state["state_dict"])
+ self.net = net
+
+ self.norm_rgb = tvf.Normalize(
+ **dict(zip(["mean", "std"], net.runtime["mean_std"]))
+ )
+
+ # params
+ self.scales = conf["scales"]
+
+ def _forward(self, data):
+
+ image = self.norm_rgb(data["image"])
+
+ # Feature extraction.
+ desc = self.net.forward_global(image, scales=self.scales)
+
+ return {"global_descriptor": desc}
diff --git a/hloc/extractors/fire_local.py b/hloc/extractors/fire_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ac12ea700dcdf42b437d555e056e4ee2d43dd6
--- /dev/null
+++ b/hloc/extractors/fire_local.py
@@ -0,0 +1,90 @@
+from pathlib import Path
+import subprocess
+import logging
+import sys
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+fire_path = Path(__file__).parent / "../../third_party/fire"
+
+sys.path.append(str(fire_path))
+
+
+import fire_network
+from lib.how.how.stages.evaluate import eval_asmk_fire, load_dataset_fire
+
+from lib.asmk import asmk
+from asmk import io_helpers, asmk_method, kernel as kern_pkg
+
+EPS = 1e-6
+
+
+class FIRe(BaseModel):
+ default_conf = {
+ "global": True,
+ "asmk": False,
+ "model_name": "fire_SfM_120k.pth",
+ "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
+ "features_num": 1000,
+ "asmk_name": "asmk_codebook.bin",
+ "config_name": "eval_fire.yml",
+ }
+ required_inputs = ["image"]
+
+ # Models exported using
+ fire_models = {
+ "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
+ "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
+ }
+
+ def _init(self, conf):
+
+ assert conf["model_name"] in self.fire_models.keys()
+
+ # Config paths
+ model_path = fire_path / "model" / conf["model_name"]
+ config_path = fire_path / conf["config_name"]
+ asmk_bin_path = fire_path / "model" / conf["asmk_name"]
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.fire_models[conf["model_name"]]
+ cmd = ["wget", link, "-O", str(model_path)]
+ logger.info(f"Downloading the FIRe model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ logger.info(f"Loading fire model...")
+
+ # Load net
+ state = torch.load(model_path)
+ state["net_params"]["pretrained"] = None
+ net = fire_network.init_network(**state["net_params"])
+ net.load_state_dict(state["state_dict"])
+ self.net = net
+
+ self.norm_rgb = tvf.Normalize(
+ **dict(zip(["mean", "std"], net.runtime["mean_std"]))
+ )
+
+ # params
+ self.scales = conf["scales"]
+ self.features_num = conf["features_num"]
+
+ def _forward(self, data):
+
+ image = self.norm_rgb(data["image"])
+
+ local_desc = self.net.forward_local(
+ image, features_num=self.features_num, scales=self.scales
+ )
+
+ logger.info(f"output[0].shape = {local_desc[0].shape}\n")
+
+ return {
+ # 'global_descriptor': desc
+ "local_descriptor": local_desc
+ }
diff --git a/hloc/extractors/lanet.py b/hloc/extractors/lanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b036ccc26eaf7897cd740903a93b191c9e55c01f
--- /dev/null
+++ b/hloc/extractors/lanet.py
@@ -0,0 +1,53 @@
+import sys
+from pathlib import Path
+import subprocess
+import torch
+
+from ..utils.base_model import BaseModel
+
+lanet_path = Path(__file__).parent / "../../third_party/lanet"
+sys.path.append(str(lanet_path))
+from network_v0.model import PointModel
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class LANet(BaseModel):
+ default_conf = {
+ "model_name": "v0",
+ "keypoint_threshold": 0.1,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
+ if not model_path.exists():
+ print(f"No model found at {model_path}")
+ self.net = PointModel(is_test=True)
+ state_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(state_dict["model_state"])
+
+ def _forward(self, data):
+ image = data["image"]
+ keypoints, scores, descriptors = self.net(image)
+ _, _, Hc, Wc = descriptors.shape
+
+ # Scores & Descriptors
+ kpts_score = (
+ torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy()
+ )
+ descriptors = (
+ descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy()
+ )
+
+ # Filter based on confidence threshold
+ descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ keypoints = kpts_score[:, 1:]
+ scores = kpts_score[:, 0]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/hloc/extractors/netvlad.py b/hloc/extractors/netvlad.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb1a44b30a7f91573a1e93ad2311f12782ab6ab
--- /dev/null
+++ b/hloc/extractors/netvlad.py
@@ -0,0 +1,147 @@
+from pathlib import Path
+import subprocess
+import logging
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+from scipy.io import loadmat
+
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+
+EPS = 1e-6
+
+
+class NetVLADLayer(nn.Module):
+ def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True):
+ super().__init__()
+ self.score_proj = nn.Conv1d(input_dim, K, kernel_size=1, bias=score_bias)
+ centers = nn.parameter.Parameter(torch.empty([input_dim, K]))
+ nn.init.xavier_uniform_(centers)
+ self.register_parameter("centers", centers)
+ self.intranorm = intranorm
+ self.output_dim = input_dim * K
+
+ def forward(self, x):
+ b = x.size(0)
+ scores = self.score_proj(x)
+ scores = F.softmax(scores, dim=1)
+ diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1)
+ desc = (scores.unsqueeze(1) * diff).sum(dim=-1)
+ if self.intranorm:
+ # From the official MATLAB implementation.
+ desc = F.normalize(desc, dim=1)
+ desc = desc.view(b, -1)
+ desc = F.normalize(desc, dim=1)
+ return desc
+
+
+class NetVLAD(BaseModel):
+ default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True}
+ required_inputs = ["image"]
+
+ # Models exported using
+ # https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m.
+ dir_models = {
+ "VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat",
+ "VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat",
+ }
+
+ def _init(self, conf):
+ assert conf["model_name"] in self.dir_models.keys()
+
+ # Download the checkpoint.
+ checkpoint = Path(torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat")
+ if not checkpoint.exists():
+ checkpoint.parent.mkdir(exist_ok=True, parents=True)
+ link = self.dir_models[conf["model_name"]]
+ cmd = ["wget", link, "-O", str(checkpoint)]
+ logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ # Create the network.
+ # Remove classification head.
+ backbone = list(models.vgg16().children())[0]
+ # Remove last ReLU + MaxPool2d.
+ self.backbone = nn.Sequential(*list(backbone.children())[:-2])
+
+ self.netvlad = NetVLADLayer()
+
+ if conf["whiten"]:
+ self.whiten = nn.Linear(self.netvlad.output_dim, 4096)
+
+ # Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open
+ mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True)
+
+ # CNN weights.
+ for layer, mat_layer in zip(self.backbone.children(), mat["net"].layers):
+ if isinstance(layer, nn.Conv2d):
+ w = mat_layer.weights[0] # Shape: S x S x IN x OUT
+ b = mat_layer.weights[1] # Shape: OUT
+ # Prepare for PyTorch - enforce float32 and right shape.
+ # w should have shape: OUT x IN x S x S
+ # b should have shape: OUT
+ w = torch.tensor(w).float().permute([3, 2, 0, 1])
+ b = torch.tensor(b).float()
+ # Update layer weights.
+ layer.weight = nn.Parameter(w)
+ layer.bias = nn.Parameter(b)
+
+ # NetVLAD weights.
+ score_w = mat["net"].layers[30].weights[0] # D x K
+ # centers are stored as opposite in official MATLAB code
+ center_w = -mat["net"].layers[30].weights[1] # D x K
+ # Prepare for PyTorch - make sure it is float32 and has right shape.
+ # score_w should have shape K x D x 1
+ # center_w should have shape D x K
+ score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1)
+ center_w = torch.tensor(center_w).float()
+ # Update layer weights.
+ self.netvlad.score_proj.weight = nn.Parameter(score_w)
+ self.netvlad.centers = nn.Parameter(center_w)
+
+ # Whitening weights.
+ if conf["whiten"]:
+ w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT
+ b = mat["net"].layers[33].weights[1] # Shape: OUT
+ # Prepare for PyTorch - make sure it is float32 and has right shape
+ w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN
+ b = torch.tensor(b.squeeze()).float() # Shape: OUT
+ # Update layer weights.
+ self.whiten.weight = nn.Parameter(w)
+ self.whiten.bias = nn.Parameter(b)
+
+ # Preprocessing parameters.
+ self.preprocess = {
+ "mean": mat["net"].meta.normalization.averageImage[0, 0],
+ "std": np.array([1, 1, 1], dtype=np.float32),
+ }
+
+ def _forward(self, data):
+ image = data["image"]
+ assert image.shape[1] == 3
+ assert image.min() >= -EPS and image.max() <= 1 + EPS
+ image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255.
+ mean = self.preprocess["mean"]
+ std = self.preprocess["std"]
+ image = image - image.new_tensor(mean).view(1, -1, 1, 1)
+ image = image / image.new_tensor(std).view(1, -1, 1, 1)
+
+ # Feature extraction.
+ descriptors = self.backbone(image)
+ b, c, _, _ = descriptors.size()
+ descriptors = descriptors.view(b, c, -1)
+
+ # NetVLAD layer.
+ descriptors = F.normalize(descriptors, dim=1) # Pre-normalization.
+ desc = self.netvlad(descriptors)
+
+ # Whiten if needed.
+ if hasattr(self, "whiten"):
+ desc = self.whiten(desc)
+ desc = F.normalize(desc, dim=1) # Final L2 normalization.
+
+ return {"global_descriptor": desc}
diff --git a/hloc/extractors/openibl.py b/hloc/extractors/openibl.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e332a4e0016fceb184dd850bd3b6f86231dad54
--- /dev/null
+++ b/hloc/extractors/openibl.py
@@ -0,0 +1,26 @@
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+
+class OpenIBL(BaseModel):
+ default_conf = {
+ "model_name": "vgg16_netvlad",
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "yxgeee/OpenIBL", conf["model_name"], pretrained=True
+ ).eval()
+ mean = [0.48501960784313836, 0.4579568627450961, 0.4076039215686255]
+ std = [0.00392156862745098, 0.00392156862745098, 0.00392156862745098]
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+ desc = self.net(image)
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/hloc/extractors/r2d2.py b/hloc/extractors/r2d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4f0ec1e3e2424607d11fc5cce2533325741754c
--- /dev/null
+++ b/hloc/extractors/r2d2.py
@@ -0,0 +1,61 @@
+import sys
+from pathlib import Path
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
+sys.path.append(str(r2d2_path))
+from extract import load_network, NonMaxSuppression, extract_multiscale
+
+
+class R2D2(BaseModel):
+ default_conf = {
+ "model_name": "r2d2_WASF_N16.pt",
+ "max_keypoints": 5000,
+ "scale_factor": 2**0.25,
+ "min_size": 256,
+ "max_size": 1024,
+ "min_scale": 0,
+ "max_scale": 1,
+ "reliability_threshold": 0.7,
+ "repetability_threshold": 0.7,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_fn = r2d2_path / "models" / conf["model_name"]
+ self.norm_rgb = tvf.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ self.net = load_network(model_fn)
+ self.detector = NonMaxSuppression(
+ rel_thr=conf["reliability_threshold"],
+ rep_thr=conf["repetability_threshold"],
+ )
+
+ def _forward(self, data):
+ img = data["image"]
+ img = self.norm_rgb(img)
+
+ xys, desc, scores = extract_multiscale(
+ self.net,
+ img,
+ self.detector,
+ scale_f=self.conf["scale_factor"],
+ min_size=self.conf["min_size"],
+ max_size=self.conf["max_size"],
+ min_scale=self.conf["min_scale"],
+ max_scale=self.conf["max_scale"],
+ )
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
+ xy = xys[idxs, :2]
+ desc = desc[idxs].t()
+ scores = scores[idxs]
+
+ pred = {
+ "keypoints": xy[None],
+ "descriptors": desc[None],
+ "scores": scores[None],
+ }
+ return pred
diff --git a/hloc/extractors/rekd.py b/hloc/extractors/rekd.py
new file mode 100644
index 0000000000000000000000000000000000000000..b198b5dc4fd9ff60f3e96188da5abe6f7fa0f052
--- /dev/null
+++ b/hloc/extractors/rekd.py
@@ -0,0 +1,53 @@
+import sys
+from pathlib import Path
+import subprocess
+import torch
+
+from ..utils.base_model import BaseModel
+
+rekd_path = Path(__file__).parent / "../../third_party/REKD"
+sys.path.append(str(rekd_path))
+from training.model.REKD import REKD as REKD_
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class REKD(BaseModel):
+ default_conf = {
+ "model_name": "v0",
+ "keypoint_threshold": 0.1,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = rekd_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
+ if not model_path.exists():
+ print(f"No model found at {model_path}")
+ self.net = REKD_(is_test=True)
+ state_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(state_dict["model_state"])
+
+ def _forward(self, data):
+ image = data["image"]
+ keypoints, scores, descriptors = self.net(image)
+ _, _, Hc, Wc = descriptors.shape
+
+ # Scores & Descriptors
+ kpts_score = (
+ torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy()
+ )
+ descriptors = (
+ descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy()
+ )
+
+ # Filter based on confidence threshold
+ descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ keypoints = kpts_score[:, 1:]
+ scores = kpts_score[:, 0]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/hloc/extractors/superpoint.py b/hloc/extractors/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..264ad8afe5bd849e00836f926edec2a648460806
--- /dev/null
+++ b/hloc/extractors/superpoint.py
@@ -0,0 +1,44 @@
+import sys
+from pathlib import Path
+import torch
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from SuperGluePretrainedNetwork.models import superpoint # noqa E402
+
+
+# The original keypoint sampling is incorrect. We patch it here but
+# we don't fix it upstream to not impact exisiting evaluations.
+def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
+ """Interpolate descriptors at keypoint locations"""
+ b, c, h, w = descriptors.shape
+ keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
+ )
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1
+ )
+ return descriptors
+
+
+class SuperPoint(BaseModel):
+ default_conf = {
+ "nms_radius": 4,
+ "keypoint_threshold": 0.005,
+ "max_keypoints": -1,
+ "remove_borders": 4,
+ "fix_sampling": False,
+ }
+ required_inputs = ["image"]
+ detection_noise = 2.0
+
+ def _init(self, conf):
+ if conf["fix_sampling"]:
+ superpoint.sample_descriptors = sample_descriptors_fix_sampling
+ self.net = superpoint.SuperPoint(conf)
+
+ def _forward(self, data):
+ return self.net(data)
diff --git a/hloc/match_dense.py b/hloc/match_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ee9f5fdac5365913af2ef29f78a1acbb5cda75f
--- /dev/null
+++ b/hloc/match_dense.py
@@ -0,0 +1,384 @@
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+from types import SimpleNamespace
+from .extract_features import read_image, resize_image
+import cv2
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+confs = {
+ # Best quality but loads of points. Only use for small scenes
+ "loftr": {
+ "output": "matches-loftr",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ "force_resize": True,
+ },
+ "max_error": 1, # max error for assigned keypoints (in px)
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ },
+ # Semi-scalable loftr which limits detected keypoints
+ "loftr_aachen": {
+ "output": "matches-loftr_aachen",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {"grayscale": True, "resize_max": 1024, "dfactor": 8},
+ "max_error": 2, # max error for assigned keypoints (in px)
+ "cell_size": 8, # size of quantization patch (max 1 kp/patch)
+ },
+ # Use for matching superpoint feats with loftr
+ "loftr_superpoint": {
+ "output": "matches-loftr_aachen",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {"grayscale": True, "resize_max": 1024, "dfactor": 8},
+ "max_error": 4, # max error for assigned keypoints (in px)
+ "cell_size": 4, # size of quantization patch (max 1 kp/patch)
+ },
+ # Use topicfm for matching feats
+ "topicfm": {
+ "output": "matches-topicfm",
+ "model": {
+ "name": "topicfm",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ },
+ },
+ # Use topicfm for matching feats
+ "aspanformer": {
+ "output": "matches-aspanformer",
+ "model": {
+ "name": "aspanformer",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "dkm": {
+ "output": "matches-dkm",
+ "model": {
+ "name": "dkm",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 80,
+ "height": 60,
+ "dfactor": 8,
+ },
+ },
+ "roma": {
+ "output": "matches-roma",
+ "model": {
+ "name": "roma",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 320,
+ "height": 240,
+ "dfactor": 8,
+ },
+ },
+ "dedode_sparse": {
+ "output": "matches-dedode",
+ "model": {
+ "name": "dedode",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ "dense": False,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 768,
+ "height": 768,
+ "dfactor": 8,
+ },
+ },
+ "sold2": {
+ "output": "matches-sold2",
+ "model": {
+ "name": "sold2",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "gluestick": {
+ "output": "matches-gluestick",
+ "model": {
+ "name": "gluestick",
+ "use_lines": True,
+ "max_keypoints": 1000,
+ "max_lines": 300,
+ "force_num_keypoints": False,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+}
+
+
+def scale_keypoints(kpts, scale):
+ if np.any(scale != 1.0):
+ kpts *= kpts.new_tensor(scale)
+ return kpts
+
+
+def scale_lines(lines, scale):
+ if np.any(scale != 1.0):
+ lines *= lines.new_tensor(scale)
+ return lines
+
+
+def match(model, path_0, path_1, conf):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ "force_resize": False,
+ "width": 320,
+ "height": 240,
+ }
+
+ def preprocess(image: np.ndarray):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+ if conf.resize_max:
+ scale = conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+ if conf.force_resize:
+ size = image.shape[:2][::-1]
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
+ size_new = (conf.width, conf.height)
+ scale = np.array(size) / np.array(size_new)
+ if conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(lambda x: int(x // conf.dfactor * conf.dfactor), image.shape[-2:])
+ )
+ image = F.resize(image, size=size_new, antialias=True)
+ scale = np.array(size) / np.array(size_new)[::-1]
+ return image, scale
+
+ conf = SimpleNamespace(**{**default_conf, **conf})
+ image0 = read_image(path_0, conf.grayscale)
+ image1 = read_image(path_1, conf.grayscale)
+ image0, scale0 = preprocess(image0)
+ image1, scale1 = preprocess(image1)
+ image0 = image0.to(device)[None]
+ image1 = image1.to(device)[None]
+ pred = model({"image0": image0, "image1": image1})
+
+ # Rescale keypoints and move to cpu
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5
+ kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5
+
+ ret = {
+ "image0": image0.squeeze().cpu().numpy(),
+ "image1": image1.squeeze().cpu().numpy(),
+ "keypoints0": kpts0.cpu().numpy(),
+ "keypoints1": kpts1.cpu().numpy(),
+ }
+ if "mconf" in pred.keys():
+ ret["mconf"] = pred["mconf"].cpu().numpy()
+ return ret
+
+
+@torch.no_grad()
+def match_images(model, image_0, image_1, conf, device="cpu"):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ "force_resize": False,
+ "width": 320,
+ "height": 240,
+ }
+
+ def preprocess(image: np.ndarray):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+ if conf.resize_max:
+ scale = conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+ if conf.force_resize:
+ size = image.shape[:2][::-1]
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
+ size_new = (conf.width, conf.height)
+ scale = np.array(size) / np.array(size_new)
+ if conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(lambda x: int(x // conf.dfactor * conf.dfactor), image.shape[-2:])
+ )
+ image = F.resize(image, size=size_new)
+ scale = np.array(size) / np.array(size_new)[::-1]
+ return image, scale
+
+ conf = SimpleNamespace(**{**default_conf, **conf})
+
+ if len(image_0.shape) == 3 and conf.grayscale:
+ image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY)
+ else:
+ image0 = image_0
+ if len(image_0.shape) == 3 and conf.grayscale:
+ image1 = cv2.cvtColor(image_1, cv2.COLOR_RGB2GRAY)
+ else:
+ image1 = image_1
+
+ # comment following lines, image is always RGB mode
+ # if not conf.grayscale and len(image0.shape) == 3:
+ # image0 = image0[:, :, ::-1] # BGR to RGB
+ # if not conf.grayscale and len(image1.shape) == 3:
+ # image1 = image1[:, :, ::-1] # BGR to RGB
+
+ image0, scale0 = preprocess(image0)
+ image1, scale1 = preprocess(image1)
+ image0 = image0.to(device)[None]
+ image1 = image1.to(device)[None]
+ pred = model({"image0": image0, "image1": image1})
+
+ s0 = np.array(image_0.shape[:2][::-1]) / np.array(image0.shape[-2:][::-1])
+ s1 = np.array(image_1.shape[:2][::-1]) / np.array(image1.shape[-2:][::-1])
+
+ # Rescale keypoints and move to cpu
+ if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
+ kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
+
+ ret = {
+ "image0": image0.squeeze().cpu().numpy(),
+ "image1": image1.squeeze().cpu().numpy(),
+ "image0_orig": image_0,
+ "image1_orig": image_1,
+ "keypoints0": kpts0.cpu().numpy(),
+ "keypoints1": kpts1.cpu().numpy(),
+ "keypoints0_orig": kpts0_origin.cpu().numpy(),
+ "keypoints1_orig": kpts1_origin.cpu().numpy(),
+ "original_size0": np.array(image_0.shape[:2][::-1]),
+ "original_size1": np.array(image_1.shape[:2][::-1]),
+ "new_size0": np.array(image0.shape[-2:][::-1]),
+ "new_size1": np.array(image1.shape[-2:][::-1]),
+ "scale0": s0,
+ "scale1": s1,
+ }
+ if "mconf" in pred.keys():
+ ret["mconf"] = pred["mconf"].cpu().numpy()
+ if "lines0" in pred.keys() and "lines1" in pred.keys():
+ if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
+ kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
+ kpts0_origin = kpts0_origin.cpu().numpy()
+ kpts1_origin = kpts1_origin.cpu().numpy()
+ else:
+ kpts0_origin, kpts1_origin = None, None # np.zeros([0]), np.zeros([0])
+ lines0, lines1 = pred["lines0"], pred["lines1"]
+ lines0_raw, lines1_raw = pred["raw_lines0"], pred["raw_lines1"]
+
+ lines0_raw = torch.from_numpy(lines0_raw.copy())
+ lines1_raw = torch.from_numpy(lines1_raw.copy())
+ lines0_raw = scale_lines(lines0_raw + 0.5, s0) - 0.5
+ lines1_raw = scale_lines(lines1_raw + 0.5, s1) - 0.5
+
+ lines0 = torch.from_numpy(lines0.copy())
+ lines1 = torch.from_numpy(lines1.copy())
+ lines0 = scale_lines(lines0 + 0.5, s0) - 0.5
+ lines1 = scale_lines(lines1 + 0.5, s1) - 0.5
+
+ ret = {
+ "image0_orig": image_0,
+ "image1_orig": image_1,
+ "line0": lines0_raw.cpu().numpy(),
+ "line1": lines1_raw.cpu().numpy(),
+ "line0_orig": lines0.cpu().numpy(),
+ "line1_orig": lines1.cpu().numpy(),
+ "line_keypoints0_orig": kpts0_origin,
+ "line_keypoints1_orig": kpts1_origin,
+ }
+ del pred
+ torch.cuda.empty_cache()
+ return ret
diff --git a/hloc/match_features.py b/hloc/match_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..478c02459bd6b037425aa85ac5c54fae9497e354
--- /dev/null
+++ b/hloc/match_features.py
@@ -0,0 +1,389 @@
+import argparse
+from typing import Union, Optional, Dict, List, Tuple
+from pathlib import Path
+import pprint
+from queue import Queue
+from threading import Thread
+from functools import partial
+from tqdm import tqdm
+import h5py
+import torch
+
+from . import matchers, logger
+from .utils.base_model import dynamic_load
+from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
+import numpy as np
+
+"""
+A set of standard configurations that can be directly selected from the command
+line using their name. Each is a dictionary with the following entries:
+ - output: the name of the match file that will be generated.
+ - model: the model configuration, as passed to a feature matcher.
+"""
+confs = {
+ "superglue": {
+ "output": "matches-superglue",
+ "model": {
+ "name": "superglue",
+ "weights": "outdoor",
+ "sinkhorn_iterations": 50,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "superglue-fast": {
+ "output": "matches-superglue-it5",
+ "model": {
+ "name": "superglue",
+ "weights": "outdoor",
+ "sinkhorn_iterations": 5,
+ "match_threshold": 0.2,
+ },
+ },
+ "superpoint-lightglue": {
+ "output": "matches-lightglue",
+ "model": {
+ "name": "lightglue",
+ "match_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "superpoint",
+ "model_name": "superpoint_lightglue.pth",
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "disk-lightglue": {
+ "output": "matches-lightglue",
+ "model": {
+ "name": "lightglue",
+ "match_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "disk",
+ "model_name": "disk_lightglue.pth",
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "sgmnet": {
+ "output": "matches-sgmnet",
+ "model": {
+ "name": "sgmnet",
+ "seed_top_k": [256, 256],
+ "seed_radius_coe": 0.01,
+ "net_channels": 128,
+ "layer_num": 9,
+ "head": 4,
+ "seedlayer": [0, 6],
+ "use_mc_seeding": True,
+ "use_score_encoding": False,
+ "conf_bar": [1.11, 0.1],
+ "sink_iter": [10, 100],
+ "detach_iter": 1000000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "NN-superpoint": {
+ "output": "matches-NN-mutual-dist.7",
+ "model": {
+ "name": "nearest_neighbor",
+ "do_mutual_check": True,
+ "distance_threshold": 0.7,
+ "match_threshold": 0.2,
+ },
+ },
+ "NN-ratio": {
+ "output": "matches-NN-mutual-ratio.8",
+ "model": {
+ "name": "nearest_neighbor",
+ "do_mutual_check": True,
+ "ratio_threshold": 0.8,
+ "match_threshold": 0.2,
+ },
+ },
+ "NN-mutual": {
+ "output": "matches-NN-mutual",
+ "model": {
+ "name": "nearest_neighbor",
+ "do_mutual_check": True,
+ "match_threshold": 0.2,
+ },
+ },
+ "Dual-Softmax": {
+ "output": "matches-Dual-Softmax",
+ "model": {
+ "name": "dual_softmax",
+ "do_mutual_check": True,
+ "match_threshold": 0.2, # TODO
+ },
+ },
+ "adalam": {
+ "output": "matches-adalam",
+ "model": {
+ "name": "adalam",
+ "match_threshold": 0.2,
+ },
+ },
+}
+
+
+class WorkQueue:
+ def __init__(self, work_fn, num_threads=1):
+ self.queue = Queue(num_threads)
+ self.threads = [
+ Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads)
+ ]
+ for thread in self.threads:
+ thread.start()
+
+ def join(self):
+ for thread in self.threads:
+ self.queue.put(None)
+ for thread in self.threads:
+ thread.join()
+
+ def thread_fn(self, work_fn):
+ item = self.queue.get()
+ while item is not None:
+ work_fn(item)
+ item = self.queue.get()
+
+ def put(self, data):
+ self.queue.put(data)
+
+
+class FeaturePairsDataset(torch.utils.data.Dataset):
+ def __init__(self, pairs, feature_path_q, feature_path_r):
+ self.pairs = pairs
+ self.feature_path_q = feature_path_q
+ self.feature_path_r = feature_path_r
+
+ def __getitem__(self, idx):
+ name0, name1 = self.pairs[idx]
+ data = {}
+ with h5py.File(self.feature_path_q, "r") as fd:
+ grp = fd[name0]
+ for k, v in grp.items():
+ data[k + "0"] = torch.from_numpy(v.__array__()).float()
+ # some matchers might expect an image but only use its size
+ data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
+ with h5py.File(self.feature_path_r, "r") as fd:
+ grp = fd[name1]
+ for k, v in grp.items():
+ data[k + "1"] = torch.from_numpy(v.__array__()).float()
+ data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
+ return data
+
+ def __len__(self):
+ return len(self.pairs)
+
+
+def writer_fn(inp, match_path):
+ pair, pred = inp
+ with h5py.File(str(match_path), "a", libver="latest") as fd:
+ if pair in fd:
+ del fd[pair]
+ grp = fd.create_group(pair)
+ matches = pred["matches0"][0].cpu().short().numpy()
+ grp.create_dataset("matches0", data=matches)
+ if "matching_scores0" in pred:
+ scores = pred["matching_scores0"][0].cpu().half().numpy()
+ grp.create_dataset("matching_scores0", data=scores)
+
+
+def main(
+ conf: Dict,
+ pairs: Path,
+ features: Union[Path, str],
+ export_dir: Optional[Path] = None,
+ matches: Optional[Path] = None,
+ features_ref: Optional[Path] = None,
+ overwrite: bool = False,
+) -> Path:
+
+ if isinstance(features, Path) or Path(features).exists():
+ features_q = features
+ if matches is None:
+ raise ValueError(
+ "Either provide both features and matches as Path" " or both as names."
+ )
+ else:
+ if export_dir is None:
+ raise ValueError(
+ "Provide an export_dir if features is not" f" a file path: {features}."
+ )
+ features_q = Path(export_dir, features + ".h5")
+ if matches is None:
+ matches = Path(export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5')
+
+ if features_ref is None:
+ features_ref = features_q
+ match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)
+
+ return matches
+
+
+def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None):
+ """Avoid to recompute duplicates to save time."""
+ pairs = set()
+ for i, j in pairs_all:
+ if (j, i) not in pairs:
+ pairs.add((i, j))
+ pairs = list(pairs)
+ if match_path is not None and match_path.exists():
+ with h5py.File(str(match_path), "r", libver="latest") as fd:
+ pairs_filtered = []
+ for i, j in pairs:
+ if (
+ names_to_pair(i, j) in fd
+ or names_to_pair(j, i) in fd
+ or names_to_pair_old(i, j) in fd
+ or names_to_pair_old(j, i) in fd
+ ):
+ continue
+ pairs_filtered.append((i, j))
+ return pairs_filtered
+ return pairs
+
+
+@torch.no_grad()
+def match_from_paths(
+ conf: Dict,
+ pairs_path: Path,
+ match_path: Path,
+ feature_path_q: Path,
+ feature_path_ref: Path,
+ overwrite: bool = False,
+) -> Path:
+ logger.info(
+ "Matching local features with configuration:" f"\n{pprint.pformat(conf)}"
+ )
+
+ if not feature_path_q.exists():
+ raise FileNotFoundError(f"Query feature file {feature_path_q}.")
+ if not feature_path_ref.exists():
+ raise FileNotFoundError(f"Reference feature file {feature_path_ref}.")
+ match_path.parent.mkdir(exist_ok=True, parents=True)
+
+ assert pairs_path.exists(), pairs_path
+ pairs = parse_retrieval(pairs_path)
+ pairs = [(q, r) for q, rs in pairs.items() for r in rs]
+ pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
+ if len(pairs) == 0:
+ logger.info("Skipping the matching.")
+ return
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ Model = dynamic_load(matchers, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+
+ dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
+ loader = torch.utils.data.DataLoader(
+ dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True
+ )
+ writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)
+
+ for idx, data in enumerate(tqdm(loader, smoothing=0.1)):
+ data = {
+ k: v if k.startswith("image") else v.to(device, non_blocking=True)
+ for k, v in data.items()
+ }
+ pred = model(data)
+ pair = names_to_pair(*pairs[idx])
+ writer_queue.put((pair, pred))
+ writer_queue.join()
+ logger.info("Finished exporting matches.")
+
+
+def scale_keypoints(kpts, scale):
+ if np.any(scale != 1.0):
+ kpts *= kpts.new_tensor(scale)
+ return kpts
+
+
+@torch.no_grad()
+def match_images(model, feat0, feat1):
+ # forward pass to match keypoints
+ desc0 = feat0["descriptors"][0]
+ desc1 = feat1["descriptors"][0]
+ if len(desc0.shape) == 2:
+ desc0 = desc0.unsqueeze(0)
+ if len(desc1.shape) == 2:
+ desc1 = desc1.unsqueeze(0)
+ pred = model(
+ {
+ "image0": feat0["image"],
+ "keypoints0": feat0["keypoints"][0],
+ "scores0": feat0["scores"][0].unsqueeze(0),
+ "descriptors0": desc0,
+ "image1": feat1["image"],
+ "keypoints1": feat1["keypoints"][0],
+ "scores1": feat1["scores"][0].unsqueeze(0),
+ "descriptors1": desc1,
+ }
+ )
+ pred = {
+ k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
+ for k, v in pred.items()
+ }
+ kpts0, kpts1 = (
+ feat0["keypoints"][0].cpu().numpy(),
+ feat1["keypoints"][0].cpu().numpy(),
+ )
+ matches, confid = pred["matches0"], pred["matching_scores0"]
+ # Keep the matching keypoints.
+ valid = matches > -1
+ mkpts0 = kpts0[valid]
+ mkpts1 = kpts1[matches[valid]]
+ mconfid = confid[valid]
+ # rescale the keypoints to their original size
+ s0 = feat0["original_size"] / feat0["size"]
+ s1 = feat1["original_size"] / feat1["size"]
+ kpts0_origin = scale_keypoints(torch.from_numpy(mkpts0 + 0.5), s0) - 0.5
+ kpts1_origin = scale_keypoints(torch.from_numpy(mkpts1 + 0.5), s1) - 0.5
+ ret = {
+ "image0_orig": feat0["image_orig"],
+ "image1_orig": feat1["image_orig"],
+ "keypoints0": kpts0,
+ "keypoints1": kpts1,
+ "keypoints0_orig": kpts0_origin.numpy(),
+ "keypoints1_orig": kpts1_origin.numpy(),
+ "mconf": mconfid,
+ }
+ del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin
+ torch.cuda.empty_cache()
+
+ return ret
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pairs", type=Path, required=True)
+ parser.add_argument("--export_dir", type=Path)
+ parser.add_argument("--features", type=str, default="feats-superpoint-n4096-r1024")
+ parser.add_argument("--matches", type=Path)
+ parser.add_argument(
+ "--conf", type=str, default="superglue", choices=list(confs.keys())
+ )
+ args = parser.parse_args()
+ main(confs[args.conf], args.pairs, args.features, args.export_dir)
diff --git a/hloc/matchers/__init__.py b/hloc/matchers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9fd381eb391604db3d3c3c03d278c0e08de2531
--- /dev/null
+++ b/hloc/matchers/__init__.py
@@ -0,0 +1,3 @@
+def get_matcher(matcher):
+ mod = __import__(f"{__name__}.{matcher}", fromlist=[""])
+ return getattr(mod, "Model")
diff --git a/hloc/matchers/adalam.py b/hloc/matchers/adalam.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51491ad7ed46a467811a793d819c7da8e63b3d4
--- /dev/null
+++ b/hloc/matchers/adalam.py
@@ -0,0 +1,69 @@
+import torch
+
+from ..utils.base_model import BaseModel
+
+from kornia.feature.adalam import AdalamFilter
+from kornia.utils.helpers import get_cuda_device_if_available
+
+
+class AdaLAM(BaseModel):
+ # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
+ default_conf = {
+ "area_ratio": 100,
+ "search_expansion": 4,
+ "ransac_iters": 128,
+ "min_inliers": 6,
+ "min_confidence": 200,
+ "orientation_difference_threshold": 30,
+ "scale_rate_threshold": 1.5,
+ "detected_scale_rate_threshold": 5,
+ "refit": True,
+ "force_seed_mnn": True,
+ "device": get_cuda_device_if_available(),
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ "descriptors0",
+ "descriptors1",
+ "keypoints0",
+ "keypoints1",
+ "scales0",
+ "scales1",
+ "oris0",
+ "oris1",
+ ]
+
+ def _init(self, conf):
+ self.adalam = AdalamFilter(conf)
+
+ def _forward(self, data):
+ assert data["keypoints0"].size(0) == 1
+ if data["keypoints0"].size(1) < 2 or data["keypoints1"].size(1) < 2:
+ matches = torch.zeros(
+ (0, 2), dtype=torch.int64, device=data["keypoints0"].device
+ )
+ else:
+ matches = self.adalam.match_and_filter(
+ data["keypoints0"][0],
+ data["keypoints1"][0],
+ data["descriptors0"][0].T,
+ data["descriptors1"][0].T,
+ data["image0"].shape[2:],
+ data["image1"].shape[2:],
+ data["oris0"][0],
+ data["oris1"][0],
+ data["scales0"][0],
+ data["scales1"][0],
+ )
+ matches_new = torch.full(
+ (data["keypoints0"].size(1),),
+ -1,
+ dtype=torch.int64,
+ device=data["keypoints0"].device,
+ )
+ matches_new[matches[:, 0]] = matches[:, 1]
+ return {
+ "matches0": matches_new.unsqueeze(0),
+ "matching_scores0": torch.zeros(matches_new.size(0)).unsqueeze(0),
+ }
diff --git a/hloc/matchers/aspanformer.py b/hloc/matchers/aspanformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..501fdfffc1004347d6adcb89903f70dc640bf4ee
--- /dev/null
+++ b/hloc/matchers/aspanformer.py
@@ -0,0 +1,76 @@
+import sys
+import torch
+from ..utils.base_model import BaseModel
+from ..utils import do_system
+from pathlib import Path
+import subprocess
+import logging
+
+logger = logging.getLogger(__name__)
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
+from ASpanFormer.src.config.default import get_cfg_defaults
+from ASpanFormer.src.utils.misc import lower_config
+from ASpanFormer.demo import demo_utils
+
+aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
+
+
+class ASpanFormer(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "match_threshold": 0.2,
+ "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
+ "model_name": "weights_aspanformer.tar",
+ }
+ required_inputs = ["image0", "image1"]
+ proxy = "http://localhost:1080"
+ aspanformer_models = {
+ "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
+ }
+
+ def _init(self, conf):
+ model_path = aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
+ # Download the model.
+ if not model_path.exists():
+ # model_path.parent.mkdir(exist_ok=True)
+ tar_path = aspanformer_path / conf["model_name"]
+ if not tar_path.exists():
+ link = self.aspanformer_models[conf["model_name"]]
+ cmd = ["gdown", link, "-O", str(tar_path), "--proxy", self.proxy]
+ cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
+ logger.info(f"Downloading the Aspanformer model with `{cmd_wo_proxy}`.")
+ try:
+ subprocess.run(cmd_wo_proxy, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.info(f"Downloading the Aspanformer model with `{cmd}`.")
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f"Failed to download the Aspanformer model.")
+ raise e
+
+ do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")
+
+ logger.info(f"Loading Aspanformer model...")
+
+ config = get_cfg_defaults()
+ config.merge_from_file(conf["config_path"])
+ _config = lower_config(config)
+ self.net = _ASpanFormer(config=_config["aspan"])
+ weight_path = model_path
+ state_dict = torch.load(str(weight_path), map_location="cpu")["state_dict"]
+ self.net.load_state_dict(state_dict, strict=False)
+
+ def _forward(self, data):
+ data_ = {
+ "image0": data["image0"],
+ "image1": data["image1"],
+ }
+ self.net(data_, online_resize=True)
+ corr0 = data_["mkpts0_f"]
+ corr1 = data_["mkpts1_f"]
+ pred = {}
+ pred["keypoints0"], pred["keypoints1"] = corr0, corr1
+ return pred
diff --git a/hloc/matchers/dkm.py b/hloc/matchers/dkm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b6bd9dd890dddb6245f07edd2ef8a79955b29f7
--- /dev/null
+++ b/hloc/matchers/dkm.py
@@ -0,0 +1,61 @@
+import sys
+from pathlib import Path
+import torch
+from PIL import Image
+import subprocess
+import logging
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from DKM.dkm import DKMv3_outdoor
+
+dkm_path = Path(__file__).parent / "../../third_party/DKM"
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+logger = logging.getLogger(__name__)
+
+
+class DKMv3(BaseModel):
+ default_conf = {
+ "model_name": "DKMv3_outdoor.pth",
+ "match_threshold": 0.2,
+ "checkpoint_dir": dkm_path / "pretrained",
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+ # Models exported using
+ dkm_models = {
+ "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
+ "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
+ }
+
+ def _init(self, conf):
+ model_path = dkm_path / "pretrained" / conf["model_name"]
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.dkm_models[conf["model_name"]]
+ cmd = ["wget", link, "-O", str(model_path)]
+ logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+ logger.info(f"Loading DKMv3 model...")
+ self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device)
+
+ def _forward(self, data):
+ img0 = data["image0"].cpu().numpy().squeeze() * 255
+ img1 = data["image1"].cpu().numpy().squeeze() * 255
+ img0 = img0.transpose(1, 2, 0)
+ img1 = img1.transpose(1, 2, 0)
+ img0 = Image.fromarray(img0.astype("uint8"))
+ img1 = Image.fromarray(img1.astype("uint8"))
+ W_A, H_A = img0.size
+ W_B, H_B = img1.size
+
+ warp, certainty = self.net.match(img0, img1, device=device)
+ matches, certainty = self.net.sample(warp, certainty)
+ kpts1, kpts2 = self.net.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ pred = {}
+ pred["keypoints0"], pred["keypoints1"] = kpts1, kpts2
+ return pred
diff --git a/hloc/matchers/dual_softmax.py b/hloc/matchers/dual_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e43bd8d72c6178bc9d40b641bab518b070cac5c
--- /dev/null
+++ b/hloc/matchers/dual_softmax.py
@@ -0,0 +1,68 @@
+import torch
+
+from ..utils.base_model import BaseModel
+import numpy as np
+
+# borrow from dedode
+def dual_softmax_matcher(
+ desc_A: tuple["B", "C", "N"],
+ desc_B: tuple["B", "C", "M"],
+ threshold=0.1,
+ inv_temperature=20,
+ normalize=True,
+):
+ B, C, N = desc_A.shape
+ if len(desc_A.shape) < 3:
+ desc_A, desc_B = desc_A[None], desc_B[None]
+ if normalize:
+ desc_A = desc_A / desc_A.norm(dim=1, keepdim=True)
+ desc_B = desc_B / desc_B.norm(dim=1, keepdim=True)
+ sim = torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature
+ P = sim.softmax(dim=-2) * sim.softmax(dim=-1)
+ mask = torch.nonzero(
+ (P == P.max(dim=-1, keepdim=True).values)
+ * (P == P.max(dim=-2, keepdim=True).values)
+ * (P > threshold)
+ )
+ mask = mask.cpu().numpy()
+ matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1)
+ scores0 = np.zeros((B, P.shape[-2]), dtype=float)
+ matches0[:, mask[:, 1]] = mask[:, 2]
+ tmp_P = P.cpu().numpy()
+ scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]]
+ matches0 = torch.from_numpy(matches0).to(P.device)
+ scores0 = torch.from_numpy(scores0).to(P.device)
+ return matches0, scores0
+
+
+class DualSoftMax(BaseModel):
+ default_conf = {
+ "match_threshold": 0.2,
+ "inv_temperature": 20,
+ }
+ # shape: B x DIM x M
+ required_inputs = ["descriptors0", "descriptors1"]
+
+ def _init(self, conf):
+ pass
+
+ def _forward(self, data):
+ if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0:
+ matches0 = torch.full(
+ data["descriptors0"].shape[:2], -1, device=data["descriptors0"].device
+ )
+ return {
+ "matches0": matches0,
+ "matching_scores0": torch.zeros_like(matches0),
+ }
+
+ matches0, scores0 = dual_softmax_matcher(
+ data["descriptors0"],
+ data["descriptors1"],
+ threshold=self.conf["match_threshold"],
+ inv_temperature=self.conf["inv_temperature"],
+ )
+ return {
+ "matches0": matches0, # 1 x M
+ "matching_scores0": scores0,
+ }
diff --git a/hloc/matchers/gluestick.py b/hloc/matchers/gluestick.py
new file mode 100644
index 0000000000000000000000000000000000000000..8432d63221e2be0d075375c0cb69b8ebec56612a
--- /dev/null
+++ b/hloc/matchers/gluestick.py
@@ -0,0 +1,108 @@
+import sys
+from pathlib import Path
+import subprocess
+import logging
+import torch
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+
+gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
+sys.path.append(str(gluestick_path))
+
+from gluestick import batch_to_np
+from gluestick.models.two_view_pipeline import TwoViewPipeline
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class GlueStick(BaseModel):
+ default_conf = {
+ "name": "two_view_pipeline",
+ "model_name": "checkpoint_GlueStick_MD.tar",
+ "use_lines": True,
+ "max_keypoints": 1000,
+ "max_lines": 300,
+ "force_num_keypoints": False,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ gluestick_models = {
+ "checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar",
+ }
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_path = gluestick_path / "resources" / "weights" / conf["model_name"]
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.gluestick_models[conf["model_name"]]
+ cmd = ["wget", link, "-O", str(model_path)]
+ logger.info(f"Downloading the Gluestick model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+ logger.info(f"Loading GlueStick model...")
+
+ gluestick_conf = {
+ "name": "two_view_pipeline",
+ "use_lines": True,
+ "extractor": {
+ "name": "wireframe",
+ "sp_params": {
+ "force_num_keypoints": False,
+ "max_num_keypoints": 1000,
+ },
+ "wireframe_params": {
+ "merge_points": True,
+ "merge_line_endpoints": True,
+ },
+ "max_n_lines": 300,
+ },
+ "matcher": {
+ "name": "gluestick",
+ "weights": str(model_path),
+ "trainable": False,
+ },
+ "ground_truth": {
+ "from_pose_depth": False,
+ },
+ }
+ gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[
+ "max_keypoints"
+ ]
+ gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[
+ "force_num_keypoints"
+ ]
+ gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"]
+ self.net = TwoViewPipeline(gluestick_conf)
+
+ def _forward(self, data):
+ pred = self.net(data)
+
+ pred = batch_to_np(pred)
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0 = pred["matches0"]
+
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
+ line_matches = pred["line_matches0"]
+
+ valid_matches = m0 != -1
+ match_indices = m0[valid_matches]
+ matched_kps0 = kp0[valid_matches]
+ matched_kps1 = kp1[match_indices]
+
+ valid_matches = line_matches != -1
+ match_indices = line_matches[valid_matches]
+ matched_lines0 = line_seg0[valid_matches]
+ matched_lines1 = line_seg1[match_indices]
+
+ pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1
+ pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1
+ pred["keypoints0"], pred["keypoints1"] = torch.from_numpy(
+ matched_kps0
+ ), torch.from_numpy(matched_kps1)
+ pred = {**pred, **data}
+ return pred
diff --git a/hloc/matchers/lightglue.py b/hloc/matchers/lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..94fa43d637be2cb3064b234e2232d176a798e4a7
--- /dev/null
+++ b/hloc/matchers/lightglue.py
@@ -0,0 +1,53 @@
+import sys
+from pathlib import Path
+import logging
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
+sys.path.append(str(lightglue_path))
+from lightglue import LightGlue as LG
+
+
+class LightGlue(BaseModel):
+ default_conf = {
+ "match_threshold": 0.2,
+ "filter_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "superpoint",
+ "model_name": "superpoint_lightglue.pth",
+ "flash": True, # enable FlashAttention if available.
+ "mp": False, # enable mixed precision
+ }
+ required_inputs = [
+ "image0",
+ "keypoints0",
+ "scores0",
+ "descriptors0",
+ "image1",
+ "keypoints1",
+ "scores1",
+ "descriptors1",
+ ]
+
+ def _init(self, conf):
+ weight_path = lightglue_path / "weights" / conf["model_name"]
+ conf["weights"] = str(weight_path)
+ conf["filter_threshold"] = conf["match_threshold"]
+ self.net = LG(**conf)
+ logger.info(f"Load lightglue model done.")
+
+ def _forward(self, data):
+ input = {}
+ input["image0"] = {
+ "image": data["image0"],
+ "keypoints": data["keypoints0"][None],
+ "descriptors": data["descriptors0"].permute(0, 2, 1),
+ }
+ input["image1"] = {
+ "image": data["image1"],
+ "keypoints": data["keypoints1"][None],
+ "descriptors": data["descriptors1"].permute(0, 2, 1),
+ }
+ return self.net(input)
diff --git a/hloc/matchers/loftr.py b/hloc/matchers/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d5d4a66694056e60243c8f179a3c19ed2c9be71
--- /dev/null
+++ b/hloc/matchers/loftr.py
@@ -0,0 +1,52 @@
+import torch
+import warnings
+from kornia.feature.loftr.loftr import default_cfg
+from kornia.feature import LoFTR as LoFTR_
+
+from ..utils.base_model import BaseModel
+
+
+class LoFTR(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "match_threshold": 0.2,
+ "max_num_matches": None,
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ cfg = default_cfg
+ cfg["match_coarse"]["thr"] = conf["match_threshold"]
+ self.net = LoFTR_(pretrained=conf["weights"], config=cfg)
+
+ def _forward(self, data):
+ # For consistency with hloc pairs, we refine kpts in image0!
+ rename = {
+ "keypoints0": "keypoints1",
+ "keypoints1": "keypoints0",
+ "image0": "image1",
+ "image1": "image0",
+ "mask0": "mask1",
+ "mask1": "mask0",
+ }
+ data_ = {rename[k]: v for k, v in data.items()}
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ pred = self.net(data_)
+
+ scores = pred["confidence"]
+
+ top_k = self.conf["max_num_matches"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ pred["keypoints0"], pred["keypoints1"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ )
+ scores = scores[keep]
+
+ # Switch back indices
+ pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
+ pred["scores"] = scores
+ del pred["confidence"]
+ return pred
diff --git a/hloc/matchers/nearest_neighbor.py b/hloc/matchers/nearest_neighbor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c036f8e76b01514ee714551a0846842c16283846
--- /dev/null
+++ b/hloc/matchers/nearest_neighbor.py
@@ -0,0 +1,62 @@
+import torch
+
+from ..utils.base_model import BaseModel
+
+
+def find_nn(sim, ratio_thresh, distance_thresh):
+ sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
+ dist_nn = 2 * (1 - sim_nn)
+ mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
+ if ratio_thresh:
+ mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1])
+ if distance_thresh:
+ mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
+ matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
+ scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0))
+ return matches, scores
+
+
+def mutual_check(m0, m1):
+ inds0 = torch.arange(m0.shape[-1], device=m0.device)
+ loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
+ ok = (m0 > -1) & (inds0 == loop)
+ m0_new = torch.where(ok, m0, m0.new_tensor(-1))
+ return m0_new
+
+
+class NearestNeighbor(BaseModel):
+ default_conf = {
+ "ratio_threshold": None,
+ "distance_threshold": None,
+ "do_mutual_check": True,
+ }
+ required_inputs = ["descriptors0", "descriptors1"]
+
+ def _init(self, conf):
+ pass
+
+ def _forward(self, data):
+ if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0:
+ matches0 = torch.full(
+ data["descriptors0"].shape[:2], -1, device=data["descriptors0"].device
+ )
+ return {
+ "matches0": matches0,
+ "matching_scores0": torch.zeros_like(matches0),
+ }
+ ratio_threshold = self.conf["ratio_threshold"]
+ if data["descriptors0"].size(-1) == 1 or data["descriptors1"].size(-1) == 1:
+ ratio_threshold = None
+ sim = torch.einsum("bdn,bdm->bnm", data["descriptors0"], data["descriptors1"])
+ matches0, scores0 = find_nn(
+ sim, ratio_threshold, self.conf["distance_threshold"]
+ )
+ if self.conf["do_mutual_check"]:
+ matches1, scores1 = find_nn(
+ sim.transpose(1, 2), ratio_threshold, self.conf["distance_threshold"]
+ )
+ matches0 = mutual_check(matches0, matches1)
+ return {
+ "matches0": matches0,
+ "matching_scores0": scores0,
+ }
diff --git a/hloc/matchers/roma.py b/hloc/matchers/roma.py
new file mode 100644
index 0000000000000000000000000000000000000000..8394ae271e8855e48a171a66227e6e97fa698426
--- /dev/null
+++ b/hloc/matchers/roma.py
@@ -0,0 +1,91 @@
+import sys
+from pathlib import Path
+import subprocess
+import logging
+import torch
+from PIL import Image
+from ..utils.base_model import BaseModel
+
+roma_path = Path(__file__).parent / "../../third_party/Roma"
+sys.path.append(str(roma_path))
+
+from roma.models.model_zoo.roma_models import roma_model
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+logger = logging.getLogger(__name__)
+
+
+class Roma(BaseModel):
+ default_conf = {
+ "name": "two_view_pipeline",
+ "model_name": "roma_outdoor.pth",
+ "model_utils_name": "dinov2_vitl14_pretrain.pth",
+ "max_keypoints": 3000,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+ weight_urls = {
+ "roma": {
+ "roma_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
+ "roma_indoor.pth": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
+ },
+ "dinov2_vitl14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
+ }
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_path = roma_path / "pretrained" / conf["model_name"]
+ dinov2_weights = roma_path / "pretrained" / conf["model_utils_name"]
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.weight_urls["roma"][conf["model_name"]]
+ cmd = ["wget", link, "-O", str(model_path)]
+ logger.info(f"Downloading the Roma model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ if not dinov2_weights.exists():
+ dinov2_weights.parent.mkdir(exist_ok=True)
+ link = self.weight_urls[conf["model_utils_name"]]
+ cmd = ["wget", link, "-O", str(dinov2_weights)]
+ logger.info(f"Downloading the dinov2 model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ logger.info(f"Loading Roma model...")
+ # load the model
+ weights = torch.load(model_path, map_location="cpu")
+ dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
+
+ self.net = roma_model(
+ resolution=(14 * 8 * 6, 14 * 8 * 6),
+ upsample_preds=False,
+ weights=weights,
+ dinov2_weights=dinov2_weights,
+ device=device,
+ )
+ logger.info(f"Load Roma model done.")
+
+ def _forward(self, data):
+ img0 = data["image0"].cpu().numpy().squeeze() * 255
+ img1 = data["image1"].cpu().numpy().squeeze() * 255
+ img0 = img0.transpose(1, 2, 0)
+ img1 = img1.transpose(1, 2, 0)
+ img0 = Image.fromarray(img0.astype("uint8"))
+ img1 = Image.fromarray(img1.astype("uint8"))
+ W_A, H_A = img0.size
+ W_B, H_B = img1.size
+
+ # Match
+ warp, certainty = self.net.match(img0, img1, device=device)
+ # Sample matches for estimation
+ matches, certainty = self.net.sample(
+ warp, certainty, num=self.conf["max_keypoints"]
+ )
+ kpts1, kpts2 = self.net.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ pred = {}
+ pred["keypoints0"], pred["keypoints1"] = kpts1, kpts2
+ pred["mconf"] = certainty
+ return pred
diff --git a/hloc/matchers/sgmnet.py b/hloc/matchers/sgmnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..387bfa5ea4316a5794b1a6550783b4fb8d59600a
--- /dev/null
+++ b/hloc/matchers/sgmnet.py
@@ -0,0 +1,128 @@
+import sys
+from pathlib import Path
+import subprocess
+import logging
+import torch
+from PIL import Image
+from collections import OrderedDict, namedtuple
+from ..utils.base_model import BaseModel
+from ..utils import do_system
+
+sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
+sys.path.append(str(sgmnet_path))
+
+from sgmnet import matcher as SGM_Model
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+logger = logging.getLogger(__name__)
+
+
+class SGMNet(BaseModel):
+ default_conf = {
+ "name": "SGM",
+ "model_name": "model_best.pth",
+ "seed_top_k": [256, 256],
+ "seed_radius_coe": 0.01,
+ "net_channels": 128,
+ "layer_num": 9,
+ "head": 4,
+ "seedlayer": [0, 6],
+ "use_mc_seeding": True,
+ "use_score_encoding": False,
+ "conf_bar": [1.11, 0.1],
+ "sink_iter": [10, 100],
+ "detach_iter": 1000000,
+ "match_threshold": 0.2,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+ weight_urls = {
+ "model_best.pth": "https://drive.google.com/uc?id=1Ca0WmKSSt2G6P7m8YAOlSAHEFar_TAWb&confirm=t",
+ }
+ proxy = "http://localhost:1080"
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ sgmnet_weights = sgmnet_path / "weights/sgm/root" / conf["model_name"]
+
+ link = self.weight_urls[conf["model_name"]]
+ tar_path = sgmnet_path / "weights.tar.gz"
+ # Download the model.
+ if not sgmnet_weights.exists():
+ if not tar_path.exists():
+ cmd = ["gdown", link, "-O", str(tar_path), "--proxy", self.proxy]
+ cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
+ logger.info(f"Downloading the SGMNet model with `{cmd_wo_proxy}`.")
+ try:
+ subprocess.run(cmd_wo_proxy, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.info(f"Downloading the SGMNet model with `{cmd}`.")
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f"Failed to download the SGMNet model.")
+ raise e
+ cmd = [f"cd {str(sgmnet_path)} & tar -xvf", str(tar_path)]
+ logger.info(f"Unzip model file `{cmd}`.")
+ do_system(f"cd {str(sgmnet_path)} & tar -xvf {str(tar_path)}")
+
+ # config
+ config = namedtuple("config", conf.keys())(*conf.values())
+ self.net = SGM_Model(config)
+ checkpoint = torch.load(sgmnet_weights, map_location="cpu")
+ # for ddp model
+ if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
+ new_stat_dict = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ new_stat_dict[key[7:]] = value
+ checkpoint["state_dict"] = new_stat_dict
+ self.net.load_state_dict(checkpoint["state_dict"])
+ logger.info(f"Load SGMNet model done.")
+
+ def _forward(self, data):
+ x1 = data["keypoints0"] # N x 2
+ x2 = data["keypoints1"]
+ score1 = data["scores0"].reshape(-1, 1) # N x 1
+ score2 = data["scores1"].reshape(-1, 1)
+ desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
+ desc2 = data["descriptors1"].permute(0, 2, 1)
+ size1 = torch.tensor(data["image0"].shape[2:]).flip(0) # W x H -> x & y
+ size2 = torch.tensor(data["image1"].shape[2:]).flip(0) # W x H
+ norm_x1 = self.normalize_size(x1, size1)
+ norm_x2 = self.normalize_size(x2, size2)
+
+ x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3
+ x2 = torch.cat((norm_x2, score2), dim=-1)
+ input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2}
+ input = {
+ k: v.to(device).float() if isinstance(v, torch.Tensor) else v
+ for k, v in input.items()
+ }
+ pred = self.net(input, test_mode=True)
+
+ p = pred["p"] # shape: N * M
+ indices0 = self.match_p(p[0, :-1, :-1])
+ pred = {
+ "matches0": indices0.unsqueeze(0),
+ "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0),
+ }
+ return pred
+
+ def match_p(self, p):
+ score, index = torch.topk(p, k=1, dim=-1)
+ _, index2 = torch.topk(p, k=1, dim=-2)
+ mask_th, index, index2 = (
+ score[:, 0] > self.conf["match_threshold"],
+ index[:, 0],
+ index2.squeeze(0),
+ )
+ mask_mc = index2[index] == torch.arange(len(p)).cuda()
+ mask = mask_th & mask_mc
+ indices0 = torch.where(mask, index, index.new_tensor(-1))
+ return indices0
+
+ def normalize_size(self, x, size, scale=1):
+ norm_fac = size.max()
+ return (x - size / 2 + 0.5) / (norm_fac * scale)
diff --git a/hloc/matchers/sold2.py b/hloc/matchers/sold2.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f551ce376d3c1f6d5f256b9880bae820704708
--- /dev/null
+++ b/hloc/matchers/sold2.py
@@ -0,0 +1,144 @@
+import sys
+from pathlib import Path
+from ..utils.base_model import BaseModel
+import torch
+
+from ..utils.base_model import BaseModel
+
+sold2_path = Path(__file__).parent / "../../third_party/SOLD2"
+sys.path.append(str(sold2_path))
+
+from sold2.model.line_matcher import LineMatcher
+from sold2.misc.visualize_util import (
+ plot_images,
+ plot_lines,
+ plot_line_matches,
+ plot_color_line_matches,
+ plot_keypoints,
+)
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class SOLD2(BaseModel):
+ default_conf = {
+ "weights": "sold2_wireframe.tar",
+ "match_threshold": 0.2,
+ "checkpoint_dir": sold2_path / "pretrained",
+ "detect_thresh": 0.25,
+ "multiscale": False,
+ "valid_thresh": 1e-3,
+ "num_blocks": 20,
+ "overlap_ratio": 0.5,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+ # Initialize the line matcher
+ def _init(self, conf):
+ checkpoint_path = conf["checkpoint_dir"] / conf["weights"]
+ mode = "dynamic" # 'dynamic' or 'static'
+ match_config = {
+ "model_cfg": {
+ "model_name": "lcnn_simple",
+ "model_architecture": "simple",
+ # Backbone related config
+ "backbone": "lcnn",
+ "backbone_cfg": {
+ "input_channel": 1, # Use RGB images or grayscale images.
+ "depth": 4,
+ "num_stacks": 2,
+ "num_blocks": 1,
+ "num_classes": 5,
+ },
+ # Junction decoder related config
+ "junction_decoder": "superpoint_decoder",
+ "junc_decoder_cfg": {},
+ # Heatmap decoder related config
+ "heatmap_decoder": "pixel_shuffle",
+ "heatmap_decoder_cfg": {},
+ # Descriptor decoder related config
+ "descriptor_decoder": "superpoint_descriptor",
+ "descriptor_decoder_cfg": {},
+ # Shared configurations
+ "grid_size": 8,
+ "keep_border_valid": True,
+ # Threshold of junction detection
+ "detection_thresh": 0.0153846, # 1/65
+ "max_num_junctions": 300,
+ # Threshold of heatmap detection
+ "prob_thresh": 0.5,
+ # Weighting related parameters
+ "weighting_policy": mode,
+ # [Heatmap loss]
+ "w_heatmap": 0.0,
+ "w_heatmap_class": 1,
+ "heatmap_loss_func": "cross_entropy",
+ "heatmap_loss_cfg": {"policy": mode},
+ # [Heatmap consistency loss]
+ # [Junction loss]
+ "w_junc": 0.0,
+ "junction_loss_func": "superpoint",
+ "junction_loss_cfg": {"policy": mode},
+ # [Descriptor loss]
+ "w_desc": 0.0,
+ "descriptor_loss_func": "regular_sampling",
+ "descriptor_loss_cfg": {
+ "dist_threshold": 8,
+ "grid_size": 4,
+ "margin": 1,
+ "policy": mode,
+ },
+ },
+ "line_detector_cfg": {
+ "detect_thresh": 0.25, # depending on your images, you might need to tune this parameter
+ "num_samples": 64,
+ "sampling_method": "local_max",
+ "inlier_thresh": 0.9,
+ "use_candidate_suppression": True,
+ "nms_dist_tolerance": 3.0,
+ "use_heatmap_refinement": True,
+ "heatmap_refine_cfg": {
+ "mode": "local",
+ "ratio": 0.2,
+ "valid_thresh": 1e-3,
+ "num_blocks": 20,
+ "overlap_ratio": 0.5,
+ },
+ },
+ "multiscale": False,
+ "line_matcher_cfg": {
+ "cross_check": True,
+ "num_samples": 5,
+ "min_dist_pts": 8,
+ "top_k_candidates": 10,
+ "grid_size": 4,
+ },
+ }
+ self.net = LineMatcher(
+ match_config["model_cfg"],
+ checkpoint_path,
+ device,
+ match_config["line_detector_cfg"],
+ match_config["line_matcher_cfg"],
+ match_config["multiscale"],
+ )
+
+ def _forward(self, data):
+ img0 = data["image0"]
+ img1 = data["image1"]
+ pred = self.net([img0, img1])
+ line_seg1 = pred["line_segments"][0]
+ line_seg2 = pred["line_segments"][1]
+ matches = pred["matches"]
+
+ valid_matches = matches != -1
+ match_indices = matches[valid_matches]
+ matched_lines1 = line_seg1[valid_matches][:, :, ::-1]
+ matched_lines2 = line_seg2[match_indices][:, :, ::-1]
+
+ pred["raw_lines0"], pred["raw_lines1"] = line_seg1, line_seg2
+ pred["lines0"], pred["lines1"] = matched_lines1, matched_lines2
+ pred = {**pred, **data}
+ return pred
diff --git a/hloc/matchers/superglue.py b/hloc/matchers/superglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e427f4908f9af676d0627643393a8090c40a00a
--- /dev/null
+++ b/hloc/matchers/superglue.py
@@ -0,0 +1,31 @@
+import sys
+from pathlib import Path
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from SuperGluePretrainedNetwork.models.superglue import SuperGlue as SG
+
+
+class SuperGlue(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "sinkhorn_iterations": 100,
+ "match_threshold": 0.2,
+ }
+ required_inputs = [
+ "image0",
+ "keypoints0",
+ "scores0",
+ "descriptors0",
+ "image1",
+ "keypoints1",
+ "scores1",
+ "descriptors1",
+ ]
+
+ def _init(self, conf):
+ self.net = SG(conf)
+
+ def _forward(self, data):
+ return self.net(data)
diff --git a/hloc/matchers/topicfm.py b/hloc/matchers/topicfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1122147b8079021f0d0e501787d1343c588bd7
--- /dev/null
+++ b/hloc/matchers/topicfm.py
@@ -0,0 +1,45 @@
+import torch
+import warnings
+from ..utils.base_model import BaseModel
+import sys
+from pathlib import Path
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from TopicFM.src.models.topic_fm import TopicFM as _TopicFM
+from TopicFM.src import get_model_cfg
+
+topicfm_path = Path(__file__).parent / "../../third_party/TopicFM"
+
+
+class TopicFM(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "match_threshold": 0.2,
+ "n_sampling_topics": 4,
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ _conf = dict(get_model_cfg())
+ _conf["match_coarse"]["thr"] = conf["match_threshold"]
+ _conf["coarse"]["n_samples"] = conf["n_sampling_topics"]
+ weight_path = topicfm_path / "pretrained/model_best.ckpt"
+ self.net = _TopicFM(config=_conf)
+ ckpt_dict = torch.load(weight_path, map_location="cpu")
+ self.net.load_state_dict(ckpt_dict["state_dict"])
+
+ def _forward(self, data):
+ data_ = {
+ "image0": data["image0"],
+ "image1": data["image1"],
+ }
+ self.net(data_)
+ mkpts0 = data_["mkpts0_f"]
+ mkpts1 = data_["mkpts1_f"]
+ mconf = data_["mconf"]
+ total_n_matches = len(data_["mkpts0_f"])
+
+ pred = {}
+ pred["keypoints0"], pred["keypoints1"] = mkpts0, mkpts1
+ pred["mconf"] = mconf
+ return pred
diff --git a/hloc/pipelines/4Seasons/README.md b/hloc/pipelines/4Seasons/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad23ac8348ae9f0963611bc9a342240d5ae97255
--- /dev/null
+++ b/hloc/pipelines/4Seasons/README.md
@@ -0,0 +1,43 @@
+# 4Seasons dataset
+
+This pipeline localizes sequences from the [4Seasons dataset](https://arxiv.org/abs/2009.06364) and can reproduce our winning submission to the challenge of the [ECCV 2020 Workshop on Map-based Localization for Autonomous Driving](https://sites.google.com/view/mlad-eccv2020/home).
+
+## Installation
+
+Download the sequences from the [challenge webpage](https://sites.google.com/view/mlad-eccv2020/challenge) and run:
+```bash
+unzip recording_2020-04-07_10-20-32.zip -d datasets/4Seasons/reference
+unzip recording_2020-03-24_17-36-22.zip -d datasets/4Seasons/training
+unzip recording_2020-03-03_12-03-23.zip -d datasets/4Seasons/validation
+unzip recording_2020-03-24_17-45-31.zip -d datasets/4Seasons/test0
+unzip recording_2020-04-23_19-37-00.zip -d datasets/4Seasons/test1
+```
+Note that the provided scripts might modify the dataset files by deleting unused images to speed up the feature extraction
+
+## Pipeline
+
+The process is presented in our workshop talk, whose recording can be found [here](https://youtu.be/M-X6HX1JxYk?t=5245).
+
+We first triangulate a 3D model from the given poses of the reference sequence:
+```bash
+python3 -m hloc.pipelines.4Seasons.prepare_reference
+```
+
+We then relocalize a given sequence:
+```bash
+python3 -m hloc.pipelines.4Seasons.localize --sequence [training|validation|test0|test1]
+```
+
+The final submission files can be found in `outputs/4Seasons/submission_hloc+superglue/`. The script will also evaluate these results if the training or validation sequences are selected.
+
+## Results
+
+We evaluate the localization recall at distance thresholds 0.1m, 0.2m, and 0.5m.
+
+| Methods | test0 | test1 |
+| -------------------- | ---------------------- | ---------------------- |
+| **hloc + SuperGlue** | **91.8 / 97.7 / 99.2** | **67.3 / 93.5 / 98.7** |
+| Baseline SuperGlue | 21.2 / 33.9 / 60.0 | 12.4 / 26.5 / 54.4 |
+| Baseline R2D2 | 21.5 / 33.1 / 53.0 | 12.3 / 23.7 / 42.0 |
+| Baseline D2Net | 12.5 / 29.3 / 56.7 | 7.5 / 21.4 / 47.7 |
+| Baseline SuperPoint | 15.5 / 27.5 / 47.5 | 9.0 / 19.4 / 36.4 |
diff --git a/hloc/pipelines/4Seasons/__init__.py b/hloc/pipelines/4Seasons/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/4Seasons/localize.py b/hloc/pipelines/4Seasons/localize.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e26bef5116367feaf2123b5cf55f7d621e2536f
--- /dev/null
+++ b/hloc/pipelines/4Seasons/localize.py
@@ -0,0 +1,84 @@
+from pathlib import Path
+import argparse
+
+from ... import extract_features, match_features, localize_sfm, logger
+from .utils import get_timestamps, delete_unused_images
+from .utils import generate_query_lists, generate_localization_pairs
+from .utils import prepare_submission, evaluate_submission
+
+relocalization_files = {
+ "training": "RelocalizationFilesTrain//relocalizationFile_recording_2020-03-24_17-36-22.txt",
+ "validation": "RelocalizationFilesVal/relocalizationFile_recording_2020-03-03_12-03-23.txt",
+ "test0": "RelocalizationFilesTest/relocalizationFile_recording_2020-03-24_17-45-31_*.txt",
+ "test1": "RelocalizationFilesTest/relocalizationFile_recording_2020-04-23_19-37-00_*.txt",
+}
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--sequence",
+ type=str,
+ required=True,
+ choices=["training", "validation", "test0", "test1"],
+ help="Sequence to be relocalized.",
+)
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/4Seasons",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/4Seasons",
+ help="Path to the output directory, default: %(default)s",
+)
+args = parser.parse_args()
+sequence = args.sequence
+
+data_dir = args.dataset
+ref_dir = data_dir / "reference"
+assert ref_dir.exists(), f"{ref_dir} does not exist"
+seq_dir = data_dir / sequence
+assert seq_dir.exists(), f"{seq_dir} does not exist"
+seq_images = seq_dir / "undistorted_images"
+reloc = ref_dir / relocalization_files[sequence]
+
+output_dir = args.outputs
+output_dir.mkdir(exist_ok=True, parents=True)
+query_list = output_dir / f"{sequence}_queries_with_intrinsics.txt"
+ref_pairs = output_dir / "pairs-db-dist20.txt"
+ref_sfm = output_dir / "sfm_superpoint+superglue"
+results_path = output_dir / f"localization_{sequence}_hloc+superglue.txt"
+submission_dir = output_dir / "submission_hloc+superglue"
+
+num_loc_pairs = 10
+loc_pairs = output_dir / f"pairs-query-{sequence}-dist{num_loc_pairs}.txt"
+
+fconf = extract_features.confs["superpoint_max"]
+mconf = match_features.confs["superglue"]
+
+# Not all query images that are used for the evaluation
+# To save time in feature extraction, we delete unsused images.
+timestamps = get_timestamps(reloc, 1)
+delete_unused_images(seq_images, timestamps)
+
+# Generate a list of query images with their intrinsics.
+generate_query_lists(timestamps, seq_dir, query_list)
+
+# Generate the localization pairs from the given reference frames.
+generate_localization_pairs(sequence, reloc, num_loc_pairs, ref_pairs, loc_pairs)
+
+# Extract, match, amd localize.
+ffile = extract_features.main(fconf, seq_images, output_dir)
+mfile = match_features.main(mconf, loc_pairs, fconf["output"], output_dir)
+localize_sfm.main(ref_sfm, query_list, loc_pairs, ffile, mfile, results_path)
+
+# Convert the absolute poses to relative poses with the reference frames.
+submission_dir.mkdir(exist_ok=True)
+prepare_submission(results_path, reloc, ref_dir / "poses.txt", submission_dir)
+
+# If not a test sequence: evaluation the localization accuracy
+if "test" not in sequence:
+ logger.info("Evaluating the relocalization submission...")
+ evaluate_submission(submission_dir, reloc)
diff --git a/hloc/pipelines/4Seasons/prepare_reference.py b/hloc/pipelines/4Seasons/prepare_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..9074df808f5cf345f0cf2c7dada2b6b03f50261a
--- /dev/null
+++ b/hloc/pipelines/4Seasons/prepare_reference.py
@@ -0,0 +1,53 @@
+from pathlib import Path
+import argparse
+
+from ... import extract_features, match_features
+from ... import pairs_from_poses, triangulation
+from .utils import get_timestamps, delete_unused_images
+from .utils import build_empty_colmap_model
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/4Seasons",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/4Seasons",
+ help="Path to the output directory, default: %(default)s",
+)
+args = parser.parse_args()
+
+ref_dir = args.dataset / "reference"
+assert ref_dir.exists(), f"{ref_dir} does not exist"
+ref_images = ref_dir / "undistorted_images"
+
+output_dir = args.outputs
+output_dir.mkdir(exist_ok=True, parents=True)
+ref_sfm_empty = output_dir / "sfm_reference_empty"
+ref_sfm = output_dir / "sfm_superpoint+superglue"
+
+num_ref_pairs = 20
+ref_pairs = output_dir / f"pairs-db-dist{num_ref_pairs}.txt"
+
+fconf = extract_features.confs["superpoint_max"]
+mconf = match_features.confs["superglue"]
+
+# Only reference images that have a pose are used in the pipeline.
+# To save time in feature extraction, we delete unsused images.
+delete_unused_images(ref_images, get_timestamps(ref_dir / "poses.txt", 0))
+
+# Build an empty COLMAP model containing only camera and images
+# from the provided poses and intrinsics.
+build_empty_colmap_model(ref_dir, ref_sfm_empty)
+
+# Match reference images that are spatially close.
+pairs_from_poses.main(ref_sfm_empty, ref_pairs, num_ref_pairs)
+
+# Extract, match, and triangulate the reference SfM model.
+ffile = extract_features.main(fconf, ref_images, output_dir)
+mfile = match_features.main(mconf, ref_pairs, fconf["output"], output_dir)
+triangulation.main(ref_sfm, ref_sfm_empty, ref_images, ref_pairs, ffile, mfile)
diff --git a/hloc/pipelines/4Seasons/utils.py b/hloc/pipelines/4Seasons/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ecfa7d52ddb0e35de1c77280e6664b93526305
--- /dev/null
+++ b/hloc/pipelines/4Seasons/utils.py
@@ -0,0 +1,224 @@
+import os
+import numpy as np
+import logging
+from pathlib import Path
+
+from ...utils.read_write_model import qvec2rotmat, rotmat2qvec
+from ...utils.read_write_model import Image, write_model, Camera
+from ...utils.parsers import parse_retrieval
+
+logger = logging.getLogger(__name__)
+
+
+def get_timestamps(files, idx):
+ """Extract timestamps from a pose or relocalization file."""
+ lines = []
+ for p in files.parent.glob(files.name):
+ with open(p) as f:
+ lines += f.readlines()
+ timestamps = set()
+ for line in lines:
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ ts = line.replace(",", " ").split()[idx]
+ timestamps.add(ts)
+ return timestamps
+
+
+def delete_unused_images(root, timestamps):
+ """Delete all images in root if they are not contained in timestamps."""
+ images = list(root.glob("**/*.png"))
+ deleted = 0
+ for image in images:
+ ts = image.stem
+ if ts not in timestamps:
+ os.remove(image)
+ deleted += 1
+ logger.info(f"Deleted {deleted} images in {root}.")
+
+
+def camera_from_calibration_file(id_, path):
+ """Create a COLMAP camera from an MLAD calibration file."""
+ with open(path, "r") as f:
+ data = f.readlines()
+ model, fx, fy, cx, cy = data[0].split()[:5]
+ width, height = data[1].split()
+ assert model == "Pinhole"
+ model_name = "PINHOLE"
+ params = [float(i) for i in [fx, fy, cx, cy]]
+ camera = Camera(
+ id=id_, model=model_name, width=int(width), height=int(height), params=params
+ )
+ return camera
+
+
+def parse_poses(path, colmap=False):
+ """Parse a list of poses in COLMAP or MLAD quaternion convention."""
+ poses = []
+ with open(path) as f:
+ for line in f.readlines():
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ data = line.replace(",", " ").split()
+ ts, p = data[0], np.array(data[1:], float)
+ if colmap:
+ q, t = np.split(p, [4])
+ else:
+ t, q = np.split(p, [3])
+ q = q[[3, 0, 1, 2]] # xyzw to wxyz
+ R = qvec2rotmat(q)
+ poses.append((ts, R, t))
+ return poses
+
+
+def parse_relocalization(path, has_poses=False):
+ """Parse a relocalization file, possibly with poses."""
+ reloc = []
+ with open(path) as f:
+ for line in f.readlines():
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ data = line.replace(",", " ").split()
+ out = data[:2] # ref_ts, q_ts
+ if has_poses:
+ assert len(data) == 9
+ t, q = np.split(np.array(data[2:], float), [3])
+ q = q[[3, 0, 1, 2]] # xyzw to wxyz
+ R = qvec2rotmat(q)
+ out += [R, t]
+ reloc.append(out)
+ return reloc
+
+
+def build_empty_colmap_model(root, sfm_dir):
+ """Build a COLMAP model with images and cameras only."""
+ calibration = "Calibration/undistorted_calib_{}.txt"
+ cam0 = camera_from_calibration_file(0, root / calibration.format(0))
+ cam1 = camera_from_calibration_file(1, root / calibration.format(1))
+ cameras = {0: cam0, 1: cam1}
+
+ T_0to1 = np.loadtxt(root / "Calibration/undistorted_calib_stereo.txt")
+ poses = parse_poses(root / "poses.txt")
+ images = {}
+ id_ = 0
+ for ts, R_cam0_to_w, t_cam0_to_w in poses:
+ R_w_to_cam0 = R_cam0_to_w.T
+ t_w_to_cam0 = -(R_w_to_cam0 @ t_cam0_to_w)
+
+ R_w_to_cam1 = T_0to1[:3, :3] @ R_w_to_cam0
+ t_w_to_cam1 = T_0to1[:3, :3] @ t_w_to_cam0 + T_0to1[:3, 3]
+
+ for idx, (R_w_to_cam, t_w_to_cam) in enumerate(
+ zip([R_w_to_cam0, R_w_to_cam1], [t_w_to_cam0, t_w_to_cam1])
+ ):
+ image = Image(
+ id=id_,
+ qvec=rotmat2qvec(R_w_to_cam),
+ tvec=t_w_to_cam,
+ camera_id=idx,
+ name=f"cam{idx}/{ts}.png",
+ xys=np.zeros((0, 2), float),
+ point3D_ids=np.full(0, -1, int),
+ )
+ images[id_] = image
+ id_ += 1
+
+ sfm_dir.mkdir(exist_ok=True, parents=True)
+ write_model(cameras, images, {}, path=str(sfm_dir), ext=".bin")
+
+
+def generate_query_lists(timestamps, seq_dir, out_path):
+ """Create a list of query images with intrinsics from timestamps."""
+ cam0 = camera_from_calibration_file(
+ 0, seq_dir / "Calibration/undistorted_calib_0.txt"
+ )
+ intrinsics = [cam0.model, cam0.width, cam0.height] + cam0.params
+ intrinsics = [str(p) for p in intrinsics]
+ data = map(lambda ts: " ".join([f"cam0/{ts}.png"] + intrinsics), timestamps)
+ with open(out_path, "w") as f:
+ f.write("\n".join(data))
+
+
+def generate_localization_pairs(sequence, reloc, num, ref_pairs, out_path):
+ """Create the matching pairs for the localization.
+ We simply lookup the corresponding reference frame
+ and extract its `num` closest frames from the existing pair list.
+ """
+ if "test" in sequence:
+ # hard pairs will be overwritten by easy ones if available
+ relocs = [str(reloc).replace("*", d) for d in ["hard", "moderate", "easy"]]
+ else:
+ relocs = [reloc]
+ query_to_ref_ts = {}
+ for reloc in relocs:
+ with open(reloc, "r") as f:
+ for line in f.readlines():
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ ref_ts, q_ts = line.split()[:2]
+ query_to_ref_ts[q_ts] = ref_ts
+
+ ts_to_name = "cam0/{}.png".format
+ ref_pairs = parse_retrieval(ref_pairs)
+ loc_pairs = []
+ for q_ts, ref_ts in query_to_ref_ts.items():
+ ref_name = ts_to_name(ref_ts)
+ selected = [ref_name] + ref_pairs[ref_name][: num - 1]
+ loc_pairs.extend([" ".join((ts_to_name(q_ts), s)) for s in selected])
+ with open(out_path, "w") as f:
+ f.write("\n".join(loc_pairs))
+
+
+def prepare_submission(results, relocs, poses_path, out_dir):
+ """Obtain relative poses from estimated absolute and reference poses."""
+ gt_poses = parse_poses(poses_path)
+ all_T_ref0_to_w = {ts: (R, t) for ts, R, t in gt_poses}
+
+ pred_poses = parse_poses(results, colmap=True)
+ all_T_w_to_q0 = {Path(name).stem: (R, t) for name, R, t in pred_poses}
+
+ for reloc in relocs.parent.glob(relocs.name):
+ relative_poses = []
+ reloc_ts = parse_relocalization(reloc)
+ for ref_ts, q_ts in reloc_ts:
+ R_w_to_q0, t_w_to_q0 = all_T_w_to_q0[q_ts]
+ R_ref0_to_w, t_ref0_to_w = all_T_ref0_to_w[ref_ts]
+
+ R_ref0_to_q0 = R_w_to_q0 @ R_ref0_to_w
+ t_ref0_to_q0 = R_w_to_q0 @ t_ref0_to_w + t_w_to_q0
+
+ tvec = t_ref0_to_q0.tolist()
+ qvec = rotmat2qvec(R_ref0_to_q0)[[1, 2, 3, 0]] # wxyz to xyzw
+
+ out = [ref_ts, q_ts] + list(map(str, tvec)) + list(map(str, qvec))
+ relative_poses.append(" ".join(out))
+
+ out_path = out_dir / reloc.name
+ with open(out_path, "w") as f:
+ f.write("\n".join(relative_poses))
+ logger.info(f"Submission file written to {out_path}.")
+
+
+def evaluate_submission(submission_dir, relocs, ths=[0.1, 0.2, 0.5]):
+ """Compute the relocalization recall from predicted and ground truth poses."""
+ for reloc in relocs.parent.glob(relocs.name):
+ poses_gt = parse_relocalization(reloc, has_poses=True)
+ poses_pred = parse_relocalization(submission_dir / reloc.name, has_poses=True)
+ poses_pred = {(ref_ts, q_ts): (R, t) for ref_ts, q_ts, R, t in poses_pred}
+
+ error = []
+ for ref_ts, q_ts, R_gt, t_gt in poses_gt:
+ R, t = poses_pred[(ref_ts, q_ts)]
+ e = np.linalg.norm(t - t_gt)
+ error.append(e)
+
+ error = np.array(error)
+ recall = [np.mean(error <= th) for th in ths]
+ s = f"Relocalization evaluation {submission_dir.name}/{reloc.name}\n"
+ s += " / ".join([f"{th:>7}m" for th in ths]) + "\n"
+ s += " / ".join([f"{100*r:>7.3f}%" for r in recall])
+ logger.info(s)
diff --git a/hloc/pipelines/7Scenes/README.md b/hloc/pipelines/7Scenes/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2124779c43ec8d1ffc552e07790d39c3578526a9
--- /dev/null
+++ b/hloc/pipelines/7Scenes/README.md
@@ -0,0 +1,65 @@
+# 7Scenes dataset
+
+## Installation
+
+Download the images from the [7Scenes project page](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/):
+```bash
+export dataset=datasets/7scenes
+for scene in chess fire heads office pumpkin redkitchen stairs; \
+do wget http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/$scene.zip -P $dataset \
+&& unzip $dataset/$scene.zip -d $dataset && unzip $dataset/$scene/'*.zip' -d $dataset/$scene; done
+```
+
+Download the SIFT SfM models and DenseVLAD image pairs, courtesy of Torsten Sattler:
+```bash
+function download {
+wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$1" -O $2 && rm -rf /tmp/cookies.txt
+unzip $2 -d $dataset && rm $2;
+}
+download 1cu6KUR7WHO7G4EO49Qi3HEKU6n_yYDjb $dataset/7scenes_sfm_triangulated.zip
+download 1IbS2vLmxr1N0f3CEnd_wsYlgclwTyvB1 $dataset/7scenes_densevlad_retrieval_top_10.zip
+```
+
+Download the rendered depth maps, courtesy of Eric Brachmann for [DSAC\*](https://github.com/vislearn/dsacstar):
+```bash
+wget https://heidata.uni-heidelberg.de/api/access/datafile/4037 -O $dataset/7scenes_rendered_depth.tar.gz
+mkdir $dataset/depth/
+tar xzf $dataset/7scenes_rendered_depth.tar.gz -C $dataset/depth/ && rm $dataset/7scenes_rendered_depth.tar.gz
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.7Scenes.pipeline [--use_dense_depth]
+```
+By default, hloc triangulates a sparse point cloud that can be noisy in indoor environements due to image noise and lack of texture. With the flag `--use_dense_depth`, the pipeline improves the accuracy of the sparse point cloud using dense depth maps provided by the dataset. The original depth maps captured by the RGBD sensor are miscalibrated, so we use depth maps rendered from the mesh obtained by fusing the RGBD data.
+
+## Results
+We report the median error in translation/rotation in cm/deg over all scenes:
+| Method \ Scene | Chess | Fire | Heads | Office | Pumpkin | Kitchen | Stairs |
+| ------------------------------- | -------------- | -------------- | -------------- | -------------- | -------------- | -------------- | ---------- |
+| Active Search | 3/0.87 | **2**/1.01 | **1**/0.82 | 4/1.15 | 7/1.69 | 5/1.72 | 4/**1.01** |
+| DSAC* | **2**/1.10 | **2**/1.24 | **1**/1.82 | **3**/1.15 | **4**/1.34 | 4/1.68 | **3**/1.16 |
+| **SuperPoint+SuperGlue** (sfm) | **2**/0.84 | **2**/0.93 | **1**/**0.74** | **3**/0.92 | 5/1.27 | 4/1.40 | 5/1.47 |
+| **SuperPoint+SuperGlue** (RGBD) | **2**/**0.80** | **2**/**0.77** | **1**/0.79 | **3**/**0.80** | **4**/**1.07** | **3**/**1.13** | 4/1.15 |
+
+## Citation
+Please cite the following paper if you use the 7Scenes dataset:
+```
+@inproceedings{shotton2013scene,
+ title={Scene coordinate regression forests for camera relocalization in {RGB-D} images},
+ author={Shotton, Jamie and Glocker, Ben and Zach, Christopher and Izadi, Shahram and Criminisi, Antonio and Fitzgibbon, Andrew},
+ booktitle={CVPR},
+ year={2013}
+}
+```
+
+Also cite DSAC* if you use dense depth maps with the flag `--use_dense_depth`:
+```
+@article{brachmann2020dsacstar,
+ title={Visual Camera Re-Localization from {RGB} and {RGB-D} Images Using {DSAC}},
+ author={Brachmann, Eric and Rother, Carsten},
+ journal={TPAMI},
+ year={2021}
+}
+```
diff --git a/hloc/pipelines/7Scenes/__init__.py b/hloc/pipelines/7Scenes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/7Scenes/create_gt_sfm.py b/hloc/pipelines/7Scenes/create_gt_sfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..af219c9666fe96c57bfa59f28cec4b1bd6e233cb
--- /dev/null
+++ b/hloc/pipelines/7Scenes/create_gt_sfm.py
@@ -0,0 +1,135 @@
+from pathlib import Path
+import numpy as np
+import torch
+import PIL.Image
+from tqdm import tqdm
+import pycolmap
+
+from ...utils.read_write_model import write_model, read_model
+
+
+def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera):
+ assert len(depth) == len(p2D)
+ ret = pycolmap.image_to_world(p2D, camera._asdict())
+ p2D_norm = np.asarray(ret["world_points"])
+ p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1)
+ p3D_c = p2D_h * depth[:, None]
+ p3D_w = (p3D_c - t_w2c) @ R_w2c
+ return p3D_w
+
+
+def interpolate_depth(depth, kp):
+ h, w = depth.shape
+ kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
+ assert np.all(kp > -1) and np.all(kp < 1)
+ depth = torch.from_numpy(depth)[None, None]
+ kp = torch.from_numpy(kp)[None, None]
+ grid_sample = torch.nn.functional.grid_sample
+
+ # To maximize the number of points that have depth:
+ # do bilinear interpolation first and then nearest for the remaining points
+ interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[0, :, 0]
+ interp_nn = torch.nn.functional.grid_sample(
+ depth, kp, align_corners=True, mode="nearest"
+ )[0, :, 0]
+ interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
+ valid = ~torch.any(torch.isnan(interp), 0)
+
+ interp_depth = interp.T.numpy().flatten()
+ valid = valid.numpy()
+ return interp_depth, valid
+
+
+def image_path_to_rendered_depth_path(image_name):
+ parts = image_name.split("/")
+ name = "_".join(["".join(parts[0].split("-")), parts[1]])
+ name = name.replace("color", "pose")
+ name = name.replace("png", "depth.tiff")
+ return name
+
+
+def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1):
+ p3D = (p3D @ R.T) + t
+ visible = p3D[:, -1] >= eps # keep points in front of the camera
+ p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps)
+ ret = pycolmap.world_to_image(p2D_norm, camera._asdict())
+ p2D = np.asarray(ret["image_points"])
+ size = np.array([camera.width - pad - 1, camera.height - pad - 1])
+ valid = np.all((p2D >= pad) & (p2D <= size), -1)
+ valid &= visible
+ return p2D[valid], valid
+
+
+def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path):
+ cameras, images, points3D = read_model(sfm_path)
+ for imgid, img in tqdm(images.items()):
+ image_name = img.name
+ depth_name = image_path_to_rendered_depth_path(image_name)
+
+ depth = PIL.Image.open(Path(depth_folder_path) / depth_name)
+ depth = np.array(depth).astype("float64")
+ depth = depth / 1000.0 # mm to meter
+ depth[(depth == 0.0) | (depth > 1000.0)] = np.nan
+
+ R_w2c, t_w2c = img.qvec2rotmat(), img.tvec
+ camera = cameras[img.camera_id]
+ p3D_ids = img.point3D_ids
+ p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0)
+
+ p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera)
+ invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected]
+ interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds)
+ scs = scene_coordinates(
+ p2Ds[valids_backprojected],
+ R_w2c,
+ t_w2c,
+ interp_depth[valids_backprojected],
+ camera,
+ )
+ invalid_p3D_ids = np.append(
+ invalid_p3D_ids,
+ p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected],
+ )
+ for p3did in invalid_p3D_ids:
+ if p3did == -1:
+ continue
+ else:
+ obs_imgids = points3D[p3did].image_ids
+ invalid_imgids = list(np.where(obs_imgids == img.id)[0])
+ points3D[p3did] = points3D[p3did]._replace(
+ image_ids=np.delete(obs_imgids, invalid_imgids),
+ point2D_idxs=np.delete(
+ points3D[p3did].point2D_idxs, invalid_imgids
+ ),
+ )
+
+ new_p3D_ids = p3D_ids.copy()
+ sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1]
+ valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool)
+ valids[~valids_projected] = False
+ valids[valids_projected] = valids_backprojected
+ sub_p3D_ids[~valids] = -1
+ new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids
+ img = img._replace(point3D_ids=new_p3D_ids)
+
+ assert len(img.point3D_ids[img.point3D_ids != -1]) == len(
+ scs
+ ), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}"
+ for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]):
+ points3D[p3did] = points3D[p3did]._replace(xyz=scs[i])
+ images[imgid] = img
+
+ output_path.mkdir(parents=True, exist_ok=True)
+ write_model(cameras, images, points3D, output_path)
+
+
+if __name__ == "__main__":
+ dataset = Path("datasets/7scenes")
+ outputs = Path("outputs/7Scenes")
+
+ SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"]
+ for scene in SCENES:
+ sfm_path = outputs / scene / "sfm_superpoint+superglue"
+ depth_path = dataset / f"depth/7scenes_{scene}/train/depth"
+ output_path = outputs / scene / "sfm_superpoint+superglue+depth"
+ correct_sfm_with_gt_depth(sfm_path, depth_path, output_path)
diff --git a/hloc/pipelines/7Scenes/pipeline.py b/hloc/pipelines/7Scenes/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..f44bbaeef87ac3c6091c51cc367f1ba98bab3c3f
--- /dev/null
+++ b/hloc/pipelines/7Scenes/pipeline.py
@@ -0,0 +1,133 @@
+from pathlib import Path
+import argparse
+
+from .utils import create_reference_sfm
+from .create_gt_sfm import correct_sfm_with_gt_depth
+from ..Cambridge.utils import create_query_list_with_intrinsics, evaluate
+from ... import extract_features, match_features, pairs_from_covisibility
+from ... import triangulation, localize_sfm, logger
+
+SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"]
+
+
+def run_scene(
+ images,
+ gt_dir,
+ retrieval,
+ outputs,
+ results,
+ num_covis,
+ use_dense_depth,
+ depth_dir=None,
+):
+ outputs.mkdir(exist_ok=True, parents=True)
+ ref_sfm_sift = outputs / "sfm_sift"
+ ref_sfm = outputs / "sfm_superpoint+superglue"
+ query_list = outputs / "query_list_with_intrinsics.txt"
+
+ feature_conf = {
+ "output": "feats-superpoint-n4096-r1024",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ },
+ "preprocessing": {
+ "globs": ["*.color.png"],
+ "grayscale": True,
+ "resize_max": 1024,
+ },
+ }
+ matcher_conf = match_features.confs["superglue"]
+ matcher_conf["model"]["sinkhorn_iterations"] = 5
+
+ test_list = gt_dir / "list_test.txt"
+ create_reference_sfm(gt_dir, ref_sfm_sift, test_list)
+ create_query_list_with_intrinsics(gt_dir, query_list, test_list)
+
+ features = extract_features.main(feature_conf, images, outputs, as_half=True)
+
+ sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
+ pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+ if not (use_dense_depth and ref_sfm.exists()):
+ triangulation.main(
+ ref_sfm, ref_sfm_sift, images, sfm_pairs, features, sfm_matches
+ )
+ if use_dense_depth:
+ assert depth_dir is not None
+ ref_sfm_fix = outputs / "sfm_superpoint+superglue+depth"
+ correct_sfm_with_gt_depth(ref_sfm, depth_dir, ref_sfm_fix)
+ ref_sfm = ref_sfm_fix
+
+ loc_matches = match_features.main(
+ matcher_conf, retrieval, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ ref_sfm,
+ query_list,
+ retrieval,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ prepend_camera_name=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+")
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/7scenes",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/7scenes",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument("--use_dense_depth", action="store_true")
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=30,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ args = parser.parse_args()
+
+ gt_dirs = args.dataset / "7scenes_sfm_triangulated/{scene}/triangulated"
+ retrieval_dirs = args.dataset / "7scenes_densevlad_retrieval_top_10"
+
+ all_results = {}
+ for scene in args.scenes:
+ logger.info(f'Working on scene "{scene}".')
+ results = (
+ args.outputs
+ / scene
+ / "results_{}.txt".format("dense" if args.use_dense_depth else "sparse")
+ )
+ if args.overwrite or not results.exists():
+ run_scene(
+ args.dataset / scene,
+ Path(str(gt_dirs).format(scene=scene)),
+ retrieval_dirs / f"{scene}_top10.txt",
+ args.outputs / scene,
+ results,
+ args.num_covis,
+ args.use_dense_depth,
+ depth_dir=args.dataset / f"depth/7scenes_{scene}/train/depth",
+ )
+ all_results[scene] = results
+
+ for scene in args.scenes:
+ logger.info(f'Evaluate scene "{scene}".')
+ gt_dir = Path(str(gt_dirs).format(scene=scene))
+ evaluate(gt_dir, all_results[scene], gt_dir / "list_test.txt")
diff --git a/hloc/pipelines/7Scenes/utils.py b/hloc/pipelines/7Scenes/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51070a0f6888e075fe54e3543d8b2283baf5fee5
--- /dev/null
+++ b/hloc/pipelines/7Scenes/utils.py
@@ -0,0 +1,33 @@
+import logging
+import numpy as np
+
+from hloc.utils.read_write_model import read_model, write_model
+
+logger = logging.getLogger(__name__)
+
+
+def create_reference_sfm(full_model, ref_model, blacklist=None, ext=".bin"):
+ """Create a new COLMAP model with only training images."""
+ logger.info("Creating the reference model.")
+ ref_model.mkdir(exist_ok=True)
+ cameras, images, points3D = read_model(full_model, ext)
+
+ if blacklist is not None:
+ with open(blacklist, "r") as f:
+ blacklist = f.read().rstrip().split("\n")
+
+ images_ref = dict()
+ for id_, image in images.items():
+ if blacklist and image.name in blacklist:
+ continue
+ images_ref[id_] = image
+
+ points3D_ref = dict()
+ for id_, point3D in points3D.items():
+ ref_ids = [i for i in point3D.image_ids if i in images_ref]
+ if len(ref_ids) == 0:
+ continue
+ points3D_ref[id_] = point3D._replace(image_ids=np.array(ref_ids))
+
+ write_model(cameras, images_ref, points3D_ref, ref_model, ".bin")
+ logger.info(f"Kept {len(images_ref)} images out of {len(images)}.")
diff --git a/hloc/pipelines/Aachen/README.md b/hloc/pipelines/Aachen/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1aefdb7ddb3371335ba5a6a354acf3692206ecf7
--- /dev/null
+++ b/hloc/pipelines/Aachen/README.md
@@ -0,0 +1,16 @@
+# Aachen-Day-Night dataset
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/aachen
+wget -r -np -nH -R "index.html*,aachen_v1_1.zip" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset
+unzip $dataset/images/database_and_query_images.zip -d $dataset/images
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.Aachen.pipeline
+```
diff --git a/hloc/pipelines/Aachen/__init__.py b/hloc/pipelines/Aachen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/Aachen/pipeline.py b/hloc/pipelines/Aachen/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a6b6661c92a4e32115358803043b1f428afba95
--- /dev/null
+++ b/hloc/pipelines/Aachen/pipeline.py
@@ -0,0 +1,96 @@
+from pathlib import Path
+from pprint import pformat
+import argparse
+
+from ... import extract_features, match_features
+from ... import pairs_from_covisibility, pairs_from_retrieval
+from ... import colmap_from_nvm, triangulation, localize_sfm
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/aachen",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen",
+ help="Path to the output directory, default: %(default)s",
+)
+parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+)
+parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=50,
+ help="Number of image pairs for loc, default: %(default)s",
+)
+args = parser.parse_args()
+
+# Setup the paths
+dataset = args.dataset
+images = dataset / "images/images_upright/"
+
+outputs = args.outputs # where everything will be saved
+sift_sfm = outputs / "sfm_sift" # from which we extract the reference poses
+reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build
+sfm_pairs = (
+ outputs / f"pairs-db-covis{args.num_covis}.txt"
+) # top-k most covisible in SIFT model
+loc_pairs = (
+ outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+) # top-k retrieved by NetVLAD
+results = outputs / f"Aachen_hloc_superpoint+superglue_netvlad{args.num_loc}.txt"
+
+# list the standard configurations available
+print(f"Configs for feature extractors:\n{pformat(extract_features.confs)}")
+print(f"Configs for feature matchers:\n{pformat(match_features.confs)}")
+
+# pick one of the configurations for extraction and matching
+retrieval_conf = extract_features.confs["netvlad"]
+feature_conf = extract_features.confs["superpoint_aachen"]
+matcher_conf = match_features.confs["superglue"]
+
+features = extract_features.main(feature_conf, images, outputs)
+
+colmap_from_nvm.main(
+ dataset / "3D-models/aachen_cvpr2018_db.nvm",
+ dataset / "3D-models/database_intrinsics.txt",
+ dataset / "aachen.db",
+ sift_sfm,
+)
+pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+)
+
+triangulation.main(reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches)
+
+global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix="query",
+ db_model=reference_sfm,
+)
+loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+)
+
+localize_sfm.main(
+ reference_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+) # not required with SuperPoint+SuperGlue
diff --git a/hloc/pipelines/Aachen_v1_1/README.md b/hloc/pipelines/Aachen_v1_1/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..33a310bcb2625d6d307e0c95f65f23186c0c9f71
--- /dev/null
+++ b/hloc/pipelines/Aachen_v1_1/README.md
@@ -0,0 +1,18 @@
+# Aachen-Day-Night dataset v1.1
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/aachen_v1.1
+wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset
+unzip $dataset/images/database_and_query_images.zip -d $dataset/images
+unzip $dataset/aachen_v1_1.zip -d $dataset
+rsync -a $dataset/images_upright/ $dataset/images/images_upright/
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.Aachen_v1_1.pipeline
+```
diff --git a/hloc/pipelines/Aachen_v1_1/__init__.py b/hloc/pipelines/Aachen_v1_1/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/Aachen_v1_1/pipeline.py b/hloc/pipelines/Aachen_v1_1/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3feb42562dc4c1031fb373a8b7241fae055c27c
--- /dev/null
+++ b/hloc/pipelines/Aachen_v1_1/pipeline.py
@@ -0,0 +1,89 @@
+from pathlib import Path
+from pprint import pformat
+import argparse
+
+from ... import extract_features, match_features, triangulation
+from ... import pairs_from_covisibility, pairs_from_retrieval, localize_sfm
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/aachen_v1.1",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen_v1.1",
+ help="Path to the output directory, default: %(default)s",
+)
+parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+)
+parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=50,
+ help="Number of image pairs for loc, default: %(default)s",
+)
+args = parser.parse_args()
+
+# Setup the paths
+dataset = args.dataset
+images = dataset / "images/images_upright/"
+sift_sfm = dataset / "3D-models/aachen_v_1_1"
+
+outputs = args.outputs # where everything will be saved
+reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build
+sfm_pairs = (
+ outputs / f"pairs-db-covis{args.num_covis}.txt"
+) # top-k most covisible in SIFT model
+loc_pairs = (
+ outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+) # top-k retrieved by NetVLAD
+results = outputs / f"Aachen-v1.1_hloc_superpoint+superglue_netvlad{args.num_loc}.txt"
+
+# list the standard configurations available
+print(f"Configs for feature extractors:\n{pformat(extract_features.confs)}")
+print(f"Configs for feature matchers:\n{pformat(match_features.confs)}")
+
+# pick one of the configurations for extraction and matching
+retrieval_conf = extract_features.confs["netvlad"]
+feature_conf = extract_features.confs["superpoint_max"]
+matcher_conf = match_features.confs["superglue"]
+
+features = extract_features.main(feature_conf, images, outputs)
+
+pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+)
+
+triangulation.main(reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches)
+
+global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix="query",
+ db_model=reference_sfm,
+)
+loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+)
+
+localize_sfm.main(
+ reference_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+) # not required with SuperPoint+SuperGlue
diff --git a/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py b/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..34cd1761d8e78096e842e37edf546579a3cf8542
--- /dev/null
+++ b/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py
@@ -0,0 +1,92 @@
+from pathlib import Path
+from pprint import pformat
+import argparse
+
+from ... import extract_features, match_dense, triangulation
+from ... import pairs_from_covisibility, pairs_from_retrieval, localize_sfm
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/aachen_v1.1",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen_v1.1",
+ help="Path to the output directory, default: %(default)s",
+)
+parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+)
+parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=50,
+ help="Number of image pairs for loc, default: %(default)s",
+)
+args = parser.parse_args()
+
+# Setup the paths
+dataset = args.dataset
+images = dataset / "images/images_upright/"
+sift_sfm = dataset / "3D-models/aachen_v_1_1"
+
+outputs = args.outputs # where everything will be saved
+outputs.mkdir()
+reference_sfm = outputs / "sfm_loftr" # the SfM model we will build
+sfm_pairs = (
+ outputs / f"pairs-db-covis{args.num_covis}.txt"
+) # top-k most covisible in SIFT model
+loc_pairs = (
+ outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+) # top-k retrieved by NetVLAD
+results = outputs / f"Aachen-v1.1_hloc_loftr_netvlad{args.num_loc}.txt"
+
+# list the standard configurations available
+print(f"Configs for dense feature matchers:\n{pformat(match_dense.confs)}")
+
+# pick one of the configurations for extraction and matching
+retrieval_conf = extract_features.confs["netvlad"]
+matcher_conf = match_dense.confs["loftr_aachen"]
+
+pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+features, sfm_matches = match_dense.main(
+ matcher_conf, sfm_pairs, images, outputs, max_kps=8192, overwrite=False
+)
+
+triangulation.main(reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches)
+
+global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix="query",
+ db_model=reference_sfm,
+)
+features, loc_matches = match_dense.main(
+ matcher_conf,
+ loc_pairs,
+ images,
+ outputs,
+ features=features,
+ max_kps=None,
+ matches=sfm_matches,
+)
+
+localize_sfm.main(
+ reference_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+) # not required with loftr
diff --git a/hloc/pipelines/CMU/README.md b/hloc/pipelines/CMU/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..566ba352c53ada2a13dce21c8ec1041b56969d03
--- /dev/null
+++ b/hloc/pipelines/CMU/README.md
@@ -0,0 +1,16 @@
+# Extended CMU Seasons dataset
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/cmu_extended
+wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Extended-CMU-Seasons/ -P $dataset
+for slice in $dataset/*.tar; do tar -xf $slice -C $dataset && rm $slice; done
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.CMU.pipeline
+```
diff --git a/hloc/pipelines/CMU/__init__.py b/hloc/pipelines/CMU/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/CMU/pipeline.py b/hloc/pipelines/CMU/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8bab47c3abfc79412cbcfa18071ed49e13cd8d
--- /dev/null
+++ b/hloc/pipelines/CMU/pipeline.py
@@ -0,0 +1,126 @@
+from pathlib import Path
+import argparse
+
+from ... import extract_features, match_features, triangulation, logger
+from ... import pairs_from_covisibility, pairs_from_retrieval, localize_sfm
+
+TEST_SLICES = [2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, 19, 20, 21]
+
+
+def generate_query_list(dataset, path, slice_):
+ cameras = {}
+ with open(dataset / "intrinsics.txt", "r") as f:
+ for line in f.readlines():
+ if line[0] == "#" or line == "\n":
+ continue
+ data = line.split()
+ cameras[data[0]] = data[1:]
+ assert len(cameras) == 2
+
+ queries = dataset / f"{slice_}/test-images-{slice_}.txt"
+ with open(queries, "r") as f:
+ queries = [q.rstrip("\n") for q in f.readlines()]
+
+ out = [[q] + cameras[q.split("_")[2]] for q in queries]
+ with open(path, "w") as f:
+ f.write("\n".join(map(" ".join, out)))
+
+
+def run_slice(slice_, root, outputs, num_covis, num_loc):
+ dataset = root / slice_
+ ref_images = dataset / "database"
+ query_images = dataset / "query"
+ sift_sfm = dataset / "sparse"
+
+ outputs = outputs / slice_
+ outputs.mkdir(exist_ok=True, parents=True)
+ query_list = dataset / "queries_with_intrinsics.txt"
+ sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
+ loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt"
+ ref_sfm = outputs / "sfm_superpoint+superglue"
+ results = outputs / f"CMU_hloc_superpoint+superglue_netvlad{num_loc}.txt"
+
+ # pick one of the configurations for extraction and matching
+ retrieval_conf = extract_features.confs["netvlad"]
+ feature_conf = extract_features.confs["superpoint_aachen"]
+ matcher_conf = match_features.confs["superglue"]
+
+ pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=num_covis)
+ features = extract_features.main(feature_conf, ref_images, outputs, as_half=True)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+ triangulation.main(ref_sfm, sift_sfm, ref_images, sfm_pairs, features, sfm_matches)
+
+ generate_query_list(root, query_list, slice_)
+ global_descriptors = extract_features.main(retrieval_conf, ref_images, outputs)
+ global_descriptors = extract_features.main(retrieval_conf, query_images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors, loc_pairs, num_loc, query_list=query_list, db_model=ref_sfm
+ )
+
+ features = extract_features.main(feature_conf, query_images, outputs, as_half=True)
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ ref_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--slices",
+ type=str,
+ default="*",
+ help="a single number, an interval (e.g. 2-6), "
+ "or a Python-style list or int (e.g. [2, 3, 4]",
+ )
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/cmu_extended",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen_extended",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=10,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
+
+ if args.slice == "*":
+ slices = TEST_SLICES
+ if "-" in args.slices:
+ min_, max_ = args.slices.split("-")
+ slices = list(range(int(min_), int(max_) + 1))
+ else:
+ slices = eval(args.slices)
+ if isinstance(slices, int):
+ slices = [slices]
+
+ for slice_ in slices:
+ logger.info("Working on slice %s.", slice_)
+ run_slice(
+ f"slice{slice_}", args.dataset, args.outputs, args.num_covis, args.num_loc
+ )
diff --git a/hloc/pipelines/Cambridge/README.md b/hloc/pipelines/Cambridge/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d5ae07b71c48a98fa9235f0dfb0234c3c18c74c6
--- /dev/null
+++ b/hloc/pipelines/Cambridge/README.md
@@ -0,0 +1,47 @@
+# Cambridge Landmarks dataset
+
+## Installation
+
+Download the dataset from the [PoseNet project page](http://mi.eng.cam.ac.uk/projects/relocalisation/):
+```bash
+export dataset=datasets/cambridge
+export scenes=( "KingsCollege" "OldHospital" "StMarysChurch" "ShopFacade" "GreatCourt" )
+export IDs=( "251342" "251340" "251294" "251336" "251291" )
+for i in "${!scenes[@]}"; do
+wget https://www.repository.cam.ac.uk/bitstream/handle/1810/${IDs[i]}/${scenes[i]}.zip -P $dataset \
+&& unzip $dataset/${scenes[i]}.zip -d $dataset && rm $dataset/${scenes[i]}.zip; done
+```
+
+Download the SIFT SfM models, courtesy of Torsten Sattler:
+```bash
+export fileid=1esqzZ1zEQlzZVic-H32V6kkZvc4NeS15
+export filename=$dataset/CambridgeLandmarks_Colmap_Retriangulated_1024px.zip
+wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$fileid" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$fileid" -O $filename && rm -rf /tmp/cookies.txt
+unzip $filename -d $dataset
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.Cambridge.pipeline
+```
+
+## Results
+We report the median error in translation/rotation in cm/deg over all scenes:
+| Method \ Scene | Court | King's | Hospital | Shop | St. Mary's |
+| ------------------------ | --------------- | --------------- | --------------- | -------------- | -------------- |
+| Active Search | 24/0.13 | 13/0.22 | 20/0.36 | **4**/0.21 | 8/0.25 |
+| DSAC* | 49/0.3 | 15/0.3 | 21/0.4 | 5/0.3 | 13/0.4 |
+| **SuperPoint+SuperGlue** | **17**/**0.11** | **12**/**0.21** | **14**/**0.30** | **4**/**0.19** | **7**/**0.22** |
+
+## Citation
+
+Please cite the following paper if you use the Cambridge Landmarks dataset:
+```
+@inproceedings{kendall2015posenet,
+ title={{PoseNet}: A convolutional network for real-time {6-DoF} camera relocalization},
+ author={Kendall, Alex and Grimes, Matthew and Cipolla, Roberto},
+ booktitle={ICCV},
+ year={2015}
+}
+```
diff --git a/hloc/pipelines/Cambridge/__init__.py b/hloc/pipelines/Cambridge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/Cambridge/pipeline.py b/hloc/pipelines/Cambridge/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a032c0b7cc8a205b206a36d2edd87df6a3abd1
--- /dev/null
+++ b/hloc/pipelines/Cambridge/pipeline.py
@@ -0,0 +1,133 @@
+from pathlib import Path
+import argparse
+
+from .utils import create_query_list_with_intrinsics, scale_sfm_images, evaluate
+from ... import extract_features, match_features, pairs_from_covisibility
+from ... import triangulation, localize_sfm, pairs_from_retrieval, logger
+
+SCENES = ["KingsCollege", "OldHospital", "ShopFacade", "StMarysChurch", "GreatCourt"]
+
+
+def run_scene(images, gt_dir, outputs, results, num_covis, num_loc):
+ ref_sfm_sift = gt_dir / "model_train"
+ test_list = gt_dir / "list_query.txt"
+
+ outputs.mkdir(exist_ok=True, parents=True)
+ ref_sfm = outputs / "sfm_superpoint+superglue"
+ ref_sfm_scaled = outputs / "sfm_sift_scaled"
+ query_list = outputs / "query_list_with_intrinsics.txt"
+ sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
+ loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt"
+
+ feature_conf = {
+ "output": "feats-superpoint-n4096-r1024",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ },
+ }
+ matcher_conf = match_features.confs["superglue"]
+ retrieval_conf = extract_features.confs["netvlad"]
+
+ create_query_list_with_intrinsics(
+ gt_dir / "empty_all", query_list, test_list, ext=".txt", image_dir=images
+ )
+ with open(test_list, "r") as f:
+ query_seqs = {q.split("/")[0] for q in f.read().rstrip().split("\n")}
+
+ global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ num_loc,
+ db_model=ref_sfm_sift,
+ query_prefix=query_seqs,
+ )
+
+ features = extract_features.main(feature_conf, images, outputs, as_half=True)
+ pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+
+ scale_sfm_images(ref_sfm_sift, ref_sfm_scaled, images)
+ triangulation.main(
+ ref_sfm, ref_sfm_scaled, images, sfm_pairs, features, sfm_matches
+ )
+
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ ref_sfm,
+ query_list,
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ prepend_camera_name=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+")
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/cambridge",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/cambridge",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=10,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
+
+ gt_dirs = args.dataset / "CambridgeLandmarks_Colmap_Retriangulated_1024px"
+
+ all_results = {}
+ for scene in args.scenes:
+ logger.info(f'Working on scene "{scene}".')
+ results = args.outputs / scene / "results.txt"
+ if args.overwrite or not results.exists():
+ run_scene(
+ args.dataset / scene,
+ gt_dirs / scene,
+ args.outputs / scene,
+ results,
+ args.num_covis,
+ args.num_loc,
+ )
+ all_results[scene] = results
+
+ for scene in args.scenes:
+ logger.info(f'Evaluate scene "{scene}".')
+ evaluate(
+ gt_dirs / scene / "empty_all",
+ all_results[scene],
+ gt_dirs / scene / "list_query.txt",
+ ext=".txt",
+ )
diff --git a/hloc/pipelines/Cambridge/utils.py b/hloc/pipelines/Cambridge/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..81883defd3adf5a65baed02f4b282c380c36d742
--- /dev/null
+++ b/hloc/pipelines/Cambridge/utils.py
@@ -0,0 +1,144 @@
+import cv2
+import logging
+import numpy as np
+
+from hloc.utils.read_write_model import (
+ read_cameras_binary,
+ read_images_binary,
+ read_model,
+ write_model,
+ qvec2rotmat,
+ read_images_text,
+ read_cameras_text,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def scale_sfm_images(full_model, scaled_model, image_dir):
+ """Duplicate the provided model and scale the camera intrinsics so that
+ they match the original image resolution - makes everything easier.
+ """
+ logger.info("Scaling the COLMAP model to the original image size.")
+ scaled_model.mkdir(exist_ok=True)
+ cameras, images, points3D = read_model(full_model)
+
+ scaled_cameras = {}
+ for id_, image in images.items():
+ name = image.name
+ img = cv2.imread(str(image_dir / name))
+ assert img is not None, image_dir / name
+ h, w = img.shape[:2]
+
+ cam_id = image.camera_id
+ if cam_id in scaled_cameras:
+ assert scaled_cameras[cam_id].width == w
+ assert scaled_cameras[cam_id].height == h
+ continue
+
+ camera = cameras[cam_id]
+ assert camera.model == "SIMPLE_RADIAL"
+ sx = w / camera.width
+ sy = h / camera.height
+ assert sx == sy, (sx, sy)
+ scaled_cameras[cam_id] = camera._replace(
+ width=w, height=h, params=camera.params * np.array([sx, sx, sy, 1.0])
+ )
+
+ write_model(scaled_cameras, images, points3D, scaled_model)
+
+
+def create_query_list_with_intrinsics(
+ model, out, list_file=None, ext=".bin", image_dir=None
+):
+ """Create a list of query images with intrinsics from the colmap model."""
+ if ext == ".bin":
+ images = read_images_binary(model / "images.bin")
+ cameras = read_cameras_binary(model / "cameras.bin")
+ else:
+ images = read_images_text(model / "images.txt")
+ cameras = read_cameras_text(model / "cameras.txt")
+
+ name2id = {image.name: i for i, image in images.items()}
+ if list_file is None:
+ names = list(name2id)
+ else:
+ with open(list_file, "r") as f:
+ names = f.read().rstrip().split("\n")
+ data = []
+ for name in names:
+ image = images[name2id[name]]
+ camera = cameras[image.camera_id]
+ w, h, params = camera.width, camera.height, camera.params
+
+ if image_dir is not None:
+ # Check the original image size and rescale the camera intrinsics
+ img = cv2.imread(str(image_dir / name))
+ assert img is not None, image_dir / name
+ h_orig, w_orig = img.shape[:2]
+ assert camera.model == "SIMPLE_RADIAL"
+ sx = w_orig / w
+ sy = h_orig / h
+ assert sx == sy, (sx, sy)
+ w, h = w_orig, h_orig
+ params = params * np.array([sx, sx, sy, 1.0])
+
+ p = [name, camera.model, w, h] + params.tolist()
+ data.append(" ".join(map(str, p)))
+ with open(out, "w") as f:
+ f.write("\n".join(data))
+
+
+def evaluate(model, results, list_file=None, ext=".bin", only_localized=False):
+ predictions = {}
+ with open(results, "r") as f:
+ for data in f.read().rstrip().split("\n"):
+ data = data.split()
+ name = data[0]
+ q, t = np.split(np.array(data[1:], float), [4])
+ predictions[name] = (qvec2rotmat(q), t)
+ if ext == ".bin":
+ images = read_images_binary(model / "images.bin")
+ else:
+ images = read_images_text(model / "images.txt")
+ name2id = {image.name: i for i, image in images.items()}
+
+ if list_file is None:
+ test_names = list(name2id)
+ else:
+ with open(list_file, "r") as f:
+ test_names = f.read().rstrip().split("\n")
+
+ errors_t = []
+ errors_R = []
+ for name in test_names:
+ if name not in predictions:
+ if only_localized:
+ continue
+ e_t = np.inf
+ e_R = 180.0
+ else:
+ image = images[name2id[name]]
+ R_gt, t_gt = image.qvec2rotmat(), image.tvec
+ R, t = predictions[name]
+ e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0)
+ cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1.0, 1.0)
+ e_R = np.rad2deg(np.abs(np.arccos(cos)))
+ errors_t.append(e_t)
+ errors_R.append(e_R)
+
+ errors_t = np.array(errors_t)
+ errors_R = np.array(errors_R)
+
+ med_t = np.median(errors_t)
+ med_R = np.median(errors_R)
+ out = f"Results for file {results.name}:"
+ out += f"\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg"
+
+ out += "\nPercentage of test images localized within:"
+ threshs_t = [0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5.0]
+ threshs_R = [1.0, 2.0, 3.0, 5.0, 2.0, 5.0, 10.0]
+ for th_t, th_R in zip(threshs_t, threshs_R):
+ ratio = np.mean((errors_t < th_t) & (errors_R < th_R))
+ out += f"\n\t{th_t*100:.0f}cm, {th_R:.0f}deg : {ratio*100:.2f}%"
+ logger.info(out)
diff --git a/hloc/pipelines/RobotCar/README.md b/hloc/pipelines/RobotCar/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9881d153d4930cf32b5481ecd4fa2c900fa58c8c
--- /dev/null
+++ b/hloc/pipelines/RobotCar/README.md
@@ -0,0 +1,16 @@
+# RobotCar Seasons dataset
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/robotcar
+wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/RobotCar-Seasons/ -P $dataset
+for condition in $dataset/images/*.zip; do unzip condition -d $dataset/images/; done
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.RobotCar.pipeline
+```
diff --git a/hloc/pipelines/RobotCar/__init__.py b/hloc/pipelines/RobotCar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/pipelines/RobotCar/colmap_from_nvm.py b/hloc/pipelines/RobotCar/colmap_from_nvm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7acf11e4f0edd2c369728e2a89748a4b7cf0c33b
--- /dev/null
+++ b/hloc/pipelines/RobotCar/colmap_from_nvm.py
@@ -0,0 +1,171 @@
+import argparse
+import sqlite3
+from tqdm import tqdm
+from collections import defaultdict
+import numpy as np
+from pathlib import Path
+import logging
+
+from ...colmap_from_nvm import (
+ recover_database_images_and_ids,
+ camera_center_to_translation,
+)
+from ...utils.read_write_model import Camera, Image, Point3D, CAMERA_MODEL_IDS
+from ...utils.read_write_model import write_model
+
+logger = logging.getLogger(__name__)
+
+
+def read_nvm_model(nvm_path, database_path, image_ids, camera_ids, skip_points=False):
+
+ # Extract the intrinsics from the db file instead of the NVM model
+ db = sqlite3.connect(str(database_path))
+ ret = db.execute("SELECT camera_id, model, width, height, params FROM cameras;")
+ cameras = {}
+ for camera_id, camera_model, width, height, params in ret:
+ params = np.fromstring(params, dtype=np.double).reshape(-1)
+ camera_model = CAMERA_MODEL_IDS[camera_model]
+ assert len(params) == camera_model.num_params, (
+ len(params),
+ camera_model.num_params,
+ )
+ camera = Camera(
+ id=camera_id,
+ model=camera_model.model_name,
+ width=int(width),
+ height=int(height),
+ params=params,
+ )
+ cameras[camera_id] = camera
+
+ nvm_f = open(nvm_path, "r")
+ line = nvm_f.readline()
+ while line == "\n" or line.startswith("NVM_V3"):
+ line = nvm_f.readline()
+ num_images = int(line)
+ # assert num_images == len(cameras), (num_images, len(cameras))
+
+ logger.info(f"Reading {num_images} images...")
+ image_idx_to_db_image_id = []
+ image_data = []
+ i = 0
+ while i < num_images:
+ line = nvm_f.readline()
+ if line == "\n":
+ continue
+ data = line.strip("\n").lstrip("./").split(" ")
+ image_data.append(data)
+ image_idx_to_db_image_id.append(image_ids[data[0]])
+ i += 1
+
+ line = nvm_f.readline()
+ while line == "\n":
+ line = nvm_f.readline()
+ num_points = int(line)
+
+ if skip_points:
+ logger.info(f"Skipping {num_points} points.")
+ num_points = 0
+ else:
+ logger.info(f"Reading {num_points} points...")
+ points3D = {}
+ image_idx_to_keypoints = defaultdict(list)
+ i = 0
+ pbar = tqdm(total=num_points, unit="pts")
+ while i < num_points:
+ line = nvm_f.readline()
+ if line == "\n":
+ continue
+
+ data = line.strip("\n").split(" ")
+ x, y, z, r, g, b, num_observations = data[:7]
+ obs_image_ids, point2D_idxs = [], []
+ for j in range(int(num_observations)):
+ s = 7 + 4 * j
+ img_index, kp_index, kx, ky = data[s : s + 4]
+ image_idx_to_keypoints[int(img_index)].append(
+ (int(kp_index), float(kx), float(ky), i)
+ )
+ db_image_id = image_idx_to_db_image_id[int(img_index)]
+ obs_image_ids.append(db_image_id)
+ point2D_idxs.append(kp_index)
+
+ point = Point3D(
+ id=i,
+ xyz=np.array([x, y, z], float),
+ rgb=np.array([r, g, b], int),
+ error=1.0, # fake
+ image_ids=np.array(obs_image_ids, int),
+ point2D_idxs=np.array(point2D_idxs, int),
+ )
+ points3D[i] = point
+
+ i += 1
+ pbar.update(1)
+ pbar.close()
+
+ logger.info("Parsing image data...")
+ images = {}
+ for i, data in enumerate(image_data):
+ # Skip the focal length. Skip the distortion and terminal 0.
+ name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data
+ qvec = np.array([qw, qx, qy, qz], float)
+ c = np.array([cx, cy, cz], float)
+ t = camera_center_to_translation(c, qvec)
+
+ if i in image_idx_to_keypoints:
+ # NVM only stores triangulated 2D keypoints: add dummy ones
+ keypoints = image_idx_to_keypoints[i]
+ point2D_idxs = np.array([d[0] for d in keypoints])
+ tri_xys = np.array([[x, y] for _, x, y, _ in keypoints])
+ tri_ids = np.array([i for _, _, _, i in keypoints])
+
+ num_2Dpoints = max(point2D_idxs) + 1
+ xys = np.zeros((num_2Dpoints, 2), float)
+ point3D_ids = np.full(num_2Dpoints, -1, int)
+ xys[point2D_idxs] = tri_xys
+ point3D_ids[point2D_idxs] = tri_ids
+ else:
+ xys = np.zeros((0, 2), float)
+ point3D_ids = np.full(0, -1, int)
+
+ image_id = image_ids[name]
+ image = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=t,
+ camera_id=camera_ids[name],
+ name=name.replace("png", "jpg"), # some hack required for RobotCar
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ images[image_id] = image
+
+ return cameras, images, points3D
+
+
+def main(nvm, database, output, skip_points=False):
+ assert nvm.exists(), nvm
+ assert database.exists(), database
+
+ image_ids, camera_ids = recover_database_images_and_ids(database)
+
+ logger.info("Reading the NVM model...")
+ model = read_nvm_model(
+ nvm, database, image_ids, camera_ids, skip_points=skip_points
+ )
+
+ logger.info("Writing the COLMAP model...")
+ output.mkdir(exist_ok=True, parents=True)
+ write_model(*model, path=str(output), ext=".bin")
+ logger.info("Done.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--nvm", required=True, type=Path)
+ parser.add_argument("--database", required=True, type=Path)
+ parser.add_argument("--output", required=True, type=Path)
+ parser.add_argument("--skip_points", action="store_true")
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/hloc/pipelines/RobotCar/pipeline.py b/hloc/pipelines/RobotCar/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f1e9dc82e8bb369efe3711db8381de103e7373f
--- /dev/null
+++ b/hloc/pipelines/RobotCar/pipeline.py
@@ -0,0 +1,130 @@
+from pathlib import Path
+import argparse
+
+from . import colmap_from_nvm
+from ... import extract_features, match_features, triangulation
+from ... import pairs_from_covisibility, pairs_from_retrieval, localize_sfm
+
+
+CONDITIONS = [
+ "dawn",
+ "dusk",
+ "night",
+ "night-rain",
+ "overcast-summer",
+ "overcast-winter",
+ "rain",
+ "snow",
+ "sun",
+]
+
+
+def generate_query_list(dataset, image_dir, path):
+ h, w = 1024, 1024
+ intrinsics_filename = "intrinsics/{}_intrinsics.txt"
+ cameras = {}
+ for side in ["left", "right", "rear"]:
+ with open(dataset / intrinsics_filename.format(side), "r") as f:
+ fx = f.readline().split()[1]
+ fy = f.readline().split()[1]
+ cx = f.readline().split()[1]
+ cy = f.readline().split()[1]
+ assert fx == fy
+ params = ["SIMPLE_RADIAL", w, h, fx, cx, cy, 0.0]
+ cameras[side] = [str(p) for p in params]
+
+ queries = sorted(image_dir.glob("**/*.jpg"))
+ queries = [str(q.relative_to(image_dir.parents[0])) for q in queries]
+
+ out = [[q] + cameras[Path(q).parent.name] for q in queries]
+ with open(path, "w") as f:
+ f.write("\n".join(map(" ".join, out)))
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/robotcar",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/robotcar",
+ help="Path to the output directory, default: %(default)s",
+)
+parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+)
+parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=20,
+ help="Number of image pairs for loc, default: %(default)s",
+)
+args = parser.parse_args()
+
+# Setup the paths
+dataset = args.dataset
+images = dataset / "images/"
+
+outputs = args.outputs # where everything will be saved
+outputs.mkdir(exist_ok=True, parents=True)
+query_list = outputs / "{condition}_queries_with_intrinsics.txt"
+sift_sfm = outputs / "sfm_sift"
+reference_sfm = outputs / "sfm_superpoint+superglue"
+sfm_pairs = outputs / f"pairs-db-covis{args.num_covis}.txt"
+loc_pairs = outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+results = outputs / f"RobotCar_hloc_superpoint+superglue_netvlad{args.num_loc}.txt"
+
+# pick one of the configurations for extraction and matching
+retrieval_conf = extract_features.confs["netvlad"]
+feature_conf = extract_features.confs["superpoint_aachen"]
+matcher_conf = match_features.confs["superglue"]
+
+for condition in CONDITIONS:
+ generate_query_list(
+ dataset, images / condition, str(query_list).format(condition=condition)
+ )
+
+features = extract_features.main(feature_conf, images, outputs, as_half=True)
+
+colmap_from_nvm.main(
+ dataset / "3D-models/all-merged/all.nvm",
+ dataset / "3D-models/overcast-reference.db",
+ sift_sfm,
+)
+pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+)
+
+triangulation.main(reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches)
+
+global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+# TODO: do per location and per camera
+pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix=CONDITIONS,
+ db_model=reference_sfm,
+)
+loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+)
+
+localize_sfm.main(
+ reference_sfm,
+ Path(str(query_list).format(condition="*")),
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ prepend_camera_name=True,
+)
diff --git a/hloc/pipelines/__init__.py b/hloc/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hloc/utils/__init__.py b/hloc/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d5079d59ac615ce4c3d4b2e9e869eca9a4c411c
--- /dev/null
+++ b/hloc/utils/__init__.py
@@ -0,0 +1,13 @@
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def do_system(cmd, verbose=False):
+ if verbose:
+ logger.info(f"Run cmd: `{cmd}`.")
+ err = os.system(cmd)
+ if err:
+ logger.info(f"Run cmd err.")
+ sys.exit(err)
diff --git a/hloc/utils/base_model.py b/hloc/utils/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f560a2664eeb7ff53b49e169289ab284c16cb0ec
--- /dev/null
+++ b/hloc/utils/base_model.py
@@ -0,0 +1,47 @@
+import sys
+from abc import ABCMeta, abstractmethod
+from torch import nn
+from copy import copy
+import inspect
+
+
+class BaseModel(nn.Module, metaclass=ABCMeta):
+ default_conf = {}
+ required_inputs = []
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ self.conf = conf = {**self.default_conf, **conf}
+ self.required_inputs = copy(self.required_inputs)
+ self._init(conf)
+ sys.stdout.flush()
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+ for key in self.required_inputs:
+ assert key in data, "Missing key {} in data".format(key)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+
+def dynamic_load(root, model):
+ module_path = f"{root.__name__}.{model}"
+ module = __import__(module_path, fromlist=[""])
+ classes = inspect.getmembers(module, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == module_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseModel)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+ # return getattr(module, 'Model')
diff --git a/hloc/utils/database.py b/hloc/utils/database.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f59e77404990291057454f65f8756f8a183005
--- /dev/null
+++ b/hloc/utils/database.py
@@ -0,0 +1,409 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+# This script is based on an original implementation by True Price.
+
+import sys
+import sqlite3
+import numpy as np
+
+
+IS_PYTHON3 = sys.version_info[0] >= 3
+
+MAX_IMAGE_ID = 2**31 - 1
+
+CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ model INTEGER NOT NULL,
+ width INTEGER NOT NULL,
+ height INTEGER NOT NULL,
+ params BLOB,
+ prior_focal_length INTEGER NOT NULL)"""
+
+CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
+
+CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ name TEXT NOT NULL UNIQUE,
+ camera_id INTEGER NOT NULL,
+ prior_qw REAL,
+ prior_qx REAL,
+ prior_qy REAL,
+ prior_qz REAL,
+ prior_tx REAL,
+ prior_ty REAL,
+ prior_tz REAL,
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
+""".format(
+ MAX_IMAGE_ID
+)
+
+CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
+CREATE TABLE IF NOT EXISTS two_view_geometries (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ config INTEGER NOT NULL,
+ F BLOB,
+ E BLOB,
+ H BLOB,
+ qvec BLOB,
+ tvec BLOB)
+"""
+
+CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
+"""
+
+CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB)"""
+
+CREATE_NAME_INDEX = "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
+
+CREATE_ALL = "; ".join(
+ [
+ CREATE_CAMERAS_TABLE,
+ CREATE_IMAGES_TABLE,
+ CREATE_KEYPOINTS_TABLE,
+ CREATE_DESCRIPTORS_TABLE,
+ CREATE_MATCHES_TABLE,
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE,
+ CREATE_NAME_INDEX,
+ ]
+)
+
+
+def image_ids_to_pair_id(image_id1, image_id2):
+ if image_id1 > image_id2:
+ image_id1, image_id2 = image_id2, image_id1
+ return image_id1 * MAX_IMAGE_ID + image_id2
+
+
+def pair_id_to_image_ids(pair_id):
+ image_id2 = pair_id % MAX_IMAGE_ID
+ image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
+ return image_id1, image_id2
+
+
+def array_to_blob(array):
+ if IS_PYTHON3:
+ return array.tobytes()
+ else:
+ return np.getbuffer(array)
+
+
+def blob_to_array(blob, dtype, shape=(-1,)):
+ if IS_PYTHON3:
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
+ else:
+ return np.frombuffer(blob, dtype=dtype).reshape(*shape)
+
+
+class COLMAPDatabase(sqlite3.Connection):
+ @staticmethod
+ def connect(database_path):
+ return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
+
+ def __init__(self, *args, **kwargs):
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
+
+ self.create_tables = lambda: self.executescript(CREATE_ALL)
+ self.create_cameras_table = lambda: self.executescript(CREATE_CAMERAS_TABLE)
+ self.create_descriptors_table = lambda: self.executescript(
+ CREATE_DESCRIPTORS_TABLE
+ )
+ self.create_images_table = lambda: self.executescript(CREATE_IMAGES_TABLE)
+ self.create_two_view_geometries_table = lambda: self.executescript(
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE
+ )
+ self.create_keypoints_table = lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
+ self.create_matches_table = lambda: self.executescript(CREATE_MATCHES_TABLE)
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
+
+ def add_camera(
+ self, model, width, height, params, prior_focal_length=False, camera_id=None
+ ):
+ params = np.asarray(params, np.float64)
+ cursor = self.execute(
+ "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
+ (
+ camera_id,
+ model,
+ width,
+ height,
+ array_to_blob(params),
+ prior_focal_length,
+ ),
+ )
+ return cursor.lastrowid
+
+ def add_image(
+ self,
+ name,
+ camera_id,
+ prior_q=np.full(4, np.NaN),
+ prior_t=np.full(3, np.NaN),
+ image_id=None,
+ ):
+ cursor = self.execute(
+ "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (
+ image_id,
+ name,
+ camera_id,
+ prior_q[0],
+ prior_q[1],
+ prior_q[2],
+ prior_q[3],
+ prior_t[0],
+ prior_t[1],
+ prior_t[2],
+ ),
+ )
+ return cursor.lastrowid
+
+ def add_keypoints(self, image_id, keypoints):
+ assert len(keypoints.shape) == 2
+ assert keypoints.shape[1] in [2, 4, 6]
+
+ keypoints = np.asarray(keypoints, np.float32)
+ self.execute(
+ "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),),
+ )
+
+ def add_descriptors(self, image_id, descriptors):
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
+ self.execute(
+ "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),),
+ )
+
+ def add_matches(self, image_id1, image_id2, matches):
+ assert len(matches.shape) == 2
+ assert matches.shape[1] == 2
+
+ if image_id1 > image_id2:
+ matches = matches[:, ::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ self.execute(
+ "INSERT INTO matches VALUES (?, ?, ?, ?)",
+ (pair_id,) + matches.shape + (array_to_blob(matches),),
+ )
+
+ def add_two_view_geometry(
+ self,
+ image_id1,
+ image_id2,
+ matches,
+ F=np.eye(3),
+ E=np.eye(3),
+ H=np.eye(3),
+ qvec=np.array([1.0, 0.0, 0.0, 0.0]),
+ tvec=np.zeros(3),
+ config=2,
+ ):
+ assert len(matches.shape) == 2
+ assert matches.shape[1] == 2
+
+ if image_id1 > image_id2:
+ matches = matches[:, ::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ F = np.asarray(F, dtype=np.float64)
+ E = np.asarray(E, dtype=np.float64)
+ H = np.asarray(H, dtype=np.float64)
+ qvec = np.asarray(qvec, dtype=np.float64)
+ tvec = np.asarray(tvec, dtype=np.float64)
+ self.execute(
+ "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (pair_id,)
+ + matches.shape
+ + (
+ array_to_blob(matches),
+ config,
+ array_to_blob(F),
+ array_to_blob(E),
+ array_to_blob(H),
+ array_to_blob(qvec),
+ array_to_blob(tvec),
+ ),
+ )
+
+
+def example_usage():
+ import os
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--database_path", default="database.db")
+ args = parser.parse_args()
+
+ if os.path.exists(args.database_path):
+ print("ERROR: database path already exists -- will not modify it.")
+ return
+
+ # Open the database.
+
+ db = COLMAPDatabase.connect(args.database_path)
+
+ # For convenience, try creating all the tables upfront.
+
+ db.create_tables()
+
+ # Create dummy cameras.
+
+ model1, width1, height1, params1 = 0, 1024, 768, np.array((1024.0, 512.0, 384.0))
+ model2, width2, height2, params2 = (
+ 2,
+ 1024,
+ 768,
+ np.array((1024.0, 512.0, 384.0, 0.1)),
+ )
+
+ camera_id1 = db.add_camera(model1, width1, height1, params1)
+ camera_id2 = db.add_camera(model2, width2, height2, params2)
+
+ # Create dummy images.
+
+ image_id1 = db.add_image("image1.png", camera_id1)
+ image_id2 = db.add_image("image2.png", camera_id1)
+ image_id3 = db.add_image("image3.png", camera_id2)
+ image_id4 = db.add_image("image4.png", camera_id2)
+
+ # Create dummy keypoints.
+ #
+ # Note that COLMAP supports:
+ # - 2D keypoints: (x, y)
+ # - 4D keypoints: (x, y, theta, scale)
+ # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
+
+ num_keypoints = 1000
+ keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
+ keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
+
+ db.add_keypoints(image_id1, keypoints1)
+ db.add_keypoints(image_id2, keypoints2)
+ db.add_keypoints(image_id3, keypoints3)
+ db.add_keypoints(image_id4, keypoints4)
+
+ # Create dummy matches.
+
+ M = 50
+ matches12 = np.random.randint(num_keypoints, size=(M, 2))
+ matches23 = np.random.randint(num_keypoints, size=(M, 2))
+ matches34 = np.random.randint(num_keypoints, size=(M, 2))
+
+ db.add_matches(image_id1, image_id2, matches12)
+ db.add_matches(image_id2, image_id3, matches23)
+ db.add_matches(image_id3, image_id4, matches34)
+
+ # Commit the data to the file.
+
+ db.commit()
+
+ # Read and check cameras.
+
+ rows = db.execute("SELECT * FROM cameras")
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id1
+ assert model == model1 and width == width1 and height == height1
+ assert np.allclose(params, params1)
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id2
+ assert model == model2 and width == width2 and height == height2
+ assert np.allclose(params, params2)
+
+ # Read and check keypoints.
+
+ keypoints = dict(
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
+ for image_id, data in db.execute("SELECT image_id, data FROM keypoints")
+ )
+
+ assert np.allclose(keypoints[image_id1], keypoints1)
+ assert np.allclose(keypoints[image_id2], keypoints2)
+ assert np.allclose(keypoints[image_id3], keypoints3)
+ assert np.allclose(keypoints[image_id4], keypoints4)
+
+ # Read and check matches.
+
+ pair_ids = [
+ image_ids_to_pair_id(*pair)
+ for pair in (
+ (image_id1, image_id2),
+ (image_id2, image_id3),
+ (image_id3, image_id4),
+ )
+ ]
+
+ matches = dict(
+ (pair_id_to_image_ids(pair_id), blob_to_array(data, np.uint32, (-1, 2)))
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
+ )
+
+ assert np.all(matches[(image_id1, image_id2)] == matches12)
+ assert np.all(matches[(image_id2, image_id3)] == matches23)
+ assert np.all(matches[(image_id3, image_id4)] == matches34)
+
+ # Clean up.
+
+ db.close()
+
+ if os.path.exists(args.database_path):
+ os.remove(args.database_path)
+
+
+if __name__ == "__main__":
+ example_usage()
diff --git a/hloc/utils/geometry.py b/hloc/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fc1bb6914ee8a8b6770711c7effbc1f216ba2e7
--- /dev/null
+++ b/hloc/utils/geometry.py
@@ -0,0 +1,33 @@
+import numpy as np
+import pycolmap
+
+
+def to_homogeneous(p):
+ return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1)
+
+
+def vector_to_cross_product_matrix(v):
+ return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
+
+
+def compute_epipolar_errors(qvec_r2t, tvec_r2t, p2d_r, p2d_t):
+ T_r2t = pose_matrix_from_qvec_tvec(qvec_r2t, tvec_r2t)
+ # Compute errors in normalized plane to avoid distortion.
+ E = vector_to_cross_product_matrix(T_r2t[:3, -1]) @ T_r2t[:3, :3]
+ l2d_r2t = (E @ to_homogeneous(p2d_r).T).T
+ l2d_t2r = (E.T @ to_homogeneous(p2d_t).T).T
+ errors_r = np.abs(np.sum(to_homogeneous(p2d_r) * l2d_t2r, axis=1)) / np.linalg.norm(
+ l2d_t2r[:, :2], axis=1
+ )
+ errors_t = np.abs(np.sum(to_homogeneous(p2d_t) * l2d_r2t, axis=1)) / np.linalg.norm(
+ l2d_r2t[:, :2], axis=1
+ )
+ return E, errors_r, errors_t
+
+
+def pose_matrix_from_qvec_tvec(qvec, tvec):
+ pose = np.zeros((4, 4))
+ pose[:3, :3] = pycolmap.qvec_to_rotmat(qvec)
+ pose[:3, -1] = tvec
+ pose[-1, -1] = 1
+ return pose
diff --git a/hloc/utils/io.py b/hloc/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd55d4c30b41c3754634a164312dc5e8c294274
--- /dev/null
+++ b/hloc/utils/io.py
@@ -0,0 +1,77 @@
+from typing import Tuple
+from pathlib import Path
+import numpy as np
+import cv2
+import h5py
+
+from .parsers import names_to_pair, names_to_pair_old
+
+
+def read_image(path, grayscale=False):
+ if grayscale:
+ mode = cv2.IMREAD_GRAYSCALE
+ else:
+ mode = cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise ValueError(f"Cannot read image {path}.")
+ if not grayscale and len(image.shape) == 3:
+ image = image[:, :, ::-1] # BGR to RGB
+ return image
+
+
+def list_h5_names(path):
+ names = []
+ with h5py.File(str(path), "r", libver="latest") as fd:
+
+ def visit_fn(_, obj):
+ if isinstance(obj, h5py.Dataset):
+ names.append(obj.parent.name.strip("/"))
+
+ fd.visititems(visit_fn)
+ return list(set(names))
+
+
+def get_keypoints(
+ path: Path, name: str, return_uncertainty: bool = False
+) -> np.ndarray:
+ with h5py.File(str(path), "r", libver="latest") as hfile:
+ dset = hfile[name]["keypoints"]
+ p = dset.__array__()
+ uncertainty = dset.attrs.get("uncertainty")
+ if return_uncertainty:
+ return p, uncertainty
+ return p
+
+
+def find_pair(hfile: h5py.File, name0: str, name1: str):
+ pair = names_to_pair(name0, name1)
+ if pair in hfile:
+ return pair, False
+ pair = names_to_pair(name1, name0)
+ if pair in hfile:
+ return pair, True
+ # older, less efficient format
+ pair = names_to_pair_old(name0, name1)
+ if pair in hfile:
+ return pair, False
+ pair = names_to_pair_old(name1, name0)
+ if pair in hfile:
+ return pair, True
+ raise ValueError(
+ f"Could not find pair {(name0, name1)}... "
+ "Maybe you matched with a different list of pairs? "
+ )
+
+
+def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]:
+ with h5py.File(str(path), "r", libver="latest") as hfile:
+ pair, reverse = find_pair(hfile, name0, name1)
+ matches = hfile[pair]["matches0"].__array__()
+ scores = hfile[pair]["matching_scores0"].__array__()
+ idx = np.where(matches != -1)[0]
+ matches = np.stack([idx, matches[idx]], -1)
+ if reverse:
+ matches = np.flip(matches, -1)
+ scores = scores[idx]
+ return matches, scores
diff --git a/hloc/utils/parsers.py b/hloc/utils/parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..faaa8f2de952673abdb580abc5754efe1bfc5f40
--- /dev/null
+++ b/hloc/utils/parsers.py
@@ -0,0 +1,56 @@
+from pathlib import Path
+import logging
+import numpy as np
+from collections import defaultdict
+import pycolmap
+
+logger = logging.getLogger(__name__)
+
+
+def parse_image_list(path, with_intrinsics=False):
+ images = []
+ with open(path, "r") as f:
+ for line in f:
+ line = line.strip("\n")
+ if len(line) == 0 or line[0] == "#":
+ continue
+ name, *data = line.split()
+ if with_intrinsics:
+ model, width, height, *params = data
+ params = np.array(params, float)
+ cam = pycolmap.Camera(model, int(width), int(height), params)
+ images.append((name, cam))
+ else:
+ images.append(name)
+
+ assert len(images) > 0
+ logger.info(f"Imported {len(images)} images from {path.name}")
+ return images
+
+
+def parse_image_lists(paths, with_intrinsics=False):
+ images = []
+ files = list(Path(paths.parent).glob(paths.name))
+ assert len(files) > 0
+ for lfile in files:
+ images += parse_image_list(lfile, with_intrinsics=with_intrinsics)
+ return images
+
+
+def parse_retrieval(path):
+ retrieval = defaultdict(list)
+ with open(path, "r") as f:
+ for p in f.read().rstrip("\n").split("\n"):
+ if len(p) == 0:
+ continue
+ q, r = p.split()
+ retrieval[q].append(r)
+ return dict(retrieval)
+
+
+def names_to_pair(name0, name1, separator="/"):
+ return separator.join((name0.replace("/", "-"), name1.replace("/", "-")))
+
+
+def names_to_pair_old(name0, name1):
+ return names_to_pair(name0, name1, separator="_")
diff --git a/hloc/utils/read_write_model.py b/hloc/utils/read_write_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10bf55139f2f22d05d82a5153a29c493105bd75
--- /dev/null
+++ b/hloc/utils/read_write_model.py
@@ -0,0 +1,587 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+import os
+import collections
+import numpy as np
+import struct
+import argparse
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"]
+)
+Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
+)
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
+)
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
+}
+CAMERA_MODEL_IDS = dict(
+ [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
+)
+CAMERA_MODEL_NAMES = dict(
+ [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
+)
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(
+ id=camera_id, model=model, width=width, height=height, params=params
+ )
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ"
+ )
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(
+ fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params
+ )
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params),
+ )
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = (
+ "# Camera list with one line of data per camera:\n"
+ + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
+ + "# Number of cameras: {}\n".format(len(cameras))
+ )
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id, model_id, cam.width, cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack(
+ [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]
+ )
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi"
+ )
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
+ 0
+ ]
+ x_y_id_s = read_next_bytes(
+ fid,
+ num_bytes=24 * num_points2D,
+ format_char_sequence="ddq" * num_points2D,
+ )
+ xys = np.column_stack(
+ [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]
+ )
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum(
+ (len(img.point3D_ids) for _, img in images.items())
+ ) / len(images)
+ HEADER = (
+ "# Image list with two lines of data per image:\n"
+ + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
+ + "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
+ + "# Number of images: {}, mean observations per image: {}\n".format(
+ len(images), mean_observations
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
+ )
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
+ 0
+ ]
+ track_elems = read_next_bytes(
+ fid,
+ num_bytes=8 * track_length,
+ format_char_sequence="ii" * track_length,
+ )
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum(
+ (len(pt.image_ids) for _, pt in points3D.items())
+ ) / len(points3D)
+ HEADER = (
+ "# 3D point list with one line of data per point:\n"
+ + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
+ + "# Number of points: {}, mean track length: {}\n".format(
+ len(points3D), mean_track_length
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3D_binary(points3D, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def detect_model_format(path, ext):
+ if (
+ os.path.isfile(os.path.join(path, "cameras" + ext))
+ and os.path.isfile(os.path.join(path, "images" + ext))
+ and os.path.isfile(os.path.join(path, "points3D" + ext))
+ ):
+ return True
+
+ return False
+
+
+def read_model(path, ext=""):
+ # try to detect the extension automatically
+ if ext == "":
+ if detect_model_format(path, ".bin"):
+ ext = ".bin"
+ elif detect_model_format(path, ".txt"):
+ ext = ".txt"
+ else:
+ try:
+ cameras, images, points3D = read_model(os.path.join(path, "model/"))
+ logger.warning("This SfM file structure was deprecated in hloc v1.1")
+ return cameras, images, points3D
+ except FileNotFoundError:
+ raise FileNotFoundError(
+ f"Could not find binary or text COLMAP model at {path}"
+ )
+
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext=".bin"):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array(
+ [
+ [
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
+ ],
+ [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
+ ],
+ [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
+ ],
+ ]
+ )
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = (
+ np.array(
+ [
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
+ ]
+ )
+ / 3.0
+ )
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Read and write COLMAP binary and text models"
+ )
+ parser.add_argument("--input_model", help="path to input model folder")
+ parser.add_argument(
+ "--input_format",
+ choices=[".bin", ".txt"],
+ help="input model format",
+ default="",
+ )
+ parser.add_argument("--output_model", help="path to output model folder")
+ parser.add_argument(
+ "--output_format",
+ choices=[".bin", ".txt"],
+ help="outut model format",
+ default=".txt",
+ )
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(
+ cameras, images, points3D, path=args.output_model, ext=args.output_format
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/hloc/utils/viz.py b/hloc/utils/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..eace4fc20a6a4ebda9997acaf0842af5ae13ff42
--- /dev/null
+++ b/hloc/utils/viz.py
@@ -0,0 +1,145 @@
+"""
+2D visualization primitives based on Matplotlib.
+
+1) Plot images with `plot_images`.
+2) Call `plot_keypoints` or `plot_matches` any number of times.
+3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
+"""
+
+import matplotlib
+import matplotlib.pyplot as plt
+import matplotlib.patheffects as path_effects
+import numpy as np
+
+
+def cm_RdGn(x):
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
+ x = np.clip(x, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
+ return np.clip(c, 0, 1)
+
+
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ adaptive: whether the figure size should fit the image aspect ratios.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
+ else:
+ ratios = [4 / 3] * n
+ figsize = [sum(ratios) * 4.5, 4.5]
+ fig, ax = plt.subplots(
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+ )
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts, colors="lime", ps=4):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=color[i],
+ linewidth=lw,
+ alpha=a,
+ )
+ for i in range(len(kpts0))
+ ]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def add_text(
+ idx,
+ text,
+ pos=(0.01, 0.99),
+ fs=15,
+ color="w",
+ lcolor="k",
+ lwidth=2,
+ ha="left",
+ va="top",
+):
+ ax = plt.gcf().axes[idx]
+ t = ax.text(
+ *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
+ )
+ if lcolor is not None:
+ t.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+ path_effects.Normal(),
+ ]
+ )
+
+
+def save_plot(path, **kw):
+ """Save the current figure without any white margin."""
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
diff --git a/hloc/utils/viz_3d.py b/hloc/utils/viz_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..71254c3f1830f4eee92e56882d18ff2da355522d
--- /dev/null
+++ b/hloc/utils/viz_3d.py
@@ -0,0 +1,200 @@
+"""
+3D visualization based on plotly.
+Works for a small number of points and cameras, might be slow otherwise.
+
+1) Initialize a figure with `init_figure`
+2) Add 3D points, camera frustums, or both as a pycolmap.Reconstruction
+
+Written by Paul-Edouard Sarlin and Philipp Lindenberger.
+"""
+
+from typing import Optional
+import numpy as np
+import pycolmap
+import plotly.graph_objects as go
+
+
+def to_homogeneous(points):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+
+
+def init_figure(height: int = 800) -> go.Figure:
+ """Initialize a 3D figure."""
+ fig = go.Figure()
+ axes = dict(
+ visible=False,
+ showbackground=False,
+ showgrid=False,
+ showline=False,
+ showticklabels=True,
+ autorange=True,
+ )
+ fig.update_layout(
+ template="plotly_dark",
+ height=height,
+ scene_camera=dict(
+ eye=dict(x=0.0, y=-0.1, z=-2),
+ up=dict(x=0, y=-1.0, z=0),
+ projection=dict(type="orthographic"),
+ ),
+ scene=dict(
+ xaxis=axes,
+ yaxis=axes,
+ zaxis=axes,
+ aspectmode="data",
+ dragmode="orbit",
+ ),
+ margin=dict(l=0, r=0, b=0, t=0, pad=0),
+ legend=dict(orientation="h", yanchor="top", y=0.99, xanchor="left", x=0.1),
+ )
+ return fig
+
+
+def plot_points(
+ fig: go.Figure,
+ pts: np.ndarray,
+ color: str = "rgba(255, 0, 0, 1)",
+ ps: int = 2,
+ colorscale: Optional[str] = None,
+ name: Optional[str] = None,
+):
+ """Plot a set of 3D points."""
+ x, y, z = pts.T
+ tr = go.Scatter3d(
+ x=x,
+ y=y,
+ z=z,
+ mode="markers",
+ name=name,
+ legendgroup=name,
+ marker=dict(size=ps, color=color, line_width=0.0, colorscale=colorscale),
+ )
+ fig.add_trace(tr)
+
+
+def plot_camera(
+ fig: go.Figure,
+ R: np.ndarray,
+ t: np.ndarray,
+ K: np.ndarray,
+ color: str = "rgb(0, 0, 255)",
+ name: Optional[str] = None,
+ legendgroup: Optional[str] = None,
+ size: float = 1.0,
+):
+ """Plot a camera frustum from pose and intrinsic matrix."""
+ W, H = K[0, 2] * 2, K[1, 2] * 2
+ corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]])
+ if size is not None:
+ image_extent = max(size * W / 1024.0, size * H / 1024.0)
+ world_extent = max(W, H) / (K[0, 0] + K[1, 1]) / 0.5
+ scale = 0.5 * image_extent / world_extent
+ else:
+ scale = 1.0
+ corners = to_homogeneous(corners) @ np.linalg.inv(K).T
+ corners = (corners / 2 * scale) @ R.T + t
+
+ x, y, z = corners.T
+ rect = go.Scatter3d(
+ x=x,
+ y=y,
+ z=z,
+ line=dict(color=color),
+ legendgroup=legendgroup,
+ name=name,
+ marker=dict(size=0.0001),
+ showlegend=False,
+ )
+ fig.add_trace(rect)
+
+ x, y, z = np.concatenate(([t], corners)).T
+ i = [0, 0, 0, 0]
+ j = [1, 2, 3, 4]
+ k = [2, 3, 4, 1]
+
+ pyramid = go.Mesh3d(
+ x=x,
+ y=y,
+ z=z,
+ color=color,
+ i=i,
+ j=j,
+ k=k,
+ legendgroup=legendgroup,
+ name=name,
+ showlegend=False,
+ )
+ fig.add_trace(pyramid)
+ triangles = np.vstack((i, j, k)).T
+ vertices = np.concatenate(([t], corners))
+ tri_points = np.array([vertices[i] for i in triangles.reshape(-1)])
+
+ x, y, z = tri_points.T
+ pyramid = go.Scatter3d(
+ x=x,
+ y=y,
+ z=z,
+ mode="lines",
+ legendgroup=legendgroup,
+ name=name,
+ line=dict(color=color, width=1),
+ showlegend=False,
+ )
+ fig.add_trace(pyramid)
+
+
+def plot_camera_colmap(
+ fig: go.Figure,
+ image: pycolmap.Image,
+ camera: pycolmap.Camera,
+ name: Optional[str] = None,
+ **kwargs
+):
+ """Plot a camera frustum from PyCOLMAP objects"""
+ plot_camera(
+ fig,
+ image.rotmat().T,
+ image.projection_center(),
+ camera.calibration_matrix(),
+ name=name or str(image.image_id),
+ **kwargs
+ )
+
+
+def plot_cameras(fig: go.Figure, reconstruction: pycolmap.Reconstruction, **kwargs):
+ """Plot a camera as a cone with camera frustum."""
+ for image_id, image in reconstruction.images.items():
+ plot_camera_colmap(
+ fig, image, reconstruction.cameras[image.camera_id], **kwargs
+ )
+
+
+def plot_reconstruction(
+ fig: go.Figure,
+ rec: pycolmap.Reconstruction,
+ max_reproj_error: float = 6.0,
+ color: str = "rgb(0, 0, 255)",
+ name: Optional[str] = None,
+ min_track_length: int = 2,
+ points: bool = True,
+ cameras: bool = True,
+ cs: float = 1.0,
+):
+ # Filter outliers
+ bbs = rec.compute_bounding_box(0.001, 0.999)
+ # Filter points, use original reproj error here
+ xyzs = [
+ p3D.xyz
+ for _, p3D in rec.points3D.items()
+ if (
+ (p3D.xyz >= bbs[0]).all()
+ and (p3D.xyz <= bbs[1]).all()
+ and p3D.error <= max_reproj_error
+ and p3D.track.length() >= min_track_length
+ )
+ ]
+ if points:
+ plot_points(fig, np.array(xyzs), color=color, ps=1, name=name)
+ if cameras:
+ plot_cameras(fig, rec, color=color, legendgroup=name, size=cs)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c9f47ce5d57cff18e3cde27cb9aed37a0d9959a8
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,30 @@
+python==3.10.4
+pytorch==1.12.1
+torchvision==0.13.1
+torchmetrics==0.6.0
+pytorch-lightning==1.4.9
+numpy==1.23.5
+opencv-python==4.6.0.66
+tqdm>=4.36.0
+matplotlib
+plotly
+scipy
+scikit_learn
+scikit-image
+h5py
+pycolmap>=0.3.0
+kornia>=0.6.7
+gdown
+seaborn
+omegaconf
+pytlsd
+tensorboardX
+shapely
+yacs
+einops
+loguru
+gradio
+e2cnn
+./third_party/disk/submodules/torch-localize
+./third_party/disk/submodules/torch-dimcheck
+./third_party/disk/submodules/unets