Realcat commited on
Commit
63f3cf2
1 Parent(s): 4487d43

fix: eloftr

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. hloc/extractors/sfd2.py +5 -7
  2. hloc/matchers/eloftr.py +10 -6
  3. hloc/matchers/imp.py +5 -6
  4. third_party/pram/.gitignore +13 -0
  5. third_party/pram/LICENSE +2 -0
  6. third_party/pram/README.md +207 -0
  7. third_party/pram/assets/map_sparsification.gif +3 -0
  8. third_party/pram/assets/multi_recognition.png +3 -0
  9. third_party/pram/assets/overview.png +3 -0
  10. third_party/pram/assets/pipeline1.png +3 -0
  11. third_party/pram/assets/pram_demo.gif +3 -0
  12. third_party/pram/assets/sam_openvoc.png +3 -0
  13. third_party/pram/colmap_utils/camera_intrinsics.py +30 -0
  14. third_party/pram/colmap_utils/database.py +352 -0
  15. third_party/pram/colmap_utils/geometry.py +17 -0
  16. third_party/pram/colmap_utils/io.py +78 -0
  17. third_party/pram/colmap_utils/parsers.py +73 -0
  18. third_party/pram/colmap_utils/read_write_model.py +627 -0
  19. third_party/pram/colmap_utils/utils.py +1 -0
  20. third_party/pram/configs/config_train_12scenes_sfd2.yaml +102 -0
  21. third_party/pram/configs/config_train_7scenes_sfd2.yaml +104 -0
  22. third_party/pram/configs/config_train_aachen_sfd2.yaml +104 -0
  23. third_party/pram/configs/config_train_cambridge_sfd2.yaml +103 -0
  24. third_party/pram/configs/config_train_multiset_sfd2.yaml +100 -0
  25. third_party/pram/configs/datasets/12Scenes.yaml +166 -0
  26. third_party/pram/configs/datasets/7Scenes.yaml +96 -0
  27. third_party/pram/configs/datasets/Aachen.yaml +15 -0
  28. third_party/pram/configs/datasets/CambridgeLandmarks.yaml +67 -0
  29. third_party/pram/dataset/aachen.py +119 -0
  30. third_party/pram/dataset/basicdataset.py +477 -0
  31. third_party/pram/dataset/cambridge_landmarks.py +101 -0
  32. third_party/pram/dataset/customdataset.py +93 -0
  33. third_party/pram/dataset/get_dataset.py +89 -0
  34. third_party/pram/dataset/recdataset.py +95 -0
  35. third_party/pram/dataset/seven_scenes.py +115 -0
  36. third_party/pram/dataset/twelve_scenes.py +121 -0
  37. third_party/pram/dataset/utils.py +31 -0
  38. third_party/pram/environment.yml +173 -0
  39. third_party/pram/inference.py +62 -0
  40. third_party/pram/localization/base_model.py +45 -0
  41. third_party/pram/localization/camera.py +11 -0
  42. third_party/pram/localization/extract_features.py +256 -0
  43. third_party/pram/localization/frame.py +195 -0
  44. third_party/pram/localization/loc_by_rec_eval.py +299 -0
  45. third_party/pram/localization/loc_by_rec_online.py +225 -0
  46. third_party/pram/localization/localizer.py +217 -0
  47. third_party/pram/localization/match_features.py +156 -0
  48. third_party/pram/localization/match_features_batch.py +242 -0
  49. third_party/pram/localization/matchers/__init__.py +3 -0
  50. third_party/pram/localization/matchers/adagml.py +41 -0
hloc/extractors/sfd2.py CHANGED
@@ -1,4 +1,3 @@
1
- # -*- coding: UTF-8 -*-
2
  import sys
3
  from pathlib import Path
4
 
@@ -7,10 +6,9 @@ import torchvision.transforms as tvf
7
  from .. import logger
8
  from ..utils.base_model import BaseModel
9
 
10
- pram_path = Path(__file__).parent / "../../third_party/pram"
11
- sys.path.append(str(pram_path))
12
-
13
- from nets.sfd2 import load_sfd2
14
 
15
 
16
  class SFD2(BaseModel):
@@ -26,8 +24,8 @@ class SFD2(BaseModel):
26
  self.norm_rgb = tvf.Normalize(
27
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
28
  )
29
- model_fn = pram_path / "weights" / self.conf["model_name"]
30
- self.net = load_sfd2(weight_path=model_fn).eval()
31
 
32
  logger.info("Load SFD2 model done.")
33
 
 
 
1
  import sys
2
  from pathlib import Path
3
 
 
6
  from .. import logger
7
  from ..utils.base_model import BaseModel
8
 
9
+ tp_path = Path(__file__).parent / "../../third_party"
10
+ sys.path.append(str(tp_path))
11
+ from pram.nets.sfd2 import load_sfd2
 
12
 
13
 
14
  class SFD2(BaseModel):
 
24
  self.norm_rgb = tvf.Normalize(
25
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
26
  )
27
+ model_path = tp_path / "pram" / "weights" / self.conf["model_name"]
28
+ self.net = load_sfd2(weight_path=model_path).eval()
29
 
30
  logger.info("Load SFD2 model done.")
31
 
hloc/matchers/eloftr.py CHANGED
@@ -5,18 +5,22 @@ from pathlib import Path
5
 
6
  import torch
7
 
8
- eloftr_path = Path(__file__).parent / "../../third_party/EfficientLoFTR"
9
- sys.path.append(str(eloftr_path))
10
 
11
- from src.loftr import LoFTR as ELoFTR_
12
- from src.loftr import full_default_cfg, opt_default_cfg, reparameter
 
 
 
 
13
 
14
  from hloc import logger
15
 
16
  from ..utils.base_model import BaseModel
17
 
18
 
19
- class LoFTR(BaseModel):
20
  default_conf = {
21
  "weights": "weights/eloftr_outdoor.ckpt",
22
  "match_threshold": 0.2,
@@ -40,7 +44,7 @@ class LoFTR(BaseModel):
40
  _default_cfg["mp"] = True
41
  elif self.conf["precision"] == "fp16":
42
  _default_cfg["half"] = True
43
- model_path = eloftr_path / self.conf["weights"]
44
  cfg = _default_cfg
45
  cfg["match_coarse"]["thr"] = conf["match_threshold"]
46
  # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
 
5
 
6
  import torch
7
 
8
+ tp_path = Path(__file__).parent / "../../third_party"
9
+ sys.path.append(str(tp_path))
10
 
11
+ from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_
12
+ from EfficientLoFTR.src.loftr import (
13
+ full_default_cfg,
14
+ opt_default_cfg,
15
+ reparameter,
16
+ )
17
 
18
  from hloc import logger
19
 
20
  from ..utils.base_model import BaseModel
21
 
22
 
23
+ class ELoFTR(BaseModel):
24
  default_conf = {
25
  "weights": "weights/eloftr_outdoor.ckpt",
26
  "match_threshold": 0.2,
 
44
  _default_cfg["mp"] = True
45
  elif self.conf["precision"] == "fp16":
46
  _default_cfg["half"] = True
47
+ model_path = tp_path / "EfficientLoFTR" / self.conf["weights"]
48
  cfg = _default_cfg
49
  cfg["match_coarse"]["thr"] = conf["match_threshold"]
50
  # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
hloc/matchers/imp.py CHANGED
@@ -1,4 +1,3 @@
1
- # -*- coding: UTF-8 -*-
2
  import sys
3
  from pathlib import Path
4
 
@@ -7,10 +6,9 @@ import torch
7
  from .. import DEVICE, logger
8
  from ..utils.base_model import BaseModel
9
 
10
- pram_path = Path(__file__).parent / "../../third_party/pram"
11
- sys.path.append(str(pram_path))
12
-
13
- from nets.gml import GML
14
 
15
 
16
  class IMP(BaseModel):
@@ -33,7 +31,8 @@ class IMP(BaseModel):
33
 
34
  def _init(self, conf):
35
  self.conf = {**self.default_conf, **conf}
36
- weight_path = pram_path / "weights" / self.conf["model_name"]
 
37
  self.net = GML(self.conf).eval().to(DEVICE)
38
  self.net.load_state_dict(
39
  torch.load(weight_path, map_location="cpu")["model"], strict=True
 
 
1
  import sys
2
  from pathlib import Path
3
 
 
6
  from .. import DEVICE, logger
7
  from ..utils.base_model import BaseModel
8
 
9
+ tp_path = Path(__file__).parent / "../../third_party"
10
+ sys.path.append(str(tp_path))
11
+ from pram.nets.gml import GML
 
12
 
13
 
14
  class IMP(BaseModel):
 
31
 
32
  def _init(self, conf):
33
  self.conf = {**self.default_conf, **conf}
34
+ weight_path = tp_path / "pram" / "weights" / self.conf["model_name"]
35
+ # self.net = nets.gml(self.conf).eval().to(DEVICE)
36
  self.net = GML(self.conf).eval().to(DEVICE)
37
  self.net.load_state_dict(
38
  torch.load(weight_path, map_location="cpu")["model"], strict=True
third_party/pram/.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea
2
+ __pycache__
3
+ weights/12scenes*
4
+ weights/7scenes*
5
+ weights/aachen*
6
+ weights/cambridgelandmarks*
7
+ weights/imp_adagml.80.pth
8
+ landmarks
9
+ 3D-models
10
+ log_*
11
+ *.log
12
+ .nfs*
13
+ Pangolin
third_party/pram/LICENSE ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License.
2
+ To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/.
third_party/pram/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## PRAM: Place Recognition Anywhere Model for Efficient Visual Localization
2
+
3
+ <p align="center">
4
+ <img src="assets/overview.png" width="960">
5
+ </p>
6
+
7
+ Humans localize themselves efficiently in known environments by first recognizing landmarks defined on certain objects
8
+ and their spatial relationships, and then verifying the location by aligning detailed structures of recognized objects
9
+ with those in the memory. Inspired by this, we propose the place recognition anywhere model (PRAM) to perform visual
10
+ localization as efficiently as humans do. PRAM consists of two main components - recognition and registration. In
11
+ detail, first of all, a self-supervised map-centric landmark definition strategy is adopted, making places in either
12
+ indoor or outdoor scenes act as unique landmarks. Then, sparse keypoints extracted from images, are utilized as the
13
+ input to a transformer-based deep neural network for landmark recognition; these keypoints enable PRAM to recognize
14
+ hundreds of landmarks with high time and memory efficiency. Keypoints along with recognized landmark labels are further
15
+ used for registration between query images and the 3D landmark map. Different from previous hierarchical methods, PRAM
16
+ discards global and local descriptors, and reduces over 90% storage. Since PRAM utilizes recognition and landmark-wise
17
+ verification to replace global reference search and exhaustive matching respectively, it runs 2.4 times faster than
18
+ prior state-of-the-art approaches. Moreover, PRAM opens new directions for visual localization including multi-modality
19
+ localization, map-centric feature learning, and hierarchical scene coordinate regression.
20
+
21
+ * Full paper
22
+ PDF: [Place Recognition Anywhere Model for Efficient Visual Localization](https://arxiv.org/pdf/2404.07785.pdf).
23
+
24
+ * Authors: *Fei Xue, Ignas Budvytis, Roberto Cipolla*
25
+
26
+ * Website: [PRAM](https://feixue94.github.io/pram-project) for videos, slides, recent updates, and datasets.
27
+
28
+ ## Key Features
29
+
30
+ ### 1. Self-supervised landmark definition on 3D space
31
+
32
+ - No need of segmentations on images
33
+ - No inconsistent semantic results from multi-view images
34
+ - No limitation to labels of only known objects
35
+ - Work in any places with known or unknown objects
36
+ - Landmark-wise 3D map sparsification
37
+
38
+ <p align="center">
39
+ <img src="assets/map_sparsification.gif" width="640">
40
+ </p>
41
+
42
+ ### 2. Efficient landmark-wise coarse and fine localization
43
+
44
+ - Recognize landmarks as opposed to do global retrieval
45
+ - Local landmark-wise matching as opposed to exhaustive matching
46
+ - No global descriptors (e.g. NetVLAD)
47
+ - No reference images and their heavy repetative 2D keypoints and descriptors
48
+ - Automatic inlier/outlier idetification
49
+
50
+ <p align="center">
51
+ <img src="assets/pipeline1.png" width="640">
52
+ </p>
53
+
54
+ ### 4. Sparse recognition
55
+
56
+ - Sparse SFD2 keypoints as tokens
57
+ - No uncertainties of points at boundaries
58
+ - Flexible to accept multi-modality inputs
59
+
60
+ ### 5. Relocalization and temporal localization
61
+
62
+ - Per frame reclocalization from scratch
63
+ - Tracking previous frames for higher efficiency
64
+
65
+ ### 6. One model one dataset
66
+
67
+ - All 7 subscenes in 7Scenes dataset share a model
68
+ - All 12 subscenes in 12Scenes dataset share a model
69
+ - All 5 subscenes in CambridgeLandmarks share a model
70
+
71
+ ### 7. Robust to long-term changes
72
+
73
+ <p align="center">
74
+ <img src="assets/pram_demo.gif" width="640">
75
+ </p>
76
+
77
+ ## Open problems
78
+
79
+ - Adaptive number landmarks determination
80
+ - Using SAM + open vocabulary to generate semantic map
81
+ - Multi-modality localization with other tokenized signals (e.g. text, language, GPS, Magonemeter)
82
+ - More effective solutions to 3D sparsification
83
+
84
+ ## Preparation
85
+
86
+ 1. Download the 7Scenes, 12Scenes, CambridgeLandmarks, and Aachen datasets (remove redundant depth images otherwise they
87
+ will be found in the sfm process)
88
+ 2. Environments
89
+
90
+ 2.1 Create a virtual environment
91
+
92
+ ```
93
+ conda env create -f environment.yml
94
+ (do not activate pram before pangolin is installed)
95
+ ```
96
+
97
+ 2.2 Compile Pangolin for the installed python
98
+
99
+ ```
100
+ git clone --recursive https://github.com/stevenlovegrove/Pangolin.git
101
+ cd Pangolin
102
+ git checkout v0.8
103
+
104
+ # Install dependencies
105
+ ./scripts/install_prerequisites.sh recommended
106
+
107
+ # Compile with your python
108
+ cmake -DPython_EXECUTABLE=/your path to/anaconda3/envs/pram/bin/python3 -B build
109
+ cmake --build build -t pypangolin_pip_install
110
+
111
+ conda activate pram
112
+ ```
113
+
114
+ ## Run the localization with online visualization
115
+
116
+ 1. Download the [3D-models](https://drive.google.com/drive/folders/1DUB073KxAjsc8lxhMpFuxPRf0ZBQS6NS?usp=drive_link),
117
+ pretrained [models](https://drive.google.com/drive/folders/1E2QvujCevqnyg_CM9FGAa0AxKkt4KbLD?usp=drive_link) ,
118
+ and [landmarks](https://drive.google.com/drive/folders/1r9src9bz7k3WYGfaPmKJ9gqxuvdfxZU0?usp=sharing)
119
+ 2. Put pretrained models in ```weights``` directory
120
+ 3. Run the demo (e.g. 7Scenes)
121
+
122
+ ```
123
+ python3 inference.py --config configs/config_train_7scenes_sfd2.yaml --rec_weight_path weights/7scenes_nc113_birch_segnetvit.199.pth --landmark_path /your path to/landmarks --online
124
+ ```
125
+
126
+ ## Train the recognition model (e.g. for 7Scenes)
127
+
128
+ ### 1. Do SfM with SFD2 including feature extraction (modify the dataset_dir, ref_sfm_dir, output_dir)
129
+
130
+ ```
131
+ ./sfm_scripts/reconstruct_7scenes.sh
132
+ ```
133
+
134
+ This step will produce the SfM results together with the extracted keypoints
135
+
136
+ ### 2. Generate 3D landmarks
137
+
138
+ ```
139
+ python3 -m recognition.recmap --dataset 7Scenes --dataset_dir /your path to/7Scenes --sfm_dir /sfm_path/7Scenes --save_dir /save_path/landmakrs
140
+ ```
141
+
142
+ This step will generate 3D landmarks, create virtual reference frame, and sparsify the 3D points for each landmark for
143
+ all scenes in 7Scenes
144
+
145
+ ### 3. Train the sparse recognition model (one model one dataset)
146
+
147
+ ```
148
+ python3 train.py --config configs/config_train_7scenes_sfd2.yaml
149
+ ```
150
+
151
+ Remember to modify the paths in 'config_train_7scenes_sfd2.yaml'
152
+
153
+ ## Your own dataset
154
+
155
+ 1. Run colmap or hloc to obtain the SfM results
156
+ 2. Do reconstruction with SFD2 keypoints with the sfm from step as refernece sfm
157
+ 3. Do 3D landmark generation, VRF, map sparsification etc (Add DatasetName.yaml to configs/datasets)
158
+ 4. Train the recognition model
159
+ 5. Do evaluation
160
+
161
+ ## Previous works can be found here
162
+
163
+ 1. [Efficient large-scale localization by landmark recognition, CVPR 2022](https://github.com/feixue94/lbr)
164
+ 2. [IMP: Iterative Matching and Pose Estimation with Adaptive Pooling, CVPR 2023](https://github.com/feixue94/imp-release)
165
+ 3. [SFD2: Semantic-guided Feature Detection and Description, CVPR 2023](https://github.com/feixue94/sfd2)
166
+ 4. [VRS-NeRF: Visual Relocalization with Sparse Neural Radiance Field, under review](https://github.com/feixue94/vrs-nerf)
167
+
168
+ ## BibTeX Citation
169
+
170
+ If you use any ideas from the paper or code in this repo, please consider citing:
171
+
172
+ ```
173
+ @article{xue2024pram,
174
+ author = {Fei Xue and Ignas Budvytis and Roberto Cipolla},
175
+ title = {PRAM: Place Recognition Anywhere Model for Efficient Visual Localization},
176
+ journal = {arXiv preprint arXiv:2404.07785},
177
+ year = {2024}
178
+ }
179
+
180
+ @inproceedings{xue2023sfd2,
181
+ author = {Fei Xue and Ignas Budvytis and Roberto Cipolla},
182
+ title = {SFD2: Semantic-guided Feature Detection and Description},
183
+ booktitle = {CVPR},
184
+ year = {2023}
185
+ }
186
+
187
+ @inproceedings{xue2022imp,
188
+ author = {Fei Xue and Ignas Budvytis and Roberto Cipolla},
189
+ title = {IMP: Iterative Matching and Pose Estimation with Adaptive Pooling},
190
+ booktitle = {CVPR},
191
+ year = {2023}
192
+ }
193
+
194
+ @inproceedings{xue2022efficient,
195
+ author = {Fei Xue and Ignas Budvytis and Daniel Olmeda Reino and Roberto Cipolla},
196
+ title = {Efficient Large-scale Localization by Global Instance Recognition},
197
+ booktitle = {CVPR},
198
+ year = {2022}
199
+ }
200
+ ```
201
+
202
+ ## Acknowledgements
203
+
204
+ Part of the code is from previous excellent works
205
+ including , [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork)
206
+ and [hloc](https://github.com/cvg/Hierarchical-Localization). You can find more details from their released
207
+ repositories if you are interested in their works.
third_party/pram/assets/map_sparsification.gif ADDED

Git LFS Details

  • SHA256: fd7bbe3b0bad7c6ae330eaa702b2839533a6f27ad5a0b104c4a37597c0c37aad
  • Pointer size: 131 Bytes
  • Size of remote file: 493 kB
third_party/pram/assets/multi_recognition.png ADDED

Git LFS Details

  • SHA256: c84e81cb990adedc25ef612b31d1ec53f7cb9f2168ef2246f2f03ca479cca9cf
  • Pointer size: 132 Bytes
  • Size of remote file: 2.46 MB
third_party/pram/assets/overview.png ADDED

Git LFS Details

  • SHA256: 466b1f2b6a38cb956a389c1fc69c213c1655579c0c944174b6e95e247209eedc
  • Pointer size: 131 Bytes
  • Size of remote file: 662 kB
third_party/pram/assets/pipeline1.png ADDED

Git LFS Details

  • SHA256: 0bd0545bc3f4814d4b9f18893965529a08a73263e80a3978755162935e05d2b3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.99 MB
third_party/pram/assets/pram_demo.gif ADDED

Git LFS Details

  • SHA256: 95e56e33824789b650f4760b4246eca89c9cd1a8c138afc2d2ab5e24ec665fac
  • Pointer size: 133 Bytes
  • Size of remote file: 14.7 MB
third_party/pram/assets/sam_openvoc.png ADDED

Git LFS Details

  • SHA256: b3e0b06b6917402ed010cd4054e2efcf75c04ede84be53f17d147e2dd388d15a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
third_party/pram/colmap_utils/camera_intrinsics.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File localizer -> camera_intrinsics
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 15/08/2023 12:33
7
+ =================================================='''
8
+ import numpy as np
9
+
10
+
11
+ def intrinsics_from_camera(camera_model, params):
12
+ if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
13
+ fx = fy = params[0]
14
+ cx = params[1]
15
+ cy = params[2]
16
+ elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
17
+ fx = params[0]
18
+ fy = params[1]
19
+ cx = params[2]
20
+ cy = params[3]
21
+ else:
22
+ raise Exception("Camera model not supported")
23
+
24
+ # intrinsics
25
+ K = np.identity(3)
26
+ K[0, 0] = fx
27
+ K[1, 1] = fy
28
+ K[0, 2] = cx
29
+ K[1, 2] = cy
30
+ return K
third_party/pram/colmap_utils/database.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31
+
32
+ # This script is based on an original implementation by True Price.
33
+
34
+ import sys
35
+ import sqlite3
36
+ import numpy as np
37
+
38
+
39
+ IS_PYTHON3 = sys.version_info[0] >= 3
40
+
41
+ MAX_IMAGE_ID = 2**31 - 1
42
+
43
+ CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
44
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
45
+ model INTEGER NOT NULL,
46
+ width INTEGER NOT NULL,
47
+ height INTEGER NOT NULL,
48
+ params BLOB,
49
+ prior_focal_length INTEGER NOT NULL)"""
50
+
51
+ CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
52
+ image_id INTEGER PRIMARY KEY NOT NULL,
53
+ rows INTEGER NOT NULL,
54
+ cols INTEGER NOT NULL,
55
+ data BLOB,
56
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
57
+
58
+ CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
59
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
60
+ name TEXT NOT NULL UNIQUE,
61
+ camera_id INTEGER NOT NULL,
62
+ prior_qw REAL,
63
+ prior_qx REAL,
64
+ prior_qy REAL,
65
+ prior_qz REAL,
66
+ prior_tx REAL,
67
+ prior_ty REAL,
68
+ prior_tz REAL,
69
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
70
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
71
+ """.format(MAX_IMAGE_ID)
72
+
73
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
74
+ CREATE TABLE IF NOT EXISTS two_view_geometries (
75
+ pair_id INTEGER PRIMARY KEY NOT NULL,
76
+ rows INTEGER NOT NULL,
77
+ cols INTEGER NOT NULL,
78
+ data BLOB,
79
+ config INTEGER NOT NULL,
80
+ F BLOB,
81
+ E BLOB,
82
+ H BLOB)
83
+ """
84
+
85
+ CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
86
+ image_id INTEGER PRIMARY KEY NOT NULL,
87
+ rows INTEGER NOT NULL,
88
+ cols INTEGER NOT NULL,
89
+ data BLOB,
90
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
91
+ """
92
+
93
+ CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
94
+ pair_id INTEGER PRIMARY KEY NOT NULL,
95
+ rows INTEGER NOT NULL,
96
+ cols INTEGER NOT NULL,
97
+ data BLOB)"""
98
+
99
+ CREATE_NAME_INDEX = \
100
+ "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
101
+
102
+ CREATE_ALL = "; ".join([
103
+ CREATE_CAMERAS_TABLE,
104
+ CREATE_IMAGES_TABLE,
105
+ CREATE_KEYPOINTS_TABLE,
106
+ CREATE_DESCRIPTORS_TABLE,
107
+ CREATE_MATCHES_TABLE,
108
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE,
109
+ CREATE_NAME_INDEX
110
+ ])
111
+
112
+
113
+ def image_ids_to_pair_id(image_id1, image_id2):
114
+ if image_id1 > image_id2:
115
+ image_id1, image_id2 = image_id2, image_id1
116
+ return image_id1 * MAX_IMAGE_ID + image_id2
117
+
118
+
119
+ def pair_id_to_image_ids(pair_id):
120
+ image_id2 = pair_id % MAX_IMAGE_ID
121
+ image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
122
+ return image_id1, image_id2
123
+
124
+
125
+ def array_to_blob(array):
126
+ if IS_PYTHON3:
127
+ return array.tostring()
128
+ else:
129
+ return np.getbuffer(array)
130
+
131
+
132
+ def blob_to_array(blob, dtype, shape=(-1,)):
133
+ if IS_PYTHON3:
134
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
135
+ else:
136
+ return np.frombuffer(blob, dtype=dtype).reshape(*shape)
137
+
138
+
139
+ class COLMAPDatabase(sqlite3.Connection):
140
+
141
+ @staticmethod
142
+ def connect(database_path):
143
+ return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
144
+
145
+
146
+ def __init__(self, *args, **kwargs):
147
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
148
+
149
+ self.create_tables = lambda: self.executescript(CREATE_ALL)
150
+ self.create_cameras_table = \
151
+ lambda: self.executescript(CREATE_CAMERAS_TABLE)
152
+ self.create_descriptors_table = \
153
+ lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
154
+ self.create_images_table = \
155
+ lambda: self.executescript(CREATE_IMAGES_TABLE)
156
+ self.create_two_view_geometries_table = \
157
+ lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
158
+ self.create_keypoints_table = \
159
+ lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
160
+ self.create_matches_table = \
161
+ lambda: self.executescript(CREATE_MATCHES_TABLE)
162
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
163
+
164
+ def add_camera(self, model, width, height, params,
165
+ prior_focal_length=False, camera_id=None):
166
+ params = np.asarray(params, np.float64)
167
+ cursor = self.execute(
168
+ "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
169
+ (camera_id, model, width, height, array_to_blob(params),
170
+ prior_focal_length))
171
+ return cursor.lastrowid
172
+
173
+ def add_image(self, name, camera_id,
174
+ prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
175
+ cursor = self.execute(
176
+ "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
177
+ (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
178
+ prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
179
+ return cursor.lastrowid
180
+
181
+ def add_keypoints(self, image_id, keypoints):
182
+ assert(len(keypoints.shape) == 2)
183
+ assert(keypoints.shape[1] in [2, 4, 6])
184
+
185
+ keypoints = np.asarray(keypoints, np.float32)
186
+ self.execute(
187
+ "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
188
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
189
+
190
+ def add_descriptors(self, image_id, descriptors):
191
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
192
+ self.execute(
193
+ "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
194
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
195
+
196
+ def add_matches(self, image_id1, image_id2, matches):
197
+ assert(len(matches.shape) == 2)
198
+ assert(matches.shape[1] == 2)
199
+
200
+ if image_id1 > image_id2:
201
+ matches = matches[:,::-1]
202
+
203
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
204
+ matches = np.asarray(matches, np.uint32)
205
+ self.execute(
206
+ "INSERT INTO matches VALUES (?, ?, ?, ?)",
207
+ (pair_id,) + matches.shape + (array_to_blob(matches),))
208
+
209
+ def add_two_view_geometry(self, image_id1, image_id2, matches,
210
+ F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
211
+ assert(len(matches.shape) == 2)
212
+ assert(matches.shape[1] == 2)
213
+
214
+ if image_id1 > image_id2:
215
+ matches = matches[:,::-1]
216
+
217
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
218
+ matches = np.asarray(matches, np.uint32)
219
+ F = np.asarray(F, dtype=np.float64)
220
+ E = np.asarray(E, dtype=np.float64)
221
+ H = np.asarray(H, dtype=np.float64)
222
+ self.execute(
223
+ "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
224
+ (pair_id,) + matches.shape + (array_to_blob(matches), config,
225
+ array_to_blob(F), array_to_blob(E), array_to_blob(H)))
226
+
227
+
228
+ def example_usage():
229
+ import os
230
+ import argparse
231
+
232
+ parser = argparse.ArgumentParser()
233
+ parser.add_argument("--database_path", default="database.db")
234
+ args = parser.parse_args()
235
+
236
+ if os.path.exists(args.database_path):
237
+ print("ERROR: database path already exists -- will not modify it.")
238
+ return
239
+
240
+ # Open the database.
241
+
242
+ db = COLMAPDatabase.connect(args.database_path)
243
+
244
+ # For convenience, try creating all the tables upfront.
245
+
246
+ db.create_tables()
247
+
248
+ # Create dummy cameras.
249
+
250
+ model1, width1, height1, params1 = \
251
+ 0, 1024, 768, np.array((1024., 512., 384.))
252
+ model2, width2, height2, params2 = \
253
+ 2, 1024, 768, np.array((1024., 512., 384., 0.1))
254
+
255
+ camera_id1 = db.add_camera(model1, width1, height1, params1)
256
+ camera_id2 = db.add_camera(model2, width2, height2, params2)
257
+
258
+ # Create dummy images.
259
+
260
+ image_id1 = db.add_image("image1.png", camera_id1)
261
+ image_id2 = db.add_image("image2.png", camera_id1)
262
+ image_id3 = db.add_image("image3.png", camera_id2)
263
+ image_id4 = db.add_image("image4.png", camera_id2)
264
+
265
+ # Create dummy keypoints.
266
+ #
267
+ # Note that COLMAP supports:
268
+ # - 2D keypoints: (x, y)
269
+ # - 4D keypoints: (x, y, theta, scale)
270
+ # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
271
+
272
+ num_keypoints = 1000
273
+ keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
274
+ keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
275
+ keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
276
+ keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
277
+
278
+ db.add_keypoints(image_id1, keypoints1)
279
+ db.add_keypoints(image_id2, keypoints2)
280
+ db.add_keypoints(image_id3, keypoints3)
281
+ db.add_keypoints(image_id4, keypoints4)
282
+
283
+ # Create dummy matches.
284
+
285
+ M = 50
286
+ matches12 = np.random.randint(num_keypoints, size=(M, 2))
287
+ matches23 = np.random.randint(num_keypoints, size=(M, 2))
288
+ matches34 = np.random.randint(num_keypoints, size=(M, 2))
289
+
290
+ db.add_matches(image_id1, image_id2, matches12)
291
+ db.add_matches(image_id2, image_id3, matches23)
292
+ db.add_matches(image_id3, image_id4, matches34)
293
+
294
+ # Commit the data to the file.
295
+
296
+ db.commit()
297
+
298
+ # Read and check cameras.
299
+
300
+ rows = db.execute("SELECT * FROM cameras")
301
+
302
+ camera_id, model, width, height, params, prior = next(rows)
303
+ params = blob_to_array(params, np.float64)
304
+ assert camera_id == camera_id1
305
+ assert model == model1 and width == width1 and height == height1
306
+ assert np.allclose(params, params1)
307
+
308
+ camera_id, model, width, height, params, prior = next(rows)
309
+ params = blob_to_array(params, np.float64)
310
+ assert camera_id == camera_id2
311
+ assert model == model2 and width == width2 and height == height2
312
+ assert np.allclose(params, params2)
313
+
314
+ # Read and check keypoints.
315
+
316
+ keypoints = dict(
317
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
318
+ for image_id, data in db.execute(
319
+ "SELECT image_id, data FROM keypoints"))
320
+
321
+ assert np.allclose(keypoints[image_id1], keypoints1)
322
+ assert np.allclose(keypoints[image_id2], keypoints2)
323
+ assert np.allclose(keypoints[image_id3], keypoints3)
324
+ assert np.allclose(keypoints[image_id4], keypoints4)
325
+
326
+ # Read and check matches.
327
+
328
+ pair_ids = [image_ids_to_pair_id(*pair) for pair in
329
+ ((image_id1, image_id2),
330
+ (image_id2, image_id3),
331
+ (image_id3, image_id4))]
332
+
333
+ matches = dict(
334
+ (pair_id_to_image_ids(pair_id),
335
+ blob_to_array(data, np.uint32, (-1, 2)))
336
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
337
+ )
338
+
339
+ assert np.all(matches[(image_id1, image_id2)] == matches12)
340
+ assert np.all(matches[(image_id2, image_id3)] == matches23)
341
+ assert np.all(matches[(image_id3, image_id4)] == matches34)
342
+
343
+ # Clean up.
344
+
345
+ db.close()
346
+
347
+ if os.path.exists(args.database_path):
348
+ os.remove(args.database_path)
349
+
350
+
351
+ if __name__ == "__main__":
352
+ example_usage()
third_party/pram/colmap_utils/geometry.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ import numpy as np
3
+ import pycolmap
4
+
5
+
6
+ def to_homogeneous(p):
7
+ return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1)
8
+
9
+
10
+ def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j):
11
+ j_E_i = j_from_i.essential_matrix()
12
+ l2d_j = to_homogeneous(p2d_i) @ j_E_i.T
13
+ l2d_i = to_homogeneous(p2d_j) @ j_E_i
14
+ dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1))
15
+ errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1)
16
+ errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1)
17
+ return errors_i, errors_j
third_party/pram/colmap_utils/io.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ from pathlib import Path
3
+ from typing import Tuple
4
+
5
+ import cv2
6
+ import h5py
7
+ import numpy as np
8
+
9
+ from .parsers import names_to_pair, names_to_pair_old
10
+
11
+
12
+ def read_image(path, grayscale=False):
13
+ if grayscale:
14
+ mode = cv2.IMREAD_GRAYSCALE
15
+ else:
16
+ mode = cv2.IMREAD_COLOR
17
+ image = cv2.imread(str(path), mode)
18
+ if image is None:
19
+ raise ValueError(f"Cannot read image {path}.")
20
+ if not grayscale and len(image.shape) == 3:
21
+ image = image[:, :, ::-1] # BGR to RGB
22
+ return image
23
+
24
+
25
+ def list_h5_names(path):
26
+ names = []
27
+ with h5py.File(str(path), "r", libver="latest") as fd:
28
+ def visit_fn(_, obj):
29
+ if isinstance(obj, h5py.Dataset):
30
+ names.append(obj.parent.name.strip("/"))
31
+
32
+ fd.visititems(visit_fn)
33
+ return list(set(names))
34
+
35
+
36
+ def get_keypoints(
37
+ path: Path, name: str, return_uncertainty: bool = False
38
+ ) -> np.ndarray:
39
+ with h5py.File(str(path), "r", libver="latest") as hfile:
40
+ dset = hfile[name]["keypoints"]
41
+ p = dset.__array__()
42
+ uncertainty = dset.attrs.get("uncertainty")
43
+ if return_uncertainty:
44
+ return p, uncertainty
45
+ return p
46
+
47
+
48
+ def find_pair(hfile: h5py.File, name0: str, name1: str):
49
+ pair = names_to_pair(name0, name1)
50
+ if pair in hfile:
51
+ return pair, False
52
+ pair = names_to_pair(name1, name0)
53
+ if pair in hfile:
54
+ return pair, True
55
+ # older, less efficient format
56
+ pair = names_to_pair_old(name0, name1)
57
+ if pair in hfile:
58
+ return pair, False
59
+ pair = names_to_pair_old(name1, name0)
60
+ if pair in hfile:
61
+ return pair, True
62
+ raise ValueError(
63
+ f"Could not find pair {(name0, name1)}... "
64
+ "Maybe you matched with a different list of pairs? "
65
+ )
66
+
67
+
68
+ def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]:
69
+ with h5py.File(str(path), "r", libver="latest") as hfile:
70
+ pair, reverse = find_pair(hfile, name0, name1)
71
+ matches = hfile[pair]["matches0"].__array__()
72
+ scores = hfile[pair]["matching_scores0"].__array__()
73
+ idx = np.where(matches != -1)[0]
74
+ matches = np.stack([idx, matches[idx]], -1)
75
+ if reverse:
76
+ matches = np.flip(matches, -1)
77
+ scores = scores[idx]
78
+ return matches, scores
third_party/pram/colmap_utils/parsers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+
3
+ from pathlib import Path
4
+ import logging
5
+ import numpy as np
6
+ from collections import defaultdict
7
+
8
+
9
+ def parse_image_lists_with_intrinsics(paths):
10
+ results = []
11
+ files = list(Path(paths.parent).glob(paths.name))
12
+ assert len(files) > 0
13
+
14
+ for lfile in files:
15
+ with open(lfile, 'r') as f:
16
+ raw_data = f.readlines()
17
+
18
+ logging.info(f'Importing {len(raw_data)} queries in {lfile.name}')
19
+ for data in raw_data:
20
+ data = data.strip('\n').split(' ')
21
+ name, camera_model, width, height = data[:4]
22
+ params = np.array(data[4:], float)
23
+ info = (camera_model, int(width), int(height), params)
24
+ results.append((name, info))
25
+
26
+ assert len(results) > 0
27
+ return results
28
+
29
+
30
+ def parse_img_lists_for_extended_cmu_seaons(paths):
31
+ Ks = {
32
+ "c0": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571",
33
+ "c1": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571"
34
+ }
35
+
36
+ results = []
37
+ files = list(Path(paths.parent).glob(paths.name))
38
+ assert len(files) > 0
39
+
40
+ for lfile in files:
41
+ with open(lfile, 'r') as f:
42
+ raw_data = f.readlines()
43
+
44
+ logging.info(f'Importing {len(raw_data)} queries in {lfile.name}')
45
+ for name in raw_data:
46
+ name = name.strip('\n')
47
+ camera = name.split('_')[2]
48
+ K = Ks[camera].split(' ')
49
+ camera_model, width, height = K[:3]
50
+ params = np.array(K[3:], float)
51
+ # print("camera: ", camera_model, width, height, params)
52
+ info = (camera_model, int(width), int(height), params)
53
+ results.append((name, info))
54
+
55
+ assert len(results) > 0
56
+ return results
57
+
58
+
59
+ def parse_retrieval(path):
60
+ retrieval = defaultdict(list)
61
+ with open(path, 'r') as f:
62
+ for p in f.read().rstrip('\n').split('\n'):
63
+ q, r = p.split(' ')
64
+ retrieval[q].append(r)
65
+ return dict(retrieval)
66
+
67
+
68
+ def names_to_pair_old(name0, name1):
69
+ return '_'.join((name0.replace('/', '-'), name1.replace('/', '-')))
70
+
71
+
72
+ def names_to_pair(name0, name1, separator="/"):
73
+ return separator.join((name0.replace("/", "-"), name1.replace("/", "-")))
third_party/pram/colmap_utils/read_write_model.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31
+
32
+ import os
33
+ import sys
34
+ import collections
35
+ import numpy as np
36
+ import struct
37
+ import argparse
38
+
39
+ CameraModel = collections.namedtuple(
40
+ "CameraModel", ["model_id", "model_name", "num_params"])
41
+ Camera = collections.namedtuple(
42
+ "Camera", ["id", "model", "width", "height", "params"])
43
+ BaseImage = collections.namedtuple(
44
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
45
+ Point3D = collections.namedtuple(
46
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
47
+
48
+
49
+ class Image(BaseImage):
50
+ def qvec2rotmat(self):
51
+ return qvec2rotmat(self.qvec)
52
+
53
+
54
+ CAMERA_MODELS = {
55
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
56
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
57
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
58
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
59
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
60
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
61
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
62
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
63
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
64
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
65
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
66
+ }
67
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
68
+ for camera_model in CAMERA_MODELS])
69
+ CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
70
+ for camera_model in CAMERA_MODELS])
71
+
72
+
73
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
74
+ """Read and unpack the next bytes from a binary file.
75
+ :param fid:
76
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
77
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
78
+ :param endian_character: Any of {@, =, <, >, !}
79
+ :return: Tuple of read and unpacked values.
80
+ """
81
+ data = fid.read(num_bytes)
82
+ return struct.unpack(endian_character + format_char_sequence, data)
83
+
84
+
85
+ def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
86
+ """pack and write to a binary file.
87
+ :param fid:
88
+ :param data: data to send, if multiple elements are sent at the same time,
89
+ they should be encapsuled either in a list or a tuple
90
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
91
+ should be the same length as the data list or tuple
92
+ :param endian_character: Any of {@, =, <, >, !}
93
+ """
94
+ if isinstance(data, (list, tuple)):
95
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
96
+ else:
97
+ bytes = struct.pack(endian_character + format_char_sequence, data)
98
+ fid.write(bytes)
99
+
100
+
101
+ def read_cameras_text(path):
102
+ """
103
+ see: src/base/reconstruction.cc
104
+ void Reconstruction::WriteCamerasText(const std::string& path)
105
+ void Reconstruction::ReadCamerasText(const std::string& path)
106
+ """
107
+ cameras = {}
108
+ with open(path, "r") as fid:
109
+ while True:
110
+ line = fid.readline()
111
+ if not line:
112
+ break
113
+ line = line.strip()
114
+ if len(line) > 0 and line[0] != "#":
115
+ elems = line.split()
116
+ camera_id = int(elems[0])
117
+ model = elems[1]
118
+ width = int(elems[2])
119
+ height = int(elems[3])
120
+ params = np.array(tuple(map(float, elems[4:])))
121
+ cameras[camera_id] = Camera(id=camera_id, model=model,
122
+ width=width, height=height,
123
+ params=params)
124
+ return cameras
125
+
126
+
127
+ def read_cameras_binary(path_to_model_file):
128
+ """
129
+ see: src/base/reconstruction.cc
130
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
131
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
132
+ """
133
+ cameras = {}
134
+ with open(path_to_model_file, "rb") as fid:
135
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
136
+ for camera_line_index in range(num_cameras):
137
+ camera_properties = read_next_bytes(
138
+ fid, num_bytes=24, format_char_sequence="iiQQ")
139
+ camera_id = camera_properties[0]
140
+ model_id = camera_properties[1]
141
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
142
+ width = camera_properties[2]
143
+ height = camera_properties[3]
144
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
145
+ params = read_next_bytes(fid, num_bytes=8 * num_params,
146
+ format_char_sequence="d" * num_params)
147
+ cameras[camera_id] = Camera(id=camera_id,
148
+ model=model_name,
149
+ width=width,
150
+ height=height,
151
+ params=np.array(params))
152
+ assert len(cameras) == num_cameras
153
+ return cameras
154
+
155
+
156
+ def write_cameras_text(cameras, path):
157
+ """
158
+ see: src/base/reconstruction.cc
159
+ void Reconstruction::WriteCamerasText(const std::string& path)
160
+ void Reconstruction::ReadCamerasText(const std::string& path)
161
+ """
162
+ HEADER = '# Camera list with one line of data per camera:\n'
163
+ '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n'
164
+ '# Number of cameras: {}\n'.format(len(cameras))
165
+ with open(path, "w") as fid:
166
+ fid.write(HEADER)
167
+ for _, cam in cameras.items():
168
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
169
+ line = " ".join([str(elem) for elem in to_write])
170
+ fid.write(line + "\n")
171
+
172
+
173
+ def write_cameras_binary(cameras, path_to_model_file):
174
+ """
175
+ see: src/base/reconstruction.cc
176
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
177
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
178
+ """
179
+ with open(path_to_model_file, "wb") as fid:
180
+ write_next_bytes(fid, len(cameras), "Q")
181
+ for _, cam in cameras.items():
182
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
183
+ camera_properties = [cam.id,
184
+ model_id,
185
+ cam.width,
186
+ cam.height]
187
+ write_next_bytes(fid, camera_properties, "iiQQ")
188
+ for p in cam.params:
189
+ write_next_bytes(fid, float(p), "d")
190
+ return cameras
191
+
192
+
193
+ def read_images_text(path):
194
+ """
195
+ see: src/base/reconstruction.cc
196
+ void Reconstruction::ReadImagesText(const std::string& path)
197
+ void Reconstruction::WriteImagesText(const std::string& path)
198
+ """
199
+ images = {}
200
+ with open(path, "r") as fid:
201
+ while True:
202
+ line = fid.readline()
203
+ if not line:
204
+ break
205
+ line = line.strip()
206
+ if len(line) > 0 and line[0] != "#":
207
+ elems = line.split()
208
+ image_id = int(elems[0])
209
+ qvec = np.array(tuple(map(float, elems[1:5])))
210
+ tvec = np.array(tuple(map(float, elems[5:8])))
211
+ camera_id = int(elems[8])
212
+ image_name = elems[9]
213
+ elems = fid.readline().split()
214
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
215
+ tuple(map(float, elems[1::3]))])
216
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
217
+ images[image_id] = Image(
218
+ id=image_id, qvec=qvec, tvec=tvec,
219
+ camera_id=camera_id, name=image_name,
220
+ xys=xys, point3D_ids=point3D_ids)
221
+ return images
222
+
223
+
224
+ def read_images_binary(path_to_model_file):
225
+ """
226
+ see: src/base/reconstruction.cc
227
+ void Reconstruction::ReadImagesBinary(const std::string& path)
228
+ void Reconstruction::WriteImagesBinary(const std::string& path)
229
+ """
230
+ images = {}
231
+ with open(path_to_model_file, "rb") as fid:
232
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
233
+ for image_index in range(num_reg_images):
234
+ binary_image_properties = read_next_bytes(
235
+ fid, num_bytes=64, format_char_sequence="idddddddi")
236
+ image_id = binary_image_properties[0]
237
+ qvec = np.array(binary_image_properties[1:5])
238
+ tvec = np.array(binary_image_properties[5:8])
239
+ camera_id = binary_image_properties[8]
240
+ image_name = ""
241
+ current_char = read_next_bytes(fid, 1, "c")[0]
242
+ while current_char != b"\x00": # look for the ASCII 0 entry
243
+ image_name += current_char.decode("utf-8")
244
+ current_char = read_next_bytes(fid, 1, "c")[0]
245
+ num_points2D = read_next_bytes(fid, num_bytes=8,
246
+ format_char_sequence="Q")[0]
247
+ x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D,
248
+ format_char_sequence="ddq" * num_points2D)
249
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
250
+ tuple(map(float, x_y_id_s[1::3]))])
251
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
252
+ images[image_id] = Image(
253
+ id=image_id, qvec=qvec, tvec=tvec,
254
+ camera_id=camera_id, name=image_name,
255
+ xys=xys, point3D_ids=point3D_ids)
256
+ return images
257
+
258
+
259
+ def write_images_text(images, path):
260
+ """
261
+ see: src/base/reconstruction.cc
262
+ void Reconstruction::ReadImagesText(const std::string& path)
263
+ void Reconstruction::WriteImagesText(const std::string& path)
264
+ """
265
+ if len(images) == 0:
266
+ mean_observations = 0
267
+ else:
268
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images)
269
+ HEADER = '# Image list with two lines of data per image:\n'
270
+ '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
271
+ '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
272
+ '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations)
273
+
274
+ with open(path, "w") as fid:
275
+ fid.write(HEADER)
276
+ for _, img in images.items():
277
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
278
+ first_line = " ".join(map(str, image_header))
279
+ fid.write(first_line + "\n")
280
+
281
+ points_strings = []
282
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
283
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
284
+ fid.write(" ".join(points_strings) + "\n")
285
+
286
+
287
+ def write_images_binary(images, path_to_model_file):
288
+ """
289
+ see: src/base/reconstruction.cc
290
+ void Reconstruction::ReadImagesBinary(const std::string& path)
291
+ void Reconstruction::WriteImagesBinary(const std::string& path)
292
+ """
293
+ with open(path_to_model_file, "wb") as fid:
294
+ write_next_bytes(fid, len(images), "Q")
295
+ for _, img in images.items():
296
+ write_next_bytes(fid, img.id, "i")
297
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
298
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
299
+ write_next_bytes(fid, img.camera_id, "i")
300
+ for char in img.name:
301
+ write_next_bytes(fid, char.encode("utf-8"), "c")
302
+ write_next_bytes(fid, b"\x00", "c")
303
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
304
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
305
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
306
+
307
+
308
+ def read_points3D_text(path):
309
+ """
310
+ see: src/base/reconstruction.cc
311
+ void Reconstruction::ReadPoints3DText(const std::string& path)
312
+ void Reconstruction::WritePoints3DText(const std::string& path)
313
+ """
314
+ points3D = {}
315
+ with open(path, "r") as fid:
316
+ while True:
317
+ line = fid.readline()
318
+ if not line:
319
+ break
320
+ line = line.strip()
321
+ if len(line) > 0 and line[0] != "#":
322
+ elems = line.split()
323
+ point3D_id = int(elems[0])
324
+ xyz = np.array(tuple(map(float, elems[1:4])))
325
+ rgb = np.array(tuple(map(int, elems[4:7])))
326
+ error = float(elems[7])
327
+ image_ids = np.array(tuple(map(int, elems[8::2])))
328
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
329
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
330
+ error=error, image_ids=image_ids,
331
+ point2D_idxs=point2D_idxs)
332
+ return points3D
333
+
334
+
335
+ def read_points3d_binary(path_to_model_file):
336
+ """
337
+ see: src/base/reconstruction.cc
338
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
339
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
340
+ """
341
+ points3D = {}
342
+ with open(path_to_model_file, "rb") as fid:
343
+ num_points = read_next_bytes(fid, 8, "Q")[0]
344
+ for point_line_index in range(num_points):
345
+ binary_point_line_properties = read_next_bytes(
346
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
347
+ point3D_id = binary_point_line_properties[0]
348
+ xyz = np.array(binary_point_line_properties[1:4])
349
+ rgb = np.array(binary_point_line_properties[4:7])
350
+ error = np.array(binary_point_line_properties[7])
351
+ track_length = read_next_bytes(
352
+ fid, num_bytes=8, format_char_sequence="Q")[0]
353
+ track_elems = read_next_bytes(
354
+ fid, num_bytes=8 * track_length,
355
+ format_char_sequence="ii" * track_length)
356
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
357
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
358
+ points3D[point3D_id] = Point3D(
359
+ id=point3D_id, xyz=xyz, rgb=rgb,
360
+ error=error, image_ids=image_ids,
361
+ point2D_idxs=point2D_idxs)
362
+ return points3D
363
+
364
+
365
+ def write_points3D_text(points3D, path):
366
+ """
367
+ see: src/base/reconstruction.cc
368
+ void Reconstruction::ReadPoints3DText(const std::string& path)
369
+ void Reconstruction::WritePoints3DText(const std::string& path)
370
+ """
371
+ if len(points3D) == 0:
372
+ mean_track_length = 0
373
+ else:
374
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D)
375
+ HEADER = '# 3D point list with one line of data per point:\n'
376
+ '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n'
377
+ '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length)
378
+
379
+ with open(path, "w") as fid:
380
+ fid.write(HEADER)
381
+ for _, pt in points3D.items():
382
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
383
+ fid.write(" ".join(map(str, point_header)) + " ")
384
+ track_strings = []
385
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
386
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
387
+ fid.write(" ".join(track_strings) + "\n")
388
+
389
+
390
+ def write_points3d_binary(points3D, path_to_model_file):
391
+ """
392
+ see: src/base/reconstruction.cc
393
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
394
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
395
+ """
396
+ with open(path_to_model_file, "wb") as fid:
397
+ write_next_bytes(fid, len(points3D), "Q")
398
+ for _, pt in points3D.items():
399
+ write_next_bytes(fid, pt.id, "Q")
400
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
401
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
402
+ write_next_bytes(fid, pt.error, "d")
403
+ track_length = pt.image_ids.shape[0]
404
+ write_next_bytes(fid, track_length, "Q")
405
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
406
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
407
+
408
+
409
+ def read_model(path, ext):
410
+ if ext == ".txt":
411
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
412
+ images = read_images_text(os.path.join(path, "images" + ext))
413
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
414
+ else:
415
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
416
+ images = read_images_binary(os.path.join(path, "images" + ext))
417
+ points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
418
+ return cameras, images, points3D
419
+
420
+
421
+ def write_model(cameras, images, points3D, path, ext):
422
+ if ext == ".txt":
423
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
424
+ write_images_text(images, os.path.join(path, "images" + ext))
425
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
426
+ else:
427
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
428
+ write_images_binary(images, os.path.join(path, "images" + ext))
429
+ write_points3d_binary(points3D, os.path.join(path, "points3D") + ext)
430
+ return cameras, images, points3D
431
+
432
+
433
+ def read_compressed_images_binary(path_to_model_file):
434
+ """
435
+ see: src/base/reconstruction.cc
436
+ void Reconstruction::ReadImagesBinary(const std::string& path)
437
+ void Reconstruction::WriteImagesBinary(const std::string& path)
438
+ """
439
+ images = {}
440
+ with open(path_to_model_file, "rb") as fid:
441
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
442
+ for image_index in range(num_reg_images):
443
+ binary_image_properties = read_next_bytes(
444
+ fid, num_bytes=64, format_char_sequence="idddddddi")
445
+ image_id = binary_image_properties[0]
446
+ qvec = np.array(binary_image_properties[1:5])
447
+ tvec = np.array(binary_image_properties[5:8])
448
+ camera_id = binary_image_properties[8]
449
+ image_name = ""
450
+ current_char = read_next_bytes(fid, 1, "c")[0]
451
+ while current_char != b"\x00": # look for the ASCII 0 entry
452
+ image_name += current_char.decode("utf-8")
453
+ current_char = read_next_bytes(fid, 1, "c")[0]
454
+ num_points2D = read_next_bytes(fid, num_bytes=8,
455
+ format_char_sequence="Q")[0]
456
+ # x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D,
457
+ # format_char_sequence="ddq" * num_points2D)
458
+ # xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
459
+ # tuple(map(float, x_y_id_s[1::3]))])
460
+ x_y_id_s = read_next_bytes(fid, num_bytes=8 * num_points2D,
461
+ format_char_sequence="q" * num_points2D)
462
+ point3D_ids = np.array(x_y_id_s)
463
+ images[image_id] = Image(
464
+ id=image_id, qvec=qvec, tvec=tvec,
465
+ camera_id=camera_id, name=image_name,
466
+ xys=np.array([]), point3D_ids=point3D_ids)
467
+ return images
468
+
469
+
470
+ def write_compressed_images_binary(images, path_to_model_file):
471
+ """
472
+ see: src/base/reconstruction.cc
473
+ void Reconstruction::ReadImagesBinary(const std::string& path)
474
+ void Reconstruction::WriteImagesBinary(const std::string& path)
475
+ """
476
+ with open(path_to_model_file, "wb") as fid:
477
+ write_next_bytes(fid, len(images), "Q")
478
+ for _, img in images.items():
479
+ write_next_bytes(fid, img.id, "i")
480
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
481
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
482
+ write_next_bytes(fid, img.camera_id, "i")
483
+ for char in img.name:
484
+ write_next_bytes(fid, char.encode("utf-8"), "c")
485
+ write_next_bytes(fid, b"\x00", "c")
486
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
487
+ for p3d_id in img.point3D_ids:
488
+ write_next_bytes(fid, p3d_id, "q")
489
+ # for xy, p3d_id in zip(img.xys, img.point3D_ids):
490
+ # write_next_bytes(fid, [*xy, p3d_id], "ddq")
491
+
492
+
493
+ def read_compressed_points3d_binary(path_to_model_file):
494
+ """
495
+ see: src/base/reconstruction.cc
496
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
497
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
498
+ """
499
+ points3D = {}
500
+ with open(path_to_model_file, "rb") as fid:
501
+ num_points = read_next_bytes(fid, 8, "Q")[0]
502
+ for point_line_index in range(num_points):
503
+ binary_point_line_properties = read_next_bytes(
504
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
505
+ point3D_id = binary_point_line_properties[0]
506
+ xyz = np.array(binary_point_line_properties[1:4])
507
+ rgb = np.array(binary_point_line_properties[4:7])
508
+ error = np.array(binary_point_line_properties[7])
509
+ track_length = read_next_bytes(
510
+ fid, num_bytes=8, format_char_sequence="Q")[0]
511
+ track_elems = read_next_bytes(
512
+ fid, num_bytes=4 * track_length,
513
+ format_char_sequence="i" * track_length)
514
+ image_ids = np.array(track_elems)
515
+ # point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
516
+ points3D[point3D_id] = Point3D(
517
+ id=point3D_id, xyz=xyz, rgb=rgb,
518
+ error=error, image_ids=image_ids,
519
+ point2D_idxs=np.array([]))
520
+ return points3D
521
+
522
+
523
+ def write_compressed_points3d_binary(points3D, path_to_model_file):
524
+ """
525
+ see: src/base/reconstruction.cc
526
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
527
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
528
+ """
529
+ with open(path_to_model_file, "wb") as fid:
530
+ write_next_bytes(fid, len(points3D), "Q")
531
+ for _, pt in points3D.items():
532
+ write_next_bytes(fid, pt.id, "Q")
533
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
534
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
535
+ write_next_bytes(fid, pt.error, "d")
536
+ track_length = pt.image_ids.shape[0]
537
+ write_next_bytes(fid, track_length, "Q")
538
+ # for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
539
+ # write_next_bytes(fid, [image_id, point2D_id], "ii")
540
+ for image_id in pt.image_ids:
541
+ write_next_bytes(fid, image_id, "i")
542
+
543
+
544
+ def read_compressed_model(path, ext):
545
+ if ext == ".txt":
546
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
547
+ images = read_images_text(os.path.join(path, "images" + ext))
548
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
549
+ else:
550
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
551
+ images = read_compressed_images_binary(os.path.join(path, "images" + ext))
552
+ points3D = read_compressed_points3d_binary(os.path.join(path, "points3D") + ext)
553
+ return cameras, images, points3D
554
+
555
+
556
+ def qvec2rotmat(qvec):
557
+ return np.array([
558
+ [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
559
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
560
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
561
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
562
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
563
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
564
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
565
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
566
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]])
567
+
568
+
569
+ def rotmat2qvec(R):
570
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
571
+ K = np.array([
572
+ [Rxx - Ryy - Rzz, 0, 0, 0],
573
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
574
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
575
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
576
+ eigvals, eigvecs = np.linalg.eigh(K)
577
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
578
+ if qvec[0] < 0:
579
+ qvec *= -1
580
+ return qvec
581
+
582
+
583
+ def intrinsics_from_camera(camera_model, params):
584
+ if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
585
+ fx = fy = params[0]
586
+ cx = params[1]
587
+ cy = params[2]
588
+ elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
589
+ fx = params[0]
590
+ fy = params[1]
591
+ cx = params[2]
592
+ cy = params[3]
593
+ else:
594
+ raise Exception("Camera model not supported")
595
+
596
+ # intrinsics
597
+ K = np.identity(3)
598
+ K[0, 0] = fx
599
+ K[1, 1] = fy
600
+ K[0, 2] = cx
601
+ K[1, 2] = cy
602
+ return K
603
+
604
+
605
+ def main():
606
+ parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models')
607
+ parser.add_argument('input_model', help='path to input model folder')
608
+ parser.add_argument('input_format', choices=['.bin', '.txt'],
609
+ help='input model format')
610
+ parser.add_argument('--output_model', metavar='PATH',
611
+ help='path to output model folder')
612
+ parser.add_argument('--output_format', choices=['.bin', '.txt'],
613
+ help='outut model format', default='.txt')
614
+ args = parser.parse_args()
615
+
616
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
617
+
618
+ print("num_cameras:", len(cameras))
619
+ print("num_images:", len(images))
620
+ print("num_points3D:", len(points3D))
621
+
622
+ if args.output_model is not None:
623
+ write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
624
+
625
+
626
+ if __name__ == "__main__":
627
+ main()
third_party/pram/colmap_utils/utils.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: UTF-8 -*-
third_party/pram/configs/config_train_12scenes_sfd2.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: [ '12Scenes' ]
2
+
3
+ network_1: "segnet"
4
+ network: "segnetvit"
5
+
6
+ local_rank: 0
7
+ gpu: [ 0 ]
8
+
9
+ feature: "sfd2"
10
+ save_path: '/scratches/flyer_2/fx221/exp/pram'
11
+ landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
12
+ dataset_path: "/scratches/flyer_3/fx221/dataset"
13
+ config_path: 'configs/datasets'
14
+
15
+ image_dim: 3
16
+ feat_dim: 128
17
+ min_inliers: 32
18
+ max_inliers: 512
19
+ random_inliers: true
20
+ max_keypoints: 512
21
+ ignore_index: -1
22
+ output_dim: 1024
23
+ output_dim_: 2048
24
+ jitter_params:
25
+ brightness: 0.5
26
+ contrast: 0.5
27
+ saturation: 0.25
28
+ hue: 0.15
29
+ blur: 0
30
+
31
+ scale_params: [ 0.5, 1.0 ]
32
+ pre_load: false
33
+ train: true
34
+ inlier_th: 0.5
35
+ lr: 0.0001
36
+ min_lr: 0.00001
37
+ optimizer: "adamw"
38
+ seg_loss: "cew"
39
+ seg_loss_nx: "cei"
40
+ cls_loss: "ce"
41
+ cls_loss_: "bce"
42
+ ac_fn: "relu"
43
+ norm_fn: "bn"
44
+ workers: 8
45
+ layers: 15
46
+ log_intervals: 50
47
+ eval_n_epoch: 10
48
+ do_eval: false
49
+
50
+ use_mid_feature: true
51
+ norm_desc: false
52
+ with_score: false
53
+ with_aug: true
54
+ with_dist: true
55
+
56
+ batch_size: 32
57
+ its_per_epoch: 1000
58
+ decay_rate: 0.999992
59
+ decay_iter: 60000
60
+ epochs: 500
61
+
62
+ cluster_method: 'birch'
63
+
64
+ weight_path: null
65
+ weight_path_1: '20230719_220620_segnet_L15_T_resnet4x_B32_K1024_relu_bn_od1024_nc193_adamw_cew_md_A_birch/segnet.499.pth'
66
+ weight_path_2: '20240202_145337_segnetvit_L15_T_resnet4x_B32_K512_relu_bn_od1024_nc193_adam_cew_md_A_birch/segnetvit.499.pth'
67
+
68
+ resume_path: null
69
+
70
+ n_class: 193
71
+
72
+ eval_max_keypoints: 1024
73
+
74
+ localization:
75
+ loc_scene_name: [ 'apt1/kitchen' ]
76
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
77
+ seg_k: 20
78
+ threshold: 8
79
+ min_kpts: 128
80
+ min_matches: 4
81
+ min_inliers: 64
82
+ matching_method_: "mnn"
83
+ matching_method_1: "spg"
84
+ matching_method_2: "gm"
85
+ matching_method: "gml"
86
+ matching_method_5: "adagml"
87
+ save: false
88
+ show: true
89
+ show_time: 1
90
+ max_vrf: 1
91
+ with_original: true
92
+ with_extra: false
93
+ with_compress: true
94
+ semantic_matching: true
95
+ do_refinement: true
96
+ refinement_method_: 'matching'
97
+ refinement_method: 'projection'
98
+ pre_filtering_th: 0.95
99
+ covisibility_frame: 20
100
+ refinement_radius: 20
101
+ refinement_nn_ratio: 0.9
102
+ refinement_max_matches: 0
third_party/pram/configs/config_train_7scenes_sfd2.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: [ '7Scenes' ]
2
+
3
+ network: "segnetvit"
4
+
5
+ local_rank: 0
6
+ gpu: [ 0 ]
7
+ # when using ddp, set gpu: [0,1,2,3]
8
+ with_dist: true
9
+
10
+ feature: "sfd2"
11
+ save_path_: '/scratches/flyer_2/fx221/exp/pram'
12
+ save_path: '/scratches/flyer_2/fx221/publications/test_pram/exp'
13
+ landmark_path_: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
14
+ landmark_path: "/scratches/flyer_2/fx221/publications/test_pram/landmakrs/sfd2-gml"
15
+ dataset_path: "/scratches/flyer_3/fx221/dataset"
16
+ config_path: 'configs/datasets'
17
+
18
+ image_dim: 3
19
+ feat_dim: 128
20
+
21
+ min_inliers: 32
22
+ max_inliers: 256
23
+ random_inliers: 1
24
+ max_keypoints: 512
25
+ ignore_index: -1
26
+ output_dim: 1024
27
+ output_dim_: 2048
28
+ jitter_params:
29
+ brightness: 0.5
30
+ contrast: 0.5
31
+ saturation: 0.25
32
+ hue: 0.15
33
+ blur: 0
34
+
35
+ scale_params: [ 0.5, 1.0 ]
36
+ pre_load: false
37
+ train: true
38
+ inlier_th: 0.5
39
+ lr: 0.0001
40
+ min_lr: 0.00001
41
+ cls_loss: "ce"
42
+ ac_fn: "relu"
43
+ norm_fn: "bn"
44
+ workers: 8
45
+ layers: 15
46
+ log_intervals: 50
47
+ eval_n_epoch: 10
48
+ do_eval: false
49
+
50
+ use_mid_feature: true
51
+ norm_desc: false
52
+ with_cls: false
53
+ with_score: false
54
+ with_aug: true
55
+
56
+ batch_size: 32
57
+ its_per_epoch: 1000
58
+ decay_rate: 0.999992
59
+ decay_iter: 80000
60
+ epochs: 200
61
+
62
+ cluster_method: 'birch'
63
+
64
+ weight_path: null
65
+ weight_path_1: '20230724_203230_segnet_L15_S_resnet4x_B32_K1024_relu_bn_od1024_nc113_adam_cew_md_A_birch/segnet.180.pth'
66
+ weight_path_2: '20240202_152519_segnetvit_L15_S_resnet4x_B32_K512_relu_bn_od1024_nc113_adamw_cew_md_A_birch/segnetvit.199.pth'
67
+
68
+ # used for resuming training
69
+ resume_path: null
70
+
71
+ # used for localization
72
+ n_class: 113
73
+
74
+ eval_max_keypoints: 1024
75
+
76
+ localization:
77
+ loc_scene_name: [ 'chess' ]
78
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
79
+
80
+ seg_k: 20
81
+ threshold: 8
82
+ min_kpts: 128
83
+ min_matches: 16
84
+ min_inliers: 32
85
+ matching_method_: "mnn"
86
+ matching_method_1: "spg"
87
+ matching_method_2: "gm"
88
+ matching_method: "gml"
89
+ matching_method_4: "adagml"
90
+ save: false
91
+ show: true
92
+ show_time: 1
93
+ with_original: true
94
+ max_vrf: 1
95
+ with_compress: true
96
+ semantic_matching: true
97
+ do_refinement: true
98
+ pre_filtering_th: 0.95
99
+ refinement_method_: 'matching'
100
+ refinement_method: 'projection'
101
+ covisibility_frame: 20
102
+ refinement_radius: 20
103
+ refinement_nn_ratio: 0.9
104
+ refinement_max_matches: 0
third_party/pram/configs/config_train_aachen_sfd2.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: [ 'Aachen' ]
2
+
3
+ network_: "segnet"
4
+ network: "segnetvit"
5
+ local_rank: 0
6
+ gpu: [ 0 ]
7
+
8
+ feature: "sfd2"
9
+ save_path: '/scratches/flyer_2/fx221/exp/pram'
10
+ landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
11
+ dataset_path: "/scratches/flyer_3/fx221/dataset"
12
+
13
+ config_path: 'configs/datasets'
14
+
15
+ image_dim: 3
16
+ feat_dim: 128
17
+
18
+ min_inliers: 32
19
+ max_inliers: 512
20
+ random_inliers: true
21
+ max_keypoints: 1024
22
+ ignore_index: -1
23
+ output_dim: 1024
24
+ output_dim_: 2048
25
+ jitter_params:
26
+ brightness: 0.5
27
+ contrast: 0.5
28
+ saturation: 0.25
29
+ hue: 0.15
30
+ blur: 0
31
+
32
+ scale_params: [ 0.5, 1.0 ]
33
+ pre_load: false
34
+ do_eval: true
35
+ train: true
36
+ inlier_th: 0.5
37
+ lr: 0.0001
38
+ min_lr: 0.00001
39
+ optimizer: "adam"
40
+ seg_loss: "cew"
41
+ seg_loss_nx: "cei"
42
+ cls_loss: "ce"
43
+ cls_loss_: "bce"
44
+ ac_fn: "relu"
45
+ norm_fn: "bn"
46
+ workers: 8
47
+ layers: 15
48
+ log_intervals: 50
49
+ eval_n_epoch: 10
50
+
51
+ use_mid_feature: true
52
+ norm_desc: false
53
+ with_sc: false
54
+ with_cls: true
55
+ with_score: false
56
+ with_aug: true
57
+ with_dist: true
58
+
59
+ batch_size: 32
60
+ its_per_epoch: 1000
61
+ decay_rate: 0.999992
62
+ decay_iter: 80000
63
+ epochs: 800
64
+
65
+ cluster_method: 'birch'
66
+
67
+ weight_path: null
68
+ weight_path_1: '20230719_221442_segnet_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adamw_cew_md_A_birch/segnet.899.pth'
69
+ weight_path_2: '20240211_142623_segnetvit_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adam_cew_md_A_birch/segnetvit.799.pth'
70
+ resume_path: null
71
+
72
+ n_class: 513
73
+
74
+ eval_max_keypoints: 4096
75
+
76
+ localization:
77
+ loc_scene_name: [ ]
78
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
79
+ seg_k: 10
80
+ threshold: 12
81
+ min_kpts: 256
82
+ min_matches: 8
83
+ min_inliers: 128
84
+ matching_method_: "mnn"
85
+ matching_method_1: "spg"
86
+ matching_method_2: "gm"
87
+ matching_method: "gml"
88
+ matching_method_4: "adagml"
89
+ save: false
90
+ show: true
91
+ show_time: 1
92
+ with_original: true
93
+ with_extra: false
94
+ max_vrf: 1
95
+ with_compress: true
96
+ semantic_matching: true
97
+ refinement_method_: 'matching'
98
+ refinement_method: 'projection'
99
+ pre_filtering_th: 0.95
100
+ do_refinement: true
101
+ covisibility_frame: 50
102
+ refinement_radius: 30
103
+ refinement_nn_ratio: 0.9
104
+ refinement_max_matches: 0
third_party/pram/configs/config_train_cambridge_sfd2.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: [ 'CambridgeLandmarks' ]
2
+
3
+ network_: "segnet"
4
+ network: "segnetvit"
5
+
6
+ local_rank: 0
7
+ gpu: [ 0 ]
8
+
9
+ feature: "sfd2"
10
+ save_path: '/scratches/flyer_2/fx221/exp/pram'
11
+ landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
12
+ dataset_path: "/scratches/flyer_3/fx221/dataset"
13
+ config_path: 'configs/datasets'
14
+
15
+ image_dim: 3
16
+ feat_dim: 128
17
+
18
+ min_inliers: 32
19
+ max_inliers: 512
20
+ random_inliers: 1
21
+ max_keypoints: 1024
22
+ ignore_index: -1
23
+ output_dim: 1024
24
+ output_dim_: 2048
25
+ jitter_params:
26
+ brightness: 0.5
27
+ contrast: 0.5
28
+ saturation: 0.25
29
+ hue: 0.15
30
+ blur: 0
31
+
32
+ scale_params: [ 0.5, 1.0 ]
33
+ pre_load: false
34
+ do_eval: false
35
+ train: true
36
+ inlier_th: 0.5
37
+ lr: 0.0001
38
+ min_lr: 0.00001
39
+ epochs: 300
40
+ seg_loss: "cew"
41
+ ac_fn: "relu"
42
+ norm_fn: "bn"
43
+ workers: 8
44
+ layers: 15
45
+ log_intervals: 50
46
+ eval_n_epoch: 10
47
+
48
+ use_mid_feature: true
49
+ norm_desc: false
50
+ with_score: false
51
+ with_aug: true
52
+ with_dist: true
53
+
54
+ batch_size: 32
55
+ its_per_epoch: 1000
56
+ decay_rate: 0.999992
57
+ decay_iter: 60000
58
+
59
+ cluster_method: 'birch'
60
+
61
+ weight_path: null
62
+ weight_path_1: '20230725_144044_segnet_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adam_cew_md_A_birch/segnet.260.pth'
63
+ weight_path_2: '20240204_130323_segnetvit_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adamw_cew_md_A_birch/segnetvit.399.pth'
64
+
65
+ resume_path: null
66
+
67
+ n_class: 161
68
+
69
+ eval_max_keypoints: 2048
70
+
71
+ localization:
72
+ loc_scene_name_1: [ 'GreatCourt' ]
73
+ loc_scene_name_2: [ 'KingsCollege' ]
74
+ loc_scene_name: [ 'StMarysChurch' ]
75
+ loc_scene_name_4: [ 'OldHospital' ]
76
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
77
+ seg_k: 30
78
+ threshold: 12
79
+ min_kpts: 256
80
+ min_matches: 16
81
+ min_inliers_gm: 128
82
+ min_inliers: 128
83
+ matching_method_: "mnn"
84
+ matching_method_1: "spg"
85
+ matching_method_2: "gm"
86
+ matching_method: "gml"
87
+ matching_method_4: "adagml"
88
+ show: true
89
+ show_time: 1
90
+ save: false
91
+ with_original: true
92
+ max_vrf: 1
93
+ with_extra: false
94
+ with_compress: true
95
+ semantic_matching: true
96
+ do_refinement: true
97
+ pre_filtering_th: 0.95
98
+ refinement_method_: 'matching'
99
+ refinement_method: 'projection'
100
+ covisibility_frame: 20
101
+ refinement_radius: 20
102
+ refinement_nn_ratio: 0.9
103
+ refinement_max_matches: 0
third_party/pram/configs/config_train_multiset_sfd2.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: [ 'S', 'T', 'C', 'A' ]
2
+
3
+ network: "segnet"
4
+ network_: "gsegnet3"
5
+
6
+ local_rank: 0
7
+ gpu: [ 4 ]
8
+
9
+ feature: "resnet4x"
10
+ save_path: '/scratches/flyer_2/fx221/exp/localizer'
11
+ landmark_path: "/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm"
12
+ dataset_path: "/scratches/flyer_3/fx221/dataset"
13
+ config_path: 'configs/datasets'
14
+
15
+ image_dim: 3
16
+ min_inliers: 32
17
+ max_inliers: 512
18
+ random_inliers: 1
19
+ max_keypoints: 1024
20
+ ignore_index: -1
21
+ output_dim: 1024
22
+ output_dim_: 2048
23
+ jitter_params:
24
+ brightness: 0.5
25
+ contrast: 0.5
26
+ saturation: 0.25
27
+ hue: 0.15
28
+ blur: 0
29
+
30
+ scale_params: [ 0.5, 1.0 ]
31
+ pre_load: false
32
+ do_eval: true
33
+ train: true
34
+ inlier_th: 0.5
35
+ lr: 0.0001
36
+ min_lr: 0.00001
37
+ optimizer: "adam"
38
+ seg_loss: "cew"
39
+ seg_loss_nx: "cei"
40
+ cls_loss: "ce"
41
+ cls_loss_: "bce"
42
+ sc_loss: 'l1g'
43
+ ac_fn: "relu"
44
+ norm_fn: "bn"
45
+ workers: 8
46
+ layers: 15
47
+ log_intervals: 50
48
+ eval_n_epoch: 10
49
+
50
+ use_mid_feature: true
51
+ norm_desc: false
52
+ with_sc: false
53
+ with_cls: true
54
+ with_score: false
55
+ with_aug: true
56
+ with_dist: true
57
+
58
+ batch_size: 32
59
+ its_per_epoch: 1000
60
+ decay_rate: 0.999992
61
+ decay_iter: 150000
62
+ epochs: 1500
63
+
64
+ cluster_method_: 'kmeans'
65
+ cluster_method: 'birch'
66
+
67
+ weight_path_: null
68
+ weight_path: '20230805_132653_segnet_L15_STCA_resnet4x_B32_K1024_relu_bn_od1024_nc977_adam_cew_md_A_birch/segnet.485.pth'
69
+ resume_path: null
70
+
71
+ eval: false
72
+ #loc: false
73
+ loc: true
74
+ #n_class: 977
75
+ online: false
76
+
77
+ eval_max_keypoints: 4096
78
+
79
+ localization:
80
+ loc_scene_name: [ ]
81
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
82
+ dataset: [ 'T' ]
83
+ seg_k: 50
84
+ threshold: 8 # 8 for indoor, 12 for outdoor
85
+ min_kpts: 256
86
+ min_matches: 4
87
+ min_inliers: 64
88
+ matching_method_: "mnn"
89
+ matching_method_1: "spg"
90
+ matching_method: "gm"
91
+ save: false
92
+ show: true
93
+ show_time: 1
94
+ do_refinement: true
95
+ with_original: true
96
+ with_extra: false
97
+ max_vrf: 1
98
+ with_compress: false
99
+ covisibility_frame: 20
100
+ observation_threshold: 3
third_party/pram/configs/datasets/12Scenes.yaml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: '12Scenes'
2
+ scenes: [ 'apt1/kitchen',
3
+ 'apt1/living',
4
+ 'apt2/bed',
5
+ 'apt2/kitchen',
6
+ 'apt2/living',
7
+ 'apt2/luke',
8
+ 'office1/gates362',
9
+ 'office1/gates381',
10
+ 'office1/lounge',
11
+ 'office1/manolis',
12
+ 'office2/5a',
13
+ 'office2/5b'
14
+ ]
15
+
16
+ apt1/kitchen:
17
+ n_cluster: 16
18
+ cluster_mode: 'xy'
19
+ cluster_method: 'birch'
20
+
21
+ training_sample_ratio: 1
22
+ eval_sample_ratio: 5
23
+ query_path: 'queries_with_intrinsics.txt'
24
+ gt_pose_path: 'queries_poses.txt'
25
+ image_path_prefix: ''
26
+
27
+
28
+ apt1/living:
29
+ n_cluster: 16
30
+ cluster_mode: 'xy'
31
+ cluster_method: 'birch'
32
+
33
+ training_sample_ratio: 1
34
+ eval_sample_ratio: 5
35
+ image_path_prefix: ''
36
+ query_path: 'queries_with_intrinsics.txt'
37
+ gt_pose_path: 'queries_poses.txt'
38
+
39
+ apt2/bed:
40
+ n_cluster: 16
41
+ cluster_mode: 'xy'
42
+ cluster_method: 'birch'
43
+
44
+ training_sample_ratio: 1
45
+ eval_sample_ratio: 5
46
+ image_path_prefix: ''
47
+
48
+ query_path: 'queries_with_intrinsics.txt'
49
+ gt_pose_path: 'queries_poses.txt'
50
+
51
+
52
+ apt2/kitchen:
53
+ n_cluster: 16
54
+ cluster_mode: 'xy'
55
+ cluster_method: 'birch'
56
+
57
+ training_sample_ratio: 1
58
+ eval_sample_ratio: 5
59
+ image_path_prefix: ''
60
+
61
+ query_path: 'queries_with_intrinsics.txt'
62
+ gt_pose_path: 'queries_poses.txt'
63
+
64
+
65
+ apt2/living:
66
+ n_cluster: 16
67
+ cluster_mode: 'xy'
68
+ cluster_method: 'birch'
69
+
70
+ training_sample_ratio: 1
71
+ eval_sample_ratio: 5
72
+ image_path_prefix: ''
73
+
74
+ query_path: 'queries_with_intrinsics.txt'
75
+ gt_pose_path: 'queries_poses.txt'
76
+
77
+
78
+ apt2/luke:
79
+ n_cluster: 16
80
+ cluster_mode: 'xy'
81
+ cluster_method: 'birch'
82
+
83
+ training_sample_ratio: 1
84
+ eval_sample_ratio: 5
85
+ image_path_prefix: ''
86
+
87
+ query_path: 'queries_with_intrinsics.txt'
88
+ gt_pose_path: 'queries_poses.txt'
89
+
90
+
91
+ office1/gates362:
92
+ n_cluster: 16
93
+ cluster_mode: 'xy'
94
+ cluster_method: 'birch'
95
+
96
+ training_sample_ratio: 3
97
+ eval_sample_ratio: 5
98
+ image_path_prefix: ''
99
+
100
+ query_path: 'queries_with_intrinsics.txt'
101
+ gt_pose_path: 'queries_poses.txt'
102
+
103
+
104
+ office1/gates381:
105
+ n_cluster: 16
106
+ cluster_mode: 'xy'
107
+ cluster_method: 'birch'
108
+
109
+ training_sample_ratio: 3
110
+ eval_sample_ratio: 5
111
+ image_path_prefix: ''
112
+
113
+ query_path: 'queries_with_intrinsics.txt'
114
+ gt_pose_path: 'queries_poses.txt'
115
+
116
+
117
+ office1/lounge:
118
+ n_cluster: 16
119
+ cluster_mode: 'xy'
120
+ cluster_method: 'birch'
121
+
122
+ training_sample_ratio: 1
123
+ eval_sample_ratio: 5
124
+ image_path_prefix: ''
125
+
126
+ query_path: 'queries_with_intrinsics.txt'
127
+ gt_pose_path: 'queries_poses.txt'
128
+
129
+
130
+ office1/manolis:
131
+ n_cluster: 16
132
+ cluster_mode: 'xy'
133
+ cluster_method: 'birch'
134
+
135
+ training_sample_ratio: 1
136
+ eval_sample_ratio: 5
137
+ image_path_prefix: ''
138
+
139
+ query_path: 'queries_with_intrinsics.txt'
140
+ gt_pose_path: 'queries_poses.txt'
141
+
142
+
143
+ office2/5a:
144
+ n_cluster: 16
145
+ cluster_mode: 'xy'
146
+ cluster_method: 'birch'
147
+
148
+ training_sample_ratio: 1
149
+ eval_sample_ratio: 5
150
+ image_path_prefix: ''
151
+
152
+ query_path: 'queries_with_intrinsics.txt'
153
+ gt_pose_path: 'queries_poses.txt'
154
+
155
+
156
+ office2/5b:
157
+ n_cluster: 16
158
+ cluster_mode: 'xy'
159
+ cluster_method: 'birch'
160
+
161
+ training_sample_ratio: 1
162
+ eval_sample_ratio: 5
163
+ image_path_prefix: ''
164
+
165
+ query_path: 'queries_with_intrinsics.txt'
166
+ gt_pose_path: 'queries_poses.txt'
third_party/pram/configs/datasets/7Scenes.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: '7Scenes'
2
+ scenes: [ 'chess', 'heads', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin' ]
3
+
4
+
5
+ chess:
6
+ n_cluster: 16
7
+ cluster_mode: 'xz'
8
+ cluster_method_: 'kmeans'
9
+ cluster_method: 'birch'
10
+
11
+ training_sample_ratio: 2
12
+ eval_sample_ratio: 10
13
+ gt_pose_path: 'queries_poses.txt'
14
+ query_path: 'queries_with_intrinsics.txt'
15
+ image_path_prefix: ''
16
+
17
+
18
+
19
+ heads:
20
+ n_cluster: 16
21
+ cluster_mode: 'xz'
22
+ cluster_method_: 'kmeans'
23
+ cluster_method: 'birch'
24
+
25
+ training_sample_ratio: 1
26
+ eval_sample_ratio: 2
27
+ gt_pose_path: 'queries_poses.txt'
28
+ query_path: 'queries_with_intrinsics.txt'
29
+ image_path_prefix: ''
30
+
31
+
32
+ office:
33
+ n_cluster: 16
34
+ cluster_mode: 'xz'
35
+ cluster_method_: 'kmeans'
36
+ cluster_method: 'birch'
37
+
38
+ training_sample_ratio: 3
39
+ eval_sample_ratio: 10
40
+ gt_pose_path: 'queries_poses.txt'
41
+ query_path: 'queries_with_intrinsics.txt'
42
+ image_path_prefix: ''
43
+
44
+ fire:
45
+ n_cluster: 16
46
+ cluster_mode: 'xz'
47
+ cluster_method_: 'kmeans'
48
+ cluster_method: 'birch'
49
+
50
+ training_sample_ratio: 2
51
+ eval_sample_ratio: 5
52
+ gt_pose_path: 'queries_poses.txt'
53
+ query_path: 'queries_with_intrinsics.txt'
54
+ image_path_prefix: ''
55
+
56
+
57
+ stairs:
58
+ n_cluster: 16
59
+ cluster_mode: 'xz'
60
+ cluster_method_: 'kmeans'
61
+ cluster_method: 'birch'
62
+
63
+ training_sample_ratio: 1
64
+ eval_sample_ratio: 10
65
+ gt_pose_path: 'queries_poses.txt'
66
+ query_path: 'queries_with_intrinsics.txt'
67
+ image_path_prefix: ''
68
+
69
+
70
+ redkitchen:
71
+ n_cluster: 16
72
+ cluster_mode: 'xz'
73
+ cluster_method_: 'kmeans'
74
+ cluster_method: 'birch'
75
+
76
+ training_sample_ratio: 3
77
+ eval_sample_ratio: 10
78
+ gt_pose_path: 'queries_poses.txt'
79
+ query_path: 'queries_with_intrinsics.txt'
80
+ image_path_prefix: ''
81
+
82
+
83
+
84
+
85
+ pumpkin:
86
+ n_cluster: 16
87
+ cluster_mode: 'xz'
88
+ cluster_method_: 'kmeans'
89
+ cluster_method: 'birch'
90
+
91
+ training_sample_ratio: 2
92
+ eval_sample_ratio: 10
93
+ gt_pose_path: 'queries_poses.txt'
94
+ query_path: 'queries_with_intrinsics.txt'
95
+ image_path_prefix: ''
96
+
third_party/pram/configs/datasets/Aachen.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: 'Aachen'
2
+
3
+ scenes: [ 'Aachenv11' ]
4
+
5
+ Aachenv11:
6
+ n_cluster: 512
7
+ cluster_mode: 'xz'
8
+ cluster_method_: 'kmeans'
9
+ cluster_method: 'birch'
10
+ training_sample_ratio: 1
11
+ eval_sample_ratio: 1
12
+ image_path_prefix: 'images/images_upright'
13
+ query_path_: 'queries_with_intrinsics.txt'
14
+ query_path: 'queries_with_intrinsics_demo.txt'
15
+ gt_pose_path: 'queries_pose_spp_spg.txt'
third_party/pram/configs/datasets/CambridgeLandmarks.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: 'CambridgeLandmarks'
2
+ scenes: [ 'GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch' ]
3
+
4
+ GreatCourt:
5
+ n_cluster: 32
6
+ cluster_mode: 'xy'
7
+ cluster_method: 'birch'
8
+
9
+ training_sample_ratio: 1
10
+ eval_sample_ratio: 1
11
+ image_path_prefix: ''
12
+ query_path: 'queries_with_intrinsics.txt'
13
+ gt_pose_path: 'queries_poses.txt'
14
+
15
+
16
+ KingsCollege:
17
+ n_cluster: 32
18
+ cluster_mode: 'xy'
19
+ cluster_method: 'birch'
20
+
21
+ training_sample_ratio: 1
22
+ eval_sample_ratio: 1
23
+ image_path_prefix: ''
24
+
25
+ query_path: 'queries_with_intrinsics.txt'
26
+ gt_pose_path: 'queries_poses.txt'
27
+
28
+
29
+ OldHospital:
30
+ n_cluster: 32
31
+ cluster_mode: 'xz'
32
+ cluster_method: 'birch'
33
+
34
+ training_sample_ratio: 1
35
+ eval_sample_ratio: 1
36
+ image_path_prefix: ''
37
+ query_path: 'queries_with_intrinsics.txt'
38
+ gt_pose_path: 'queries_poses.txt'
39
+
40
+
41
+ ShopFacade:
42
+ n_cluster: 32
43
+ cluster_mode: 'xy'
44
+ cluster_method: 'birch'
45
+
46
+ training_sample_ratio: 1
47
+ eval_sample_ratio: 1
48
+ image_path_prefix: ''
49
+
50
+ query_path: 'queries_with_intrinsics.txt'
51
+ gt_pose_path: 'queries_poses.txt'
52
+
53
+
54
+ StMarysChurch:
55
+ n_cluster: 32
56
+ cluster_mode: 'xz'
57
+ cluster_method: 'birch'
58
+
59
+ training_sample_ratio: 1
60
+ eval_sample_ratio: 1
61
+ image_path_prefix: ''
62
+
63
+ query_path: 'queries_with_intrinsics.txt'
64
+ gt_pose_path: 'queries_poses.txt'
65
+
66
+
67
+
third_party/pram/dataset/aachen.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> aachen
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:33
7
+ =================================================='''
8
+ import os.path as osp
9
+ import numpy as np
10
+ import cv2
11
+ from colmap_utils.read_write_model import read_model
12
+ import torchvision.transforms as tvt
13
+ from dataset.basicdataset import BasicDataset
14
+
15
+
16
+ class Aachen(BasicDataset):
17
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='Aachen',
18
+ nfeatures=1024,
19
+ query_p3d_fn=None,
20
+ train=True,
21
+ with_aug=False,
22
+ min_inliers=0,
23
+ max_inliers=4096,
24
+ random_inliers=False,
25
+ jitter_params=None,
26
+ scale_params=None,
27
+ image_dim=3,
28
+ query_info_path=None,
29
+ sample_ratio=1, ):
30
+ self.landmark_path = osp.join(landmark_path, scene)
31
+ self.dataset_path = osp.join(dataset_path, scene)
32
+ self.n_class = n_class
33
+ self.dataset = dataset + '/' + scene
34
+ self.nfeatures = nfeatures
35
+ self.with_aug = with_aug
36
+ self.jitter_params = jitter_params
37
+ self.scale_params = scale_params
38
+ self.image_dim = image_dim
39
+ self.train = train
40
+ self.min_inliers = min_inliers
41
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
42
+ self.random_inliers = random_inliers
43
+ self.image_prefix = 'images/images_upright'
44
+
45
+ train_transforms = []
46
+ if self.with_aug:
47
+ train_transforms.append(tvt.ColorJitter(
48
+ brightness=jitter_params['brightness'],
49
+ contrast=jitter_params['contrast'],
50
+ saturation=jitter_params['saturation'],
51
+ hue=jitter_params['hue']))
52
+ if jitter_params['blur'] > 0:
53
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
54
+ self.train_transforms = tvt.Compose(train_transforms)
55
+
56
+ if train:
57
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
58
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
59
+
60
+ # only for testing of query images
61
+ if not self.train:
62
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
63
+ self.img_p3d = data
64
+ else:
65
+ self.img_p3d = {}
66
+
67
+ self.img_fns = []
68
+ if train:
69
+ with open(osp.join(self.dataset_path, 'aachen_db_imglist.txt'), 'r') as f:
70
+ lines = f.readlines()
71
+ for l in lines:
72
+ l = l.strip()
73
+ if l not in self.name_to_id.keys():
74
+ continue
75
+ self.img_fns.append(l)
76
+ else:
77
+ with open(osp.join(self.dataset_path, 'queries', 'day_time_queries_with_intrinsics.txt'), 'r') as f:
78
+ lines = f.readlines()
79
+ for l in lines:
80
+ l = l.strip().split()[0]
81
+ if l not in self.img_p3d.keys():
82
+ continue
83
+ self.img_fns.append(l)
84
+ with open(osp.join(self.dataset_path, 'queries', 'night_time_queries_with_intrinsics.txt'), 'r') as f:
85
+ lines = f.readlines()
86
+ for l in lines:
87
+ l = l.strip().split()[0]
88
+ if l not in self.img_p3d.keys():
89
+ continue
90
+ self.img_fns.append(l)
91
+
92
+ print(
93
+ 'Load {} images from {} for {}...'.format(len(self.img_fns), self.dataset, 'training' if train else 'eval'))
94
+
95
+ data = np.load(osp.join(self.landmark_path,
96
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
97
+ allow_pickle=True)[()]
98
+ p3d_id = data['id']
99
+ seg_id = data['label']
100
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
101
+ xyzs = data['xyz']
102
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
103
+
104
+ with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
105
+ lines = f.readlines()
106
+ for l in lines:
107
+ l = l.strip().split()
108
+ self.mean_xyz = np.array([float(v) for v in l[:3]])
109
+ self.scale_xyz = np.array([float(v) for v in l[3:]])
110
+
111
+ if not train:
112
+ self.query_info = self.read_query_info(path=query_info_path)
113
+
114
+ self.nfeatures = nfeatures
115
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
116
+ self.feats = {}
117
+
118
+ def read_image(self, image_name):
119
+ return cv2.imread(osp.join(self.dataset_path, 'images/images_upright/', image_name))
third_party/pram/dataset/basicdataset.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> basicdataset
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:27
7
+ =================================================='''
8
+ import torchvision.transforms.functional as tvf
9
+ import torchvision.transforms as tvt
10
+ import os.path as osp
11
+ import numpy as np
12
+ import cv2
13
+ from colmap_utils.read_write_model import qvec2rotmat, read_model
14
+ from dataset.utils import normalize_size
15
+
16
+
17
+ class BasicDataset:
18
+ def __init__(self,
19
+ img_list_fn,
20
+ feature_dir,
21
+ sfm_path,
22
+ seg_fn,
23
+ dataset_path,
24
+ n_class,
25
+ dataset,
26
+ nfeatures=1024,
27
+ query_p3d_fn=None,
28
+ train=True,
29
+ with_aug=False,
30
+ min_inliers=0,
31
+ max_inliers=4096,
32
+ random_inliers=False,
33
+ jitter_params=None,
34
+ scale_params=None,
35
+ image_dim=1,
36
+ pre_load=False,
37
+ query_info_path=None,
38
+ sc_mean_scale_fn=None,
39
+ ):
40
+ self.n_class = n_class
41
+ self.train = train
42
+ self.min_inliers = min_inliers
43
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
44
+ self.random_inliers = random_inliers
45
+ self.dataset_path = dataset_path
46
+ self.with_aug = with_aug
47
+ self.dataset = dataset
48
+ self.jitter_params = jitter_params
49
+ self.scale_params = scale_params
50
+ self.image_dim = image_dim
51
+ self.image_prefix = ''
52
+
53
+ train_transforms = []
54
+ if self.with_aug:
55
+ train_transforms.append(tvt.ColorJitter(
56
+ brightness=jitter_params['brightness'],
57
+ contrast=jitter_params['contrast'],
58
+ saturation=jitter_params['saturation'],
59
+ hue=jitter_params['hue']))
60
+ if jitter_params['blur'] > 0:
61
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
62
+ self.train_transforms = tvt.Compose(train_transforms)
63
+
64
+ # only for testing of query images
65
+ if not self.train:
66
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
67
+ self.img_p3d = data
68
+ else:
69
+ self.img_p3d = {}
70
+
71
+ self.img_fns = []
72
+ with open(img_list_fn, 'r') as f:
73
+ lines = f.readlines()
74
+ for l in lines:
75
+ l = l.strip()
76
+ self.img_fns.append(l)
77
+ print('Load {} images from {} for {}...'.format(len(self.img_fns), dataset, 'training' if train else 'eval'))
78
+ self.feats = {}
79
+ if train:
80
+ self.cameras, self.images, point3Ds = read_model(path=sfm_path, ext='.bin')
81
+ self.name_to_id = {image.name: i for i, image in self.images.items()}
82
+
83
+ data = np.load(seg_fn, allow_pickle=True)[()]
84
+ p3d_id = data['id']
85
+ seg_id = data['label']
86
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
87
+ self.p3d_xyzs = {}
88
+
89
+ for pid in self.p3d_seg.keys():
90
+ p3d = point3Ds[pid]
91
+ self.p3d_xyzs[pid] = p3d.xyz
92
+
93
+ with open(sc_mean_scale_fn, 'r') as f:
94
+ lines = f.readlines()
95
+ for l in lines:
96
+ l = l.strip().split()
97
+ self.mean_xyz = np.array([float(v) for v in l[:3]])
98
+ self.scale_xyz = np.array([float(v) for v in l[3:]])
99
+
100
+ if not train:
101
+ self.query_info = self.read_query_info(path=query_info_path)
102
+
103
+ self.nfeatures = nfeatures
104
+ self.feature_dir = feature_dir
105
+ print('Pre loaded {} feats, mean xyz {}, scale xyz {}'.format(len(self.feats.keys()), self.mean_xyz,
106
+ self.scale_xyz))
107
+
108
+ def normalize_p3ds(self, p3ds):
109
+ mean_p3ds = np.ceil(np.mean(p3ds, axis=0))
110
+ p3ds_ = p3ds - mean_p3ds
111
+ dx = np.max(abs(p3ds_[:, 0]))
112
+ dy = np.max(abs(p3ds_[:, 1]))
113
+ dz = np.max(abs(p3ds_[:, 2]))
114
+ scale_p3ds = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, ))
115
+ scale_p3ds[scale_p3ds < 1] = 1
116
+ scale_p3ds[scale_p3ds == 0] = 1
117
+ return mean_p3ds, scale_p3ds
118
+
119
+ def read_query_info(self, path):
120
+ query_info = {}
121
+ with open(path, 'r') as f:
122
+ lines = f.readlines()
123
+ for l in lines:
124
+ l = l.strip().split()
125
+ image_name = l[0]
126
+ cam_model = l[1]
127
+ h, w = int(l[2]), int(l[3])
128
+ params = np.array([float(v) for v in l[4:]])
129
+ query_info[image_name] = {
130
+ 'width': w,
131
+ 'height': h,
132
+ 'model': cam_model,
133
+ 'params': params,
134
+ }
135
+ return query_info
136
+
137
+ def extract_intrinsic_extrinsic_params(self, image_id):
138
+ cam = self.cameras[self.images[image_id].camera_id]
139
+ params = cam.params
140
+ model = cam.model
141
+ if model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
142
+ fx = fy = params[0]
143
+ cx = params[1]
144
+ cy = params[2]
145
+ elif model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
146
+ fx = params[0]
147
+ fy = params[1]
148
+ cx = params[2]
149
+ cy = params[3]
150
+ else:
151
+ raise Exception("Camera model not supported")
152
+ K = np.eye(3, dtype=float)
153
+ K[0, 0] = fx
154
+ K[1, 1] = fy
155
+ K[0, 2] = cx
156
+ K[1, 2] = cy
157
+
158
+ qvec = self.images[image_id].qvec
159
+ tvec = self.images[image_id].tvec
160
+ R = qvec2rotmat(qvec=qvec)
161
+ P = np.eye(4, dtype=float)
162
+ P[:3, :3] = R
163
+ P[:3, 3] = tvec.reshape(3, )
164
+
165
+ return {'K': K, 'P': P}
166
+
167
+ def get_item_train(self, idx):
168
+ img_name = self.img_fns[idx]
169
+ if img_name in self.feats.keys():
170
+ feat_data = self.feats[img_name]
171
+ else:
172
+ feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()]
173
+ # descs = feat_data['descriptors'] # [N, D]
174
+ scores = feat_data['scores'] # [N, 1]
175
+ kpts = feat_data['keypoints'] # [N, 2]
176
+ image_size = feat_data['image_size']
177
+
178
+ nfeat = kpts.shape[0]
179
+
180
+ # print(img_name, self.name_to_id[img_name])
181
+ p3d_ids = self.images[self.name_to_id[img_name]].point3D_ids
182
+ p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float)
183
+
184
+ seg_ids = np.zeros(shape=(nfeat,), dtype=int) # + self.n_class - 1
185
+ for i in range(nfeat):
186
+ p3d = p3d_ids[i]
187
+ if p3d in self.p3d_seg.keys():
188
+ seg_ids[i] = self.p3d_seg[p3d] + 1 # 0 for invalid
189
+ if seg_ids[i] == -1:
190
+ seg_ids[i] = 0
191
+
192
+ if p3d in self.p3d_xyzs.keys():
193
+ p3d_xyzs[i] = self.p3d_xyzs[p3d]
194
+
195
+ seg_ids = np.array(seg_ids).reshape(-1, )
196
+
197
+ n_inliers = np.sum(seg_ids > 0)
198
+ n_outliers = np.sum(seg_ids == 0)
199
+ inlier_ids = np.where(seg_ids > 0)[0]
200
+ outlier_ids = np.where(seg_ids == 0)[0]
201
+
202
+ if n_inliers <= self.min_inliers:
203
+ sel_inliers = n_inliers
204
+ sel_outliers = self.nfeatures - sel_inliers
205
+
206
+ out_ids = np.arange(n_outliers)
207
+ np.random.shuffle(out_ids)
208
+ sel_ids = np.hstack([inlier_ids, outlier_ids[out_ids[:self.nfeatures - n_inliers]]])
209
+ else:
210
+ sel_inliers = np.random.randint(self.min_inliers, self.max_inliers)
211
+ if sel_inliers > n_inliers:
212
+ sel_inliers = n_inliers
213
+
214
+ if sel_inliers + n_outliers < self.nfeatures:
215
+ sel_inliers = self.nfeatures - n_outliers
216
+
217
+ sel_outliers = self.nfeatures - sel_inliers
218
+
219
+ in_ids = np.arange(n_inliers)
220
+ np.random.shuffle(in_ids)
221
+ sel_inlier_ids = inlier_ids[in_ids[:sel_inliers]]
222
+
223
+ out_ids = np.arange(n_outliers)
224
+ np.random.shuffle(out_ids)
225
+ sel_outlier_ids = outlier_ids[out_ids[:sel_outliers]]
226
+
227
+ sel_ids = np.hstack([sel_inlier_ids, sel_outlier_ids])
228
+
229
+ # sel_descs = descs[sel_ids]
230
+ sel_scores = scores[sel_ids]
231
+ sel_kpts = kpts[sel_ids]
232
+ sel_seg_ids = seg_ids[sel_ids]
233
+ sel_xyzs = p3d_xyzs[sel_ids]
234
+
235
+ shuffle_ids = np.arange(sel_ids.shape[0])
236
+ np.random.shuffle(shuffle_ids)
237
+ # sel_descs = sel_descs[shuffle_ids]
238
+ sel_scores = sel_scores[shuffle_ids]
239
+ sel_kpts = sel_kpts[shuffle_ids]
240
+ sel_seg_ids = sel_seg_ids[shuffle_ids]
241
+ sel_xyzs = sel_xyzs[shuffle_ids]
242
+
243
+ if sel_kpts.shape[0] < self.nfeatures:
244
+ # print(sel_descs.shape, sel_kpts.shape, sel_scores.shape, sel_seg_ids.shape, sel_xyzs.shape)
245
+ valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0]) if sel_seg_ids[v] > 0], dtype=int)
246
+ # ref_sel_id = np.random.choice(valid_sel_ids, size=1)[0]
247
+ if valid_sel_ids.shape[0] == 0:
248
+ valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0])], dtype=int)
249
+ random_n = self.nfeatures - sel_kpts.shape[0]
250
+ random_scores = np.random.random((random_n,))
251
+ random_kpts, random_seg_ids, random_xyzs = self.random_points_from_reference(
252
+ n=random_n,
253
+ ref_kpts=sel_kpts[valid_sel_ids],
254
+ ref_segs=sel_seg_ids[valid_sel_ids],
255
+ ref_xyzs=sel_xyzs[valid_sel_ids],
256
+ radius=5,
257
+ )
258
+ # sel_descs = np.vstack([sel_descs, random_descs])
259
+ sel_scores = np.hstack([sel_scores, random_scores])
260
+ sel_kpts = np.vstack([sel_kpts, random_kpts])
261
+ sel_seg_ids = np.hstack([sel_seg_ids, random_seg_ids])
262
+ sel_xyzs = np.vstack([sel_xyzs, random_xyzs])
263
+
264
+ gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
265
+ gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
266
+ gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
267
+ uids = np.unique(sel_seg_ids).tolist()
268
+ for uid in uids:
269
+ if uid == 0:
270
+ continue
271
+ gt_cls[uid] = 1
272
+ gt_n_seg[uid] = np.sum(sel_seg_ids == uid)
273
+ gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum(seg_ids > 0) # [valid_id / total_valid_id]
274
+
275
+ param_out = self.extract_intrinsic_extrinsic_params(image_id=self.name_to_id[img_name])
276
+
277
+ img = self.read_image(image_name=img_name)
278
+ image_size = img.shape[:2]
279
+ if self.image_dim == 1:
280
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
281
+ else:
282
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
283
+ if self.with_aug:
284
+ nh = img.shape[0]
285
+ nw = img.shape[1]
286
+ if self.scale_params is not None:
287
+ do_scale = np.random.random()
288
+ if do_scale <= 0.25:
289
+ p = np.random.randint(0, 11)
290
+ s = self.scale_params[0] + (self.scale_params[1] - self.scale_params[0]) / 10 * p
291
+ nh = int(img.shape[0] * s)
292
+ nw = int(img.shape[1] * s)
293
+ sh = nh / img.shape[0]
294
+ sw = nw / img.shape[1]
295
+ sel_kpts[:, 0] = sel_kpts[:, 0] * sw
296
+ sel_kpts[:, 1] = sel_kpts[:, 1] * sh
297
+ img = cv2.resize(img, dsize=(nw, nh))
298
+
299
+ brightness = np.random.uniform(-self.jitter_params['brightness'], self.jitter_params['brightness']) * 255
300
+ contrast = 1 + np.random.uniform(-self.jitter_params['contrast'], self.jitter_params['contrast'])
301
+ img = cv2.addWeighted(img, contrast, img, 0, brightness)
302
+ img = np.clip(img, a_min=0, a_max=255)
303
+ if self.image_dim == 1:
304
+ img = img[..., None]
305
+ img = img.astype(float) / 255.
306
+ image_size = np.array([nh, nw], dtype=int)
307
+ else:
308
+ if self.image_dim == 1:
309
+ img = img[..., None].astype(float) / 255.
310
+
311
+ output = {
312
+ # 'descriptors': sel_descs, # may not be used
313
+ 'scores': sel_scores,
314
+ 'keypoints': sel_kpts,
315
+ 'norm_keypoints': normalize_size(x=sel_kpts, size=image_size),
316
+ 'image': [img],
317
+ 'gt_seg': sel_seg_ids,
318
+ 'gt_cls': gt_cls,
319
+ 'gt_cls_dist': gt_cls_dist,
320
+ 'gt_n_seg': gt_n_seg,
321
+ 'file_name': img_name,
322
+ 'prefix_name': self.image_prefix,
323
+ # 'mean_xyz': self.mean_xyz,
324
+ # 'scale_xyz': self.scale_xyz,
325
+ # 'gt_sc': sel_xyzs,
326
+ # 'gt_norm_sc': (sel_xyzs - self.mean_xyz) / self.scale_xyz,
327
+ 'K': param_out['K'],
328
+ 'gt_P': param_out['P']
329
+ }
330
+ return output
331
+
332
+ def get_item_test(self, idx):
333
+
334
+ # evaluation of recognition only
335
+ img_name = self.img_fns[idx]
336
+ feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()]
337
+ descs = feat_data['descriptors'] # [N, D]
338
+ scores = feat_data['scores'] # [N, 1]
339
+ kpts = feat_data['keypoints'] # [N, 2]
340
+ image_size = feat_data['image_size']
341
+
342
+ nfeat = descs.shape[0]
343
+
344
+ if img_name in self.img_p3d.keys():
345
+ p3d_ids = self.img_p3d[img_name]
346
+ p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float)
347
+ seg_ids = np.zeros(shape=(nfeat,), dtype=int) # attention! by default invalid!!!
348
+ for i in range(nfeat):
349
+ p3d = p3d_ids[i]
350
+ if p3d in self.p3d_seg.keys():
351
+ seg_ids[i] = self.p3d_seg[p3d] + 1
352
+ if seg_ids[i] == -1:
353
+ seg_ids[i] = 0 # 0 for in valid
354
+
355
+ if p3d in self.p3d_xyzs.keys():
356
+ p3d_xyzs[i] = self.p3d_xyzs[p3d]
357
+
358
+ seg_ids = np.array(seg_ids).reshape(-1, )
359
+
360
+ if self.nfeatures > 0:
361
+ sorted_ids = np.argsort(scores)[::-1][:self.nfeatures] # large to small
362
+ descs = descs[sorted_ids]
363
+ scores = scores[sorted_ids]
364
+ kpts = kpts[sorted_ids]
365
+ p3d_xyzs = p3d_xyzs[sorted_ids]
366
+
367
+ seg_ids = seg_ids[sorted_ids]
368
+
369
+ gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
370
+ gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
371
+ gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
372
+ uids = np.unique(seg_ids).tolist()
373
+ for uid in uids:
374
+ if uid == 0:
375
+ continue
376
+ gt_cls[uid] = 1
377
+ gt_n_seg[uid] = np.sum(seg_ids == uid)
378
+ gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum(
379
+ seg_ids < self.n_class - 1) # [valid_id / total_valid_id]
380
+
381
+ gt_cls[0] = 0
382
+
383
+ img = self.read_image(image_name=img_name)
384
+ if self.image_dim == 1:
385
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
386
+ img = img[..., None].astype(float) / 255.
387
+ else:
388
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(float) / 255.
389
+ return {
390
+ 'descriptors': descs,
391
+ 'scores': scores,
392
+ 'keypoints': kpts,
393
+ 'image_size': image_size,
394
+ 'norm_keypoints': normalize_size(x=kpts, size=image_size),
395
+ 'gt_seg': seg_ids,
396
+ 'gt_cls': gt_cls,
397
+ 'gt_cls_dist': gt_cls_dist,
398
+ 'gt_n_seg': gt_n_seg,
399
+ 'file_name': img_name,
400
+ 'prefix_name': self.image_prefix,
401
+ 'image': [img],
402
+
403
+ 'mean_xyz': self.mean_xyz,
404
+ 'scale_xyz': self.scale_xyz,
405
+ 'gt_sc': p3d_xyzs,
406
+ 'gt_norm_sc': (p3d_xyzs - self.mean_xyz) / self.scale_xyz
407
+ }
408
+
409
+ def __getitem__(self, idx):
410
+ if self.train:
411
+ return self.get_item_train(idx=idx)
412
+ else:
413
+ return self.get_item_test(idx=idx)
414
+
415
+ def __len__(self):
416
+ return len(self.img_fns)
417
+
418
+ def read_image(self, image_name):
419
+ return cv2.imread(osp.join(self.dataset_path, image_name))
420
+
421
+ def jitter_augmentation(self, img, params):
422
+ brightness, contrast, saturation, hue = params
423
+ p = np.random.randint(0, 20) / 20
424
+ b = brightness[0] + (brightness[1] - brightness[0]) / 20 * p
425
+ img = tvf.adjust_brightness(img=img, brightness_factor=b)
426
+
427
+ p = np.random.randint(0, 20) / 20
428
+ c = contrast[0] + (contrast[1] - contrast[0]) / 20 * p
429
+ img = tvf.adjust_contrast(img=img, contrast_factor=c)
430
+
431
+ p = np.random.randint(0, 20) / 20
432
+ s = saturation[0] + (saturation[1] - saturation[0]) / 20 * p
433
+ img = tvf.adjust_saturation(img=img, saturation_factor=s)
434
+
435
+ p = np.random.randint(0, 20) / 20
436
+ h = hue[0] + (hue[1] - hue[0]) / 20 * p
437
+ img = tvf.adjust_hue(img=img, hue_factor=h)
438
+
439
+ return img
440
+
441
+ def random_points(self, n, d, h, w):
442
+ desc = np.random.random((n, d))
443
+ desc = desc / np.linalg.norm(desc, ord=2, axis=1)[..., None]
444
+ xs = np.random.randint(0, w - 1, size=(n, 1))
445
+ ys = np.random.randint(0, h - 1, size=(n, 1))
446
+ kpts = np.hstack([xs, ys])
447
+ return desc, kpts
448
+
449
+ def random_points_from_reference(self, n, ref_kpts, ref_segs, ref_xyzs, radius=5):
450
+ n_ref = ref_kpts.shape[0]
451
+ if n_ref < n:
452
+ ref_ids = np.random.choice([i for i in range(n_ref)], size=n).tolist()
453
+ else:
454
+ ref_ids = [i for i in range(n)]
455
+
456
+ new_xs = []
457
+ new_ys = []
458
+ # new_descs = []
459
+ new_segs = []
460
+ new_xyzs = []
461
+ for i in ref_ids:
462
+ nx = np.random.randint(-radius, radius) + ref_kpts[i, 0]
463
+ ny = np.random.randint(-radius, radius) + ref_kpts[i, 1]
464
+
465
+ new_xs.append(nx)
466
+ new_ys.append(ny)
467
+ # new_descs.append(ref_descs[i])
468
+ new_segs.append(ref_segs[i])
469
+ new_xyzs.append(ref_xyzs[i])
470
+
471
+ new_xs = np.array(new_xs).reshape(n, 1)
472
+ new_ys = np.array(new_ys).reshape(n, 1)
473
+ new_segs = np.array(new_segs).reshape(n, )
474
+ new_kpts = np.hstack([new_xs, new_ys])
475
+ # new_descs = np.array(new_descs).reshape(n, -1)
476
+ new_xyzs = np.array(new_xyzs)
477
+ return new_kpts, new_segs, new_xyzs
third_party/pram/dataset/cambridge_landmarks.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> cambridge_landmarks
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:41
7
+ =================================================='''
8
+ import os.path as osp
9
+ import numpy as np
10
+ from colmap_utils.read_write_model import read_model
11
+ import torchvision.transforms as tvt
12
+ from dataset.basicdataset import BasicDataset
13
+
14
+
15
+ class CambridgeLandmarks(BasicDataset):
16
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='CambridgeLandmarks',
17
+ nfeatures=1024,
18
+ query_p3d_fn=None,
19
+ train=True,
20
+ with_aug=False,
21
+ min_inliers=0,
22
+ max_inliers=4096,
23
+ random_inliers=False,
24
+ jitter_params=None,
25
+ scale_params=None,
26
+ image_dim=3,
27
+ query_info_path=None,
28
+ sample_ratio=1,
29
+ ):
30
+ self.landmark_path = osp.join(landmark_path, scene)
31
+ self.dataset_path = osp.join(dataset_path, scene)
32
+ self.n_class = n_class
33
+ self.dataset = dataset + '/' + scene
34
+ self.nfeatures = nfeatures
35
+ self.with_aug = with_aug
36
+ self.jitter_params = jitter_params
37
+ self.scale_params = scale_params
38
+ self.image_dim = image_dim
39
+ self.train = train
40
+ self.min_inliers = min_inliers
41
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
42
+ self.random_inliers = random_inliers
43
+ self.image_prefix = ''
44
+ train_transforms = []
45
+ if self.with_aug:
46
+ train_transforms.append(tvt.ColorJitter(
47
+ brightness=jitter_params['brightness'],
48
+ contrast=jitter_params['contrast'],
49
+ saturation=jitter_params['saturation'],
50
+ hue=jitter_params['hue']))
51
+ if jitter_params['blur'] > 0:
52
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
53
+ self.train_transforms = tvt.Compose(train_transforms)
54
+
55
+ if train:
56
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
57
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
58
+
59
+ # only for testing of query images
60
+ if not self.train:
61
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
62
+ self.img_p3d = data
63
+ else:
64
+ self.img_p3d = {}
65
+
66
+ self.img_fns = []
67
+ with open(osp.join(self.dataset_path, 'dataset_train.txt' if train else 'dataset_test.txt'), 'r') as f:
68
+ lines = f.readlines()[3:] # ignore the first 3 lines
69
+ for l in lines:
70
+ l = l.strip().split()[0]
71
+ if train and l not in self.name_to_id.keys():
72
+ continue
73
+ if not train and l not in self.img_p3d.keys():
74
+ continue
75
+ self.img_fns.append(l)
76
+
77
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
78
+ self.dataset, 'training' if train else 'eval'))
79
+
80
+ data = np.load(osp.join(self.landmark_path,
81
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
82
+ allow_pickle=True)[()]
83
+ p3d_id = data['id']
84
+ seg_id = data['label']
85
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
86
+ xyzs = data['xyz']
87
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
88
+
89
+ # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
90
+ # lines = f.readlines()
91
+ # for l in lines:
92
+ # l = l.strip().split()
93
+ # self.mean_xyz = np.array([float(v) for v in l[:3]])
94
+ # self.scale_xyz = np.array([float(v) for v in l[3:]])
95
+
96
+ if not train:
97
+ self.query_info = self.read_query_info(path=query_info_path)
98
+
99
+ self.nfeatures = nfeatures
100
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
101
+ self.feats = {}
third_party/pram/dataset/customdataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> customdataset.py
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:38
7
+ =================================================='''
8
+ import os.path as osp
9
+ import numpy as np
10
+ from colmap_utils.read_write_model import read_model
11
+ import torchvision.transforms as tvt
12
+ from dataset.basicdataset import BasicDataset
13
+
14
+
15
+ class CustomDataset(BasicDataset):
16
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset,
17
+ nfeatures=1024,
18
+ query_p3d_fn=None,
19
+ train=True,
20
+ with_aug=False,
21
+ min_inliers=0,
22
+ max_inliers=4096,
23
+ random_inliers=False,
24
+ jitter_params=None,
25
+ scale_params=None,
26
+ image_dim=3,
27
+ query_info_path=None,
28
+ sample_ratio=1,
29
+ ):
30
+ self.landmark_path = osp.join(landmark_path, scene)
31
+ self.dataset_path = osp.join(dataset_path, scene)
32
+ self.n_class = n_class
33
+ self.dataset = dataset + '/' + scene
34
+ self.nfeatures = nfeatures
35
+ self.with_aug = with_aug
36
+ self.jitter_params = jitter_params
37
+ self.scale_params = scale_params
38
+ self.image_dim = image_dim
39
+ self.train = train
40
+ self.min_inliers = min_inliers
41
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
42
+ self.random_inliers = random_inliers
43
+ self.image_prefix = ''
44
+
45
+ train_transforms = []
46
+ if self.with_aug:
47
+ train_transforms.append(tvt.ColorJitter(
48
+ brightness=jitter_params['brightness'],
49
+ contrast=jitter_params['contrast'],
50
+ saturation=jitter_params['saturation'],
51
+ hue=jitter_params['hue']))
52
+ if jitter_params['blur'] > 0:
53
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
54
+ self.train_transforms = tvt.Compose(train_transforms)
55
+
56
+ if train:
57
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
58
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
59
+
60
+ # only for testing of query images
61
+ if not self.train:
62
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
63
+ self.img_p3d = data
64
+ else:
65
+ self.img_p3d = {}
66
+
67
+ if train:
68
+ self.img_fns = [self.images[v].name for v in self.images.keys() if
69
+ self.images[v].name in self.name_to_id.keys()]
70
+ else:
71
+ self.img_fns = []
72
+ with open(osp.join(self.dataset_path, 'queries_with_intrinsics.txt'), 'r') as f:
73
+ lines = f.readlines()
74
+ for l in lines:
75
+ self.img_fns.append(l.strip().split()[0])
76
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
77
+ self.dataset, 'training' if train else 'eval'))
78
+
79
+ data = np.load(osp.join(self.landmark_path,
80
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
81
+ allow_pickle=True)[()]
82
+ p3d_id = data['id']
83
+ seg_id = data['label']
84
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
85
+ xyzs = data['xyz']
86
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
87
+
88
+ if not train:
89
+ self.query_info = self.read_query_info(path=query_info_path)
90
+
91
+ self.nfeatures = nfeatures
92
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
93
+ self.feats = {}
third_party/pram/dataset/get_dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> get_dataset
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:40
7
+ =================================================='''
8
+ import os.path as osp
9
+ import yaml
10
+ from dataset.aachen import Aachen
11
+ from dataset.twelve_scenes import TwelveScenes
12
+ from dataset.seven_scenes import SevenScenes
13
+ from dataset.cambridge_landmarks import CambridgeLandmarks
14
+ from dataset.customdataset import CustomDataset
15
+ from dataset.recdataset import RecDataset
16
+
17
+
18
+ def get_dataset(dataset):
19
+ if dataset in ['7Scenes', 'S']:
20
+ return SevenScenes
21
+ elif dataset in ['12Scenes', 'T']:
22
+ return TwelveScenes
23
+ elif dataset in ['Aachen', 'A']:
24
+ return Aachen
25
+ elif dataset in ['CambridgeLandmarks', 'C']:
26
+ return CambridgeLandmarks
27
+ else:
28
+ return CustomDataset
29
+
30
+
31
+ def compose_datasets(datasets, config, train=True, sample_ratio=None):
32
+ sub_sets = []
33
+ for name in datasets:
34
+ if name == 'S':
35
+ ds_name = '7Scenes'
36
+ elif name == 'T':
37
+ ds_name = '12Scenes'
38
+ elif name == 'A':
39
+ ds_name = 'Aachen'
40
+ elif name == 'R':
41
+ ds_name = 'RobotCar-Seasons'
42
+ elif name == 'C':
43
+ ds_name = 'CambridgeLandmarks'
44
+ else:
45
+ ds_name = name
46
+ # raise '{} dataset does not exist'.format(name)
47
+ landmark_path = osp.join(config['landmark_path'], ds_name)
48
+ dataset_path = osp.join(config['dataset_path'], ds_name)
49
+ scene_config_path = 'configs/datasets/{:s}.yaml'.format(ds_name)
50
+
51
+ with open(scene_config_path, 'r') as f:
52
+ scene_config = yaml.load(f, Loader=yaml.Loader)
53
+ DSet = get_dataset(dataset=ds_name)
54
+
55
+ for scene in scene_config['scenes']:
56
+ if sample_ratio is None:
57
+ scene_sample_ratio = scene_config[scene]['training_sample_ratio'] if train else scene_config[scene][
58
+ 'eval_sample_ratio']
59
+ else:
60
+ scene_sample_ratio = sample_ratio
61
+ scene_set = DSet(landmark_path=landmark_path,
62
+ dataset_path=dataset_path,
63
+ scene=scene,
64
+ seg_mode=scene_config[scene]['cluster_mode'],
65
+ seg_method=scene_config[scene]['cluster_method'],
66
+ n_class=scene_config[scene]['n_cluster'] + 1, # including invalid - 0
67
+ dataset=ds_name,
68
+ train=train,
69
+ nfeatures=config['max_keypoints'] if train else config['eval_max_keypoints'],
70
+ min_inliers=config['min_inliers'],
71
+ max_inliers=config['max_inliers'],
72
+ random_inliers=config['random_inliers'],
73
+ with_aug=config['with_aug'],
74
+ jitter_params=config['jitter_params'],
75
+ scale_params=config['scale_params'],
76
+ image_dim=config['image_dim'],
77
+ query_p3d_fn=osp.join(config['landmark_path'], ds_name, scene,
78
+ 'point3D_query_n{:d}_{:s}_{:s}.npy'.format(
79
+ scene_config[scene]['n_cluster'],
80
+ scene_config[scene]['cluster_mode'],
81
+ scene_config[scene]['cluster_method'])),
82
+ query_info_path=osp.join(config['dataset_path'], ds_name, scene,
83
+ 'queries_with_intrinsics.txt'),
84
+ sample_ratio=scene_sample_ratio,
85
+ )
86
+
87
+ sub_sets.append(scene_set)
88
+
89
+ return RecDataset(sub_sets=sub_sets)
third_party/pram/dataset/recdataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> recdataset
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:42
7
+ =================================================='''
8
+ import numpy as np
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class RecDataset(Dataset):
13
+ def __init__(self, sub_sets=[]):
14
+ assert len(sub_sets) >= 1
15
+
16
+ self.sub_sets = sub_sets
17
+ self.names = []
18
+
19
+ self.sub_set_index = []
20
+ self.seg_offsets = []
21
+ self.sub_set_item_index = []
22
+ self.dataset_names = []
23
+ self.scene_names = []
24
+ start_index_valid_seg = 1 # start from 1, 0 is for invalid
25
+
26
+ total_subset = 0
27
+ for scene_set in sub_sets: # [0, n_class]
28
+ name = scene_set.dataset
29
+ self.names.append(name)
30
+ n_samples = len(scene_set)
31
+
32
+ n_class = scene_set.n_class
33
+ self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))]
34
+ start_index_valid_seg = start_index_valid_seg + n_class - 1
35
+
36
+ self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)]
37
+ self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)]
38
+
39
+ # self.dataset_names = self.dataset_names + [name for k in range(n_samples)]
40
+ self.scene_names = self.scene_names + [name for k in range(n_samples)]
41
+ total_subset += 1
42
+
43
+ self.n_class = start_index_valid_seg
44
+
45
+ print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class,
46
+ len(sub_sets), self.names))
47
+
48
+ def __len__(self):
49
+ return len(self.sub_set_item_index)
50
+
51
+ def __getitem__(self, idx):
52
+ subset_idx = self.sub_set_index[idx]
53
+ item_idx = self.sub_set_item_index[idx]
54
+ scene_name = self.scene_names[idx]
55
+
56
+ out = self.sub_sets[subset_idx][item_idx]
57
+
58
+ org_gt_seg = out['gt_seg']
59
+ org_gt_cls = out['gt_cls']
60
+ org_gt_cls_dist = out['gt_cls_dist']
61
+ org_gt_n_seg = out['gt_n_seg']
62
+ offset = self.seg_offsets[idx]
63
+ org_n_class = self.sub_sets[subset_idx].n_class
64
+
65
+ gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features]
66
+ gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
67
+ gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
68
+ gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
69
+
70
+ # copy invalid segments
71
+ gt_n_seg[0] = org_gt_n_seg[0]
72
+ gt_cls[0] = org_gt_cls[0]
73
+ gt_cls_dist[0] = org_gt_cls_dist[0]
74
+ # print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg)
75
+
76
+ # copy valid segments
77
+ gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023]
78
+ gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg]
79
+ gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg]
80
+ gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg]
81
+
82
+ out['gt_seg'] = gt_seg
83
+ out['gt_cls'] = gt_cls
84
+ out['gt_cls_dist'] = gt_cls_dist
85
+ out['gt_n_seg'] = gt_n_seg
86
+
87
+ # print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg)
88
+ out['scene_name'] = scene_name
89
+
90
+ # out['org_gt_seg'] = org_gt_seg
91
+ # out['org_gt_n_seg'] = org_gt_n_seg
92
+ # out['org_gt_cls'] = org_gt_cls
93
+ # out['org_gt_cls_dist'] = org_gt_cls_dist
94
+
95
+ return out
third_party/pram/dataset/seven_scenes.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> seven_scenes
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:36
7
+ =================================================='''
8
+ import os
9
+ import os.path as osp
10
+ import numpy as np
11
+ from colmap_utils.read_write_model import read_model
12
+ import torchvision.transforms as tvt
13
+ from dataset.basicdataset import BasicDataset
14
+
15
+
16
+ class SevenScenes(BasicDataset):
17
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='7Scenes',
18
+ nfeatures=1024,
19
+ query_p3d_fn=None,
20
+ train=True,
21
+ with_aug=False,
22
+ min_inliers=0,
23
+ max_inliers=4096,
24
+ random_inliers=False,
25
+ jitter_params=None,
26
+ scale_params=None,
27
+ image_dim=3,
28
+ query_info_path=None,
29
+ sample_ratio=1,
30
+ ):
31
+ self.landmark_path = osp.join(landmark_path, scene)
32
+ self.dataset_path = osp.join(dataset_path, scene)
33
+ self.n_class = n_class
34
+ self.dataset = dataset + '/' + scene
35
+ self.nfeatures = nfeatures
36
+ self.with_aug = with_aug
37
+ self.jitter_params = jitter_params
38
+ self.scale_params = scale_params
39
+ self.image_dim = image_dim
40
+ self.train = train
41
+ self.min_inliers = min_inliers
42
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
43
+ self.random_inliers = random_inliers
44
+ self.image_prefix = ''
45
+
46
+ train_transforms = []
47
+ if self.with_aug:
48
+ train_transforms.append(tvt.ColorJitter(
49
+ brightness=jitter_params['brightness'],
50
+ contrast=jitter_params['contrast'],
51
+ saturation=jitter_params['saturation'],
52
+ hue=jitter_params['hue']))
53
+ if jitter_params['blur'] > 0:
54
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
55
+ self.train_transforms = tvt.Compose(train_transforms)
56
+
57
+ if train:
58
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
59
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
60
+
61
+ # only for testing of query images
62
+ if not self.train:
63
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
64
+ self.img_p3d = data
65
+ else:
66
+ self.img_p3d = {}
67
+
68
+ if self.train:
69
+ split_fn = osp.join(self.dataset_path, 'TrainSplit.txt')
70
+ else:
71
+ split_fn = osp.join(self.dataset_path, 'TestSplit.txt')
72
+
73
+ self.img_fns = []
74
+ with open(split_fn, 'r') as f:
75
+ lines = f.readlines()
76
+ for l in lines:
77
+ seq = int(l.strip()[8:])
78
+ fns = os.listdir(osp.join(self.dataset_path, osp.join('seq-{:02d}'.format(seq))))
79
+ fns = sorted(fns)
80
+ nf = 0
81
+ for fn in fns:
82
+ if fn.find('png') >= 0:
83
+ if train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.name_to_id.keys():
84
+ continue
85
+ if not train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.img_p3d.keys():
86
+ continue
87
+ if nf % sample_ratio == 0:
88
+ self.img_fns.append('seq-{:02d}'.format(seq) + '/' + fn)
89
+ nf += 1
90
+
91
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
92
+ self.dataset, 'training' if train else 'eval'))
93
+
94
+ data = np.load(osp.join(self.landmark_path,
95
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
96
+ allow_pickle=True)[()]
97
+ p3d_id = data['id']
98
+ seg_id = data['label']
99
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
100
+ xyzs = data['xyz']
101
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
102
+
103
+ # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
104
+ # lines = f.readlines()
105
+ # for l in lines:
106
+ # l = l.strip().split()
107
+ # self.mean_xyz = np.array([float(v) for v in l[:3]])
108
+ # self.scale_xyz = np.array([float(v) for v in l[3:]])
109
+
110
+ if not train:
111
+ self.query_info = self.read_query_info(path=query_info_path)
112
+
113
+ self.nfeatures = nfeatures
114
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
115
+ self.feats = {}
third_party/pram/dataset/twelve_scenes.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> twelve_scenes
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:37
7
+ =================================================='''
8
+ import os
9
+ import os.path as osp
10
+ import numpy as np
11
+ from colmap_utils.read_write_model import read_model
12
+ import torchvision.transforms as tvt
13
+ from dataset.basicdataset import BasicDataset
14
+
15
+
16
+ class TwelveScenes(BasicDataset):
17
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='12Scenes',
18
+ nfeatures=1024,
19
+ query_p3d_fn=None,
20
+ train=True,
21
+ with_aug=False,
22
+ min_inliers=0,
23
+ max_inliers=4096,
24
+ random_inliers=False,
25
+ jitter_params=None,
26
+ scale_params=None,
27
+ image_dim=3,
28
+ query_info_path=None,
29
+ sample_ratio=1,
30
+ ):
31
+ self.landmark_path = osp.join(landmark_path, scene)
32
+ self.dataset_path = osp.join(dataset_path, scene)
33
+ self.n_class = n_class
34
+ self.dataset = dataset + '/' + scene
35
+ self.nfeatures = nfeatures
36
+ self.with_aug = with_aug
37
+ self.jitter_params = jitter_params
38
+ self.scale_params = scale_params
39
+ self.image_dim = image_dim
40
+ self.train = train
41
+ self.min_inliers = min_inliers
42
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
43
+ self.random_inliers = random_inliers
44
+ self.image_prefix = ''
45
+
46
+ train_transforms = []
47
+ if self.with_aug:
48
+ train_transforms.append(tvt.ColorJitter(
49
+ brightness=jitter_params['brightness'],
50
+ contrast=jitter_params['contrast'],
51
+ saturation=jitter_params['saturation'],
52
+ hue=jitter_params['hue']))
53
+ if jitter_params['blur'] > 0:
54
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
55
+ self.train_transforms = tvt.Compose(train_transforms)
56
+
57
+ if train:
58
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
59
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
60
+
61
+ # only for testing of query images
62
+ if not self.train:
63
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
64
+ self.img_p3d = data
65
+ else:
66
+ self.img_p3d = {}
67
+
68
+ with open(osp.join(self.dataset_path, 'split.txt'), 'r') as f:
69
+ l = f.readline()
70
+ l = l.strip().split(' ') # sequence0 [frames=357] [start=0 ; end=356], first sequence for testing
71
+ start_img_id = l[-3].split('=')[-1]
72
+ end_img_id = l[-1].split('=')[-1][:-1]
73
+ test_start_img_id = int(start_img_id)
74
+ test_end_img_id = int(end_img_id)
75
+
76
+ self.img_fns = []
77
+ fns = os.listdir(osp.join(self.dataset_path, 'data'))
78
+ fns = sorted(fns)
79
+ nf = 0
80
+ for fn in fns:
81
+ if fn.find('jpg') >= 0: # frame-001098.color.jpg
82
+ frame_id = int(fn.split('.')[0].split('-')[-1])
83
+ if not train and frame_id > test_end_img_id:
84
+ continue
85
+ if train and frame_id <= test_end_img_id:
86
+ continue
87
+
88
+ if train and 'data' + '/' + fn not in self.name_to_id.keys():
89
+ continue
90
+
91
+ if not train and 'data' + '/' + fn not in self.img_p3d.keys():
92
+ continue
93
+ if nf % sample_ratio == 0:
94
+ self.img_fns.append('data' + '/' + fn)
95
+ nf += 1
96
+
97
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
98
+ self.dataset, 'training' if train else 'eval'))
99
+
100
+ data = np.load(osp.join(self.landmark_path,
101
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
102
+ allow_pickle=True)[()]
103
+ p3d_id = data['id']
104
+ seg_id = data['label']
105
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
106
+ xyzs = data['xyz']
107
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
108
+
109
+ # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
110
+ # lines = f.readlines()
111
+ # for l in lines:
112
+ # l = l.strip().split()
113
+ # self.mean_xyz = np.array([float(v) for v in l[:3]])
114
+ # self.scale_xyz = np.array([float(v) for v in l[3:]])
115
+
116
+ if not train:
117
+ self.query_info = self.read_query_info(path=query_info_path)
118
+
119
+ self.nfeatures = nfeatures
120
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
121
+ self.feats = {}
third_party/pram/dataset/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> utils
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 29/01/2024 14:31
7
+ =================================================='''
8
+ import torch
9
+
10
+
11
+ def normalize_size(x, size, scale=0.7):
12
+ size = size.reshape([1, 2])
13
+ norm_fac = size.max() + 0.5
14
+ return (x - size / 2) / (norm_fac * scale)
15
+
16
+
17
+ def collect_batch(batch):
18
+ out = {}
19
+ # if len(batch) == 0:
20
+ # return batch
21
+ # else:
22
+ for k in batch[0].keys():
23
+ tmp = []
24
+ for v in batch:
25
+ tmp.append(v[k])
26
+ if isinstance(batch[0][k], str) or isinstance(batch[0][k], list):
27
+ out[k] = tmp
28
+ else:
29
+ out[k] = torch.cat([torch.from_numpy(i)[None] for i in tmp], dim=0)
30
+
31
+ return out
third_party/pram/environment.yml ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pram
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - binutils_impl_linux-64=2.38=h2a08ee3_1
9
+ - bzip2=1.0.8=h5eee18b_5
10
+ - ca-certificates=2024.3.11=h06a4308_0
11
+ - gcc=12.1.0=h9ea6d83_10
12
+ - gcc_impl_linux-64=12.1.0=hea43390_17
13
+ - kernel-headers_linux-64=2.6.32=he073ed8_17
14
+ - ld_impl_linux-64=2.38=h1181459_1
15
+ - libffi=3.4.4=h6a678d5_0
16
+ - libgcc-devel_linux-64=12.1.0=h1ec3361_17
17
+ - libgcc-ng=13.2.0=h807b86a_5
18
+ - libgomp=13.2.0=h807b86a_5
19
+ - libsanitizer=12.1.0=ha89aaad_17
20
+ - libstdcxx-ng=13.2.0=h7e041cc_5
21
+ - libuuid=1.41.5=h5eee18b_0
22
+ - ncurses=6.4=h6a678d5_0
23
+ - openssl=3.2.1=hd590300_1
24
+ - pip=23.3.1=py310h06a4308_0
25
+ - python=3.10.14=h955ad1f_0
26
+ - readline=8.2=h5eee18b_0
27
+ - setuptools=68.2.2=py310h06a4308_0
28
+ - sqlite=3.41.2=h5eee18b_0
29
+ - sysroot_linux-64=2.12=he073ed8_17
30
+ - tk=8.6.12=h1ccaba5_0
31
+ - wheel=0.41.2=py310h06a4308_0
32
+ - xz=5.4.6=h5eee18b_0
33
+ - zlib=1.2.13=h5eee18b_0
34
+ - pip:
35
+ - addict==2.4.0
36
+ - aiofiles==23.2.1
37
+ - aiohttp==3.9.3
38
+ - aioopenssl==0.6.0
39
+ - aiosasl==0.5.0
40
+ - aiosignal==1.3.1
41
+ - aioxmpp==0.13.3
42
+ - asttokens==2.4.1
43
+ - async-timeout==4.0.3
44
+ - attrs==23.2.0
45
+ - babel==2.14.0
46
+ - benbotasync==3.0.2
47
+ - blinker==1.7.0
48
+ - certifi==2024.2.2
49
+ - cffi==1.16.0
50
+ - charset-normalizer==3.3.2
51
+ - click==8.1.7
52
+ - colorama==0.4.6
53
+ - comm==0.2.2
54
+ - configargparse==1.7
55
+ - contourpy==1.2.1
56
+ - crayons==0.4.0
57
+ - cryptography==42.0.5
58
+ - cycler==0.12.1
59
+ - dash==2.16.1
60
+ - dash-core-components==2.0.0
61
+ - dash-html-components==2.0.0
62
+ - dash-table==5.0.0
63
+ - decorator==5.1.1
64
+ - dnspython==2.6.1
65
+ - einops==0.7.0
66
+ - exceptiongroup==1.2.0
67
+ - executing==2.0.1
68
+ - fastjsonschema==2.19.1
69
+ - filelock==3.13.3
70
+ - flask==3.0.2
71
+ - fonttools==4.50.0
72
+ - fortniteapiasync==0.1.7
73
+ - fortnitepy==3.6.9
74
+ - frozenlist==1.4.1
75
+ - fsspec==2024.3.1
76
+ - h5py==3.10.0
77
+ - html5tagger==1.3.0
78
+ - httptools==0.6.1
79
+ - idna==3.6
80
+ - importlib-metadata==7.1.0
81
+ - ipython==8.23.0
82
+ - ipywidgets==8.1.2
83
+ - itsdangerous==2.1.2
84
+ - jedi==0.19.1
85
+ - jinja2==3.1.3
86
+ - joblib==1.3.2
87
+ - jsonschema==4.21.1
88
+ - jsonschema-specifications==2023.12.1
89
+ - jupyter-core==5.7.2
90
+ - jupyterlab-widgets==3.0.10
91
+ - kiwisolver==1.4.5
92
+ - lxml==4.9.4
93
+ - markupsafe==2.1.5
94
+ - matplotlib==3.8.4
95
+ - matplotlib-inline==0.1.6
96
+ - mpmath==1.3.0
97
+ - multidict==6.0.5
98
+ - nbformat==5.10.4
99
+ - nest-asyncio==1.6.0
100
+ - networkx==3.2.1
101
+ - numpy==1.26.4
102
+ - nvidia-cublas-cu12==12.1.3.1
103
+ - nvidia-cuda-cupti-cu12==12.1.105
104
+ - nvidia-cuda-nvrtc-cu12==12.1.105
105
+ - nvidia-cuda-runtime-cu12==12.1.105
106
+ - nvidia-cudnn-cu12==8.9.2.26
107
+ - nvidia-cufft-cu12==11.0.2.54
108
+ - nvidia-curand-cu12==10.3.2.106
109
+ - nvidia-cusolver-cu12==11.4.5.107
110
+ - nvidia-cusparse-cu12==12.1.0.106
111
+ - nvidia-nccl-cu12==2.19.3
112
+ - nvidia-nvjitlink-cu12==12.4.127
113
+ - nvidia-nvtx-cu12==12.1.105
114
+ - open3d==0.18.0
115
+ - opencv-contrib-python==4.5.5.64
116
+ - packaging==24.0
117
+ - pandas==2.2.1
118
+ - parso==0.8.3
119
+ - pexpect==4.9.0
120
+ - pillow==10.3.0
121
+ - platformdirs==4.2.0
122
+ - plotly==5.20.0
123
+ - prompt-toolkit==3.0.43
124
+ - ptyprocess==0.7.0
125
+ - pure-eval==0.2.2
126
+ - pyasn1==0.6.0
127
+ - pyasn1-modules==0.4.0
128
+ - pybind11==2.12.0
129
+ - pycolmap==0.6.1
130
+ - pycparser==2.22
131
+ - pygments==2.17.2
132
+ - pyopengl==3.1.7
133
+ - pyopengl-accelerate==3.1.7
134
+ - pyopenssl==24.1.0
135
+ - pyparsing==3.1.2
136
+ - pyquaternion==0.9.9
137
+ - python-dateutil==2.9.0.post0
138
+ - pytz==2024.1
139
+ - pyyaml==6.0.1
140
+ - referencing==0.34.0
141
+ - requests==2.31.0
142
+ - retrying==1.3.4
143
+ - rpds-py==0.18.0
144
+ - sanic==23.12.1
145
+ - sanic-routing==23.12.0
146
+ - scikit-learn==1.4.1.post1
147
+ - scipy==1.13.0
148
+ - six==1.16.0
149
+ - sortedcollections==2.1.0
150
+ - sortedcontainers==2.4.0
151
+ - stack-data==0.6.3
152
+ - sympy==1.12
153
+ - tenacity==8.2.3
154
+ - threadpoolctl==3.4.0
155
+ - torch==2.2.2
156
+ - torchvision==0.17.2
157
+ - tqdm==4.66.2
158
+ - tracerite==1.1.1
159
+ - traitlets==5.14.2
160
+ - triton==2.2.0
161
+ - typing-extensions==4.10.0
162
+ - tzdata==2024.1
163
+ - tzlocal==5.2
164
+ - ujson==5.9.0
165
+ - urllib3==2.2.1
166
+ - uvloop==0.15.2
167
+ - wcwidth==0.2.13
168
+ - websockets==12.0
169
+ - werkzeug==3.0.2
170
+ - widgetsnbextension==4.0.10
171
+ - yaml2==0.0.1
172
+ - yarl==1.9.4
173
+ - zipp==3.18.1
third_party/pram/inference.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> inference
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 03/04/2024 16:06
7
+ =================================================='''
8
+ import argparse
9
+ import torch
10
+ import torchvision.transforms.transforms as tvt
11
+ import yaml
12
+ from nets.load_segnet import load_segnet
13
+ from nets.sfd2 import load_sfd2
14
+ from dataset.get_dataset import compose_datasets
15
+
16
+ parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
+ parser.add_argument('--config', type=str, required=True, help='config of specifications')
18
+ parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks')
19
+ parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth')
20
+ parser.add_argument('--rec_weight_path', type=str, required=True, help='recognition weight')
21
+ parser.add_argument('--online', action='store_true', help='online visualization with pangolin')
22
+
23
+ if __name__ == '__main__':
24
+ args = parser.parse_args()
25
+ with open(args.config, 'rt') as f:
26
+ config = yaml.load(f, Loader=yaml.Loader)
27
+ config['landmark_path'] = args.landmark_path
28
+
29
+ feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval()
30
+ print('Load SFD2 weight from {:s}'.format(args.feat_weight_path))
31
+
32
+ # rec_model = get_model(config=config)
33
+ rec_model = load_segnet(network=config['network'],
34
+ n_class=config['n_class'],
35
+ desc_dim=256 if config['use_mid_feature'] else 128,
36
+ n_layers=config['layers'],
37
+ output_dim=config['output_dim'])
38
+ state_dict = torch.load(args.rec_weight_path, map_location='cpu')['model']
39
+ rec_model.load_state_dict(state_dict, strict=True)
40
+ print('Load recognition weight from {:s}'.format(args.rec_weight_path))
41
+
42
+ img_transforms = []
43
+ img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
44
+ img_transforms = tvt.Compose(img_transforms)
45
+
46
+ dataset = config['dataset']
47
+ if not args.online:
48
+ from localization.loc_by_rec_eval import loc_by_rec_eval
49
+
50
+ test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1)
51
+ config['n_class'] = test_set.n_class
52
+
53
+ loc_by_rec_eval(rec_model=rec_model.cuda().eval(),
54
+ loader=test_set,
55
+ local_feat=feat_model.cuda().eval(),
56
+ config=config, img_transforms=img_transforms)
57
+ else:
58
+ from localization.loc_by_rec_online import loc_by_rec_online
59
+
60
+ loc_by_rec_online(rec_model=rec_model.cuda().eval(),
61
+ local_feat=feat_model.cuda().eval(),
62
+ config=config, img_transforms=img_transforms)
third_party/pram/localization/base_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+ from torch import nn
3
+ from copy import copy
4
+ import inspect
5
+
6
+
7
+ class BaseModel(nn.Module, metaclass=ABCMeta):
8
+ default_conf = {}
9
+ required_data_keys = []
10
+
11
+ def __init__(self, conf):
12
+ """Perform some logic and call the _init method of the child model."""
13
+ super().__init__()
14
+ self.conf = conf = {**self.default_conf, **conf}
15
+ self.required_data_keys = copy(self.required_data_keys)
16
+ self._init(conf)
17
+
18
+ def forward(self, data):
19
+ """Check the data and call the _forward method of the child model."""
20
+ for key in self.required_data_keys:
21
+ assert key in data, 'Missing key {} in data'.format(key)
22
+ return self._forward(data)
23
+
24
+ @abstractmethod
25
+ def _init(self, conf):
26
+ """To be implemented by the child class."""
27
+ raise NotImplementedError
28
+
29
+ @abstractmethod
30
+ def _forward(self, data):
31
+ """To be implemented by the child class."""
32
+ raise NotImplementedError
33
+
34
+
35
+ def dynamic_load(root, model):
36
+ module_path = f'{root.__name__}.{model}'
37
+ module = __import__(module_path, fromlist=[''])
38
+ classes = inspect.getmembers(module, inspect.isclass)
39
+ # Filter classes defined in the module
40
+ classes = [c for c in classes if c[1].__module__ == module_path]
41
+ # Filter classes inherited from BaseModel
42
+ classes = [c for c in classes if issubclass(c[1], BaseModel)]
43
+ assert len(classes) == 1, classes
44
+ return classes[0][1]
45
+ # return getattr(module, 'Model')
third_party/pram/localization/camera.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> camera
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 04/03/2024 11:27
7
+ =================================================='''
8
+ import collections
9
+
10
+ Camera = collections.namedtuple(
11
+ "Camera", ["id", "model", "width", "height", "params"])
third_party/pram/localization/extract_features.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> extract_features.py
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 07/02/2024 14:49
7
+ =================================================='''
8
+ import os
9
+ import os.path as osp
10
+ import h5py
11
+ import numpy as np
12
+ import progressbar
13
+ import yaml
14
+ import torch
15
+ import cv2
16
+ import torch.utils.data as Data
17
+ from tqdm import tqdm
18
+ from types import SimpleNamespace
19
+ import logging
20
+ import pprint
21
+ from pathlib import Path
22
+ import argparse
23
+ from nets.sfd2 import ResNet4x, extract_sfd2_return
24
+ from nets.superpoint import SuperPoint, extract_sp_return
25
+
26
+ confs = {
27
+ 'superpoint-n4096': {
28
+ 'output': 'feats-superpoint-n4096',
29
+ 'model': {
30
+ 'name': 'superpoint',
31
+ 'outdim': 256,
32
+ 'use_stability': False,
33
+ 'nms_radius': 3,
34
+ 'max_keypoints': 4096,
35
+ 'conf_th': 0.005,
36
+ 'multiscale': False,
37
+ 'scales': [1.0],
38
+ 'model_fn': osp.join(os.getcwd(),
39
+ "weights/superpoint_v1.pth"),
40
+ },
41
+ 'preprocessing': {
42
+ 'grayscale': True,
43
+ 'resize_max': False,
44
+ },
45
+ },
46
+
47
+ 'resnet4x-20230511-210205-pho-0005': {
48
+ 'output': 'feats-resnet4x-20230511-210205-pho-0005',
49
+ 'model': {
50
+ 'outdim': 128,
51
+ 'name': 'resnet4x',
52
+ 'use_stability': False,
53
+ 'max_keypoints': 4096,
54
+ 'conf_th': 0.005,
55
+ 'multiscale': False,
56
+ 'scales': [1.0],
57
+ 'model_fn': osp.join(os.getcwd(),
58
+ "weights/sfd2_20230511_210205_resnet4x.79.pth"),
59
+ },
60
+ 'preprocessing': {
61
+ 'grayscale': False,
62
+ 'resize_max': False,
63
+ },
64
+ 'mask': False,
65
+ },
66
+
67
+ 'sfd2': {
68
+ 'output': 'feats-sfd2',
69
+ 'model': {
70
+ 'outdim': 128,
71
+ 'name': 'resnet4x',
72
+ 'use_stability': False,
73
+ 'max_keypoints': 4096,
74
+ 'conf_th': 0.005,
75
+ 'multiscale': False,
76
+ 'scales': [1.0],
77
+ 'model_fn': osp.join(os.getcwd(),
78
+ "weights/sfd2_20230511_210205_resnet4x.79.pth"),
79
+ },
80
+ 'preprocessing': {
81
+ 'grayscale': False,
82
+ 'resize_max': False,
83
+ },
84
+ 'mask': False,
85
+ },
86
+ }
87
+
88
+
89
+ class ImageDataset(Data.Dataset):
90
+ default_conf = {
91
+ 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'],
92
+ 'grayscale': False,
93
+ 'resize_max': None,
94
+ 'resize_force': False,
95
+ }
96
+
97
+ def __init__(self, root, conf, image_list=None,
98
+ mask_root=None):
99
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
100
+ self.root = root
101
+
102
+ self.paths = []
103
+ if image_list is None:
104
+ for g in conf.globs:
105
+ self.paths += list(Path(root).glob('**/' + g))
106
+ if len(self.paths) == 0:
107
+ raise ValueError(f'Could not find any image in root: {root}.')
108
+ self.paths = [i.relative_to(root) for i in self.paths]
109
+ else:
110
+ with open(image_list, "r") as f:
111
+ lines = f.readlines()
112
+ for l in lines:
113
+ l = l.strip()
114
+ self.paths.append(Path(l))
115
+
116
+ logging.info(f'Found {len(self.paths)} images in root {root}.')
117
+
118
+ if mask_root is not None:
119
+ self.mask_root = mask_root
120
+ else:
121
+ self.mask_root = None
122
+
123
+ def __getitem__(self, idx):
124
+ path = self.paths[idx]
125
+ if self.conf.grayscale:
126
+ mode = cv2.IMREAD_GRAYSCALE
127
+ else:
128
+ mode = cv2.IMREAD_COLOR
129
+ image = cv2.imread(str(self.root / path), mode)
130
+ if not self.conf.grayscale:
131
+ image = image[:, :, ::-1] # BGR to RGB
132
+ if image is None:
133
+ raise ValueError(f'Cannot read image {str(path)}.')
134
+ image = image.astype(np.float32)
135
+ size = image.shape[:2][::-1]
136
+ w, h = size
137
+
138
+ if self.conf.resize_max and (self.conf.resize_force
139
+ or max(w, h) > self.conf.resize_max):
140
+ scale = self.conf.resize_max / max(h, w)
141
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
142
+ image = cv2.resize(
143
+ image, (w_new, h_new), interpolation=cv2.INTER_CUBIC)
144
+
145
+ if self.conf.grayscale:
146
+ image = image[None]
147
+ else:
148
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
149
+ image = image / 255.
150
+
151
+ data = {
152
+ 'name': str(path),
153
+ 'image': image,
154
+ 'original_size': np.array(size),
155
+ }
156
+
157
+ if self.mask_root is not None:
158
+ mask_path = Path(str(path).replace("jpg", "png"))
159
+ if osp.exists(mask_path):
160
+ mask = cv2.imread(str(self.mask_root / mask_path))
161
+ mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST)
162
+ else:
163
+ mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8)
164
+
165
+ data['mask'] = mask
166
+
167
+ return data
168
+
169
+ def __len__(self):
170
+ return len(self.paths)
171
+
172
+
173
+ def get_model(model_name, weight_path, outdim=128, **kwargs):
174
+ if model_name == 'superpoint':
175
+ model = SuperPoint(config={
176
+ 'descriptor_dim': 256,
177
+ 'nms_radius': 4,
178
+ 'keypoint_threshold': 0.005,
179
+ 'max_keypoints': -1,
180
+ 'remove_borders': 4,
181
+ 'weight_path': weight_path,
182
+ }).eval()
183
+
184
+ extractor = extract_sp_return
185
+
186
+ if model_name == 'resnet4x':
187
+ model = ResNet4x(outdim=outdim).eval()
188
+ model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True)
189
+ extractor = extract_sfd2_return
190
+
191
+ return model, extractor
192
+
193
+
194
+ @torch.no_grad()
195
+ def main(conf, image_dir, export_dir):
196
+ logging.info('Extracting local features with configuration:'
197
+ f'\n{pprint.pformat(conf)}')
198
+ model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"],
199
+ use_stability=conf['model']['use_stability'], outdim=conf['model']['outdim'])
200
+ model = model.cuda()
201
+ loader = ImageDataset(image_dir,
202
+ conf['preprocessing'],
203
+ image_list=args.image_list,
204
+ mask_root=None)
205
+ loader = torch.utils.data.DataLoader(loader, num_workers=4)
206
+
207
+ os.makedirs(export_dir, exist_ok=True)
208
+ feature_path = Path(export_dir, conf['output'] + '.h5')
209
+ feature_path.parent.mkdir(exist_ok=True, parents=True)
210
+ feature_file = h5py.File(str(feature_path), 'a')
211
+
212
+ with tqdm(total=len(loader)) as t:
213
+ for idx, data in enumerate(loader):
214
+ t.update()
215
+ pred = extractor(model, img=data["image"],
216
+ topK=conf["model"]["max_keypoints"],
217
+ mask=None,
218
+ conf_th=conf["model"]["conf_th"],
219
+ scales=conf["model"]["scales"],
220
+ )
221
+
222
+ # pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
223
+ pred['descriptors'] = pred['descriptors'].transpose()
224
+
225
+ t.set_postfix(npoints=pred['keypoints'].shape[0])
226
+ # print(pred['keypoints'].shape)
227
+
228
+ pred['image_size'] = original_size = data['original_size'][0].numpy()
229
+ # pred['descriptors'] = pred['descriptors'].T
230
+ if 'keypoints' in pred.keys():
231
+ size = np.array(data['image'].shape[-2:][::-1])
232
+ scales = (original_size / size).astype(np.float32)
233
+ pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5
234
+
235
+ grp = feature_file.create_group(data['name'][0])
236
+ for k, v in pred.items():
237
+ # print(k, v.shape)
238
+ grp.create_dataset(k, data=v)
239
+
240
+ del pred
241
+
242
+ feature_file.close()
243
+ logging.info('Finished exporting features.')
244
+
245
+ return feature_path
246
+
247
+
248
+ if __name__ == '__main__':
249
+ parser = argparse.ArgumentParser()
250
+ parser.add_argument('--image_dir', type=Path, required=True)
251
+ parser.add_argument('--image_list', type=str, default=None)
252
+ parser.add_argument('--mask_dir', type=Path, default=None)
253
+ parser.add_argument('--export_dir', type=Path, required=True)
254
+ parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
255
+ args = parser.parse_args()
256
+ main(confs[args.conf], args.image_dir, args.export_dir)
third_party/pram/localization/frame.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> frame
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 01/03/2024 10:08
7
+ =================================================='''
8
+ from collections import defaultdict
9
+
10
+ import numpy as np
11
+ import torch
12
+ import pycolmap
13
+
14
+ from localization.camera import Camera
15
+ from localization.utils import compute_pose_error
16
+
17
+
18
+ class Frame:
19
+ def __init__(self, image: np.ndarray, camera: pycolmap.Camera, id: int, name: str = None, qvec=None, tvec=None,
20
+ scene_name=None,
21
+ reference_frame_id=None):
22
+ self.image = image
23
+ self.camera = camera
24
+ self.id = id
25
+ self.name = name
26
+ self.image_size = np.array([camera.height, camera.width])
27
+ self.qvec = qvec
28
+ self.tvec = tvec
29
+ self.scene_name = scene_name
30
+ self.reference_frame_id = reference_frame_id
31
+
32
+ self.keypoints = None # [N, 3]
33
+ self.descriptors = None # [N, D]
34
+ self.segmentations = None # [N C]
35
+ self.seg_scores = None # [N C]
36
+ self.seg_ids = None # [N, 1]
37
+ self.point3D_ids = None # [N, 1]
38
+ self.xyzs = None
39
+
40
+ self.gt_qvec = None
41
+ self.gt_tvec = None
42
+
43
+ self.matched_scene_name = None
44
+ self.matched_keypoints = None
45
+ self.matched_keypoint_ids = None
46
+ self.matched_xyzs = None
47
+ self.matched_point3D_ids = None
48
+ self.matched_inliers = None
49
+ self.matched_sids = None
50
+ self.matched_order = None
51
+
52
+ self.refinement_reference_frame_ids = None
53
+ self.image_rec = None
54
+ self.image_matching = None
55
+ self.image_inlier = None
56
+ self.reference_frame_name = None
57
+ self.image_matching_tmp = None
58
+ self.image_inlier_tmp = None
59
+ self.reference_frame_name_tmp = None
60
+
61
+ self.tracking_status = None
62
+
63
+ self.time_feat = 0
64
+ self.time_rec = 0
65
+ self.time_loc = 0
66
+ self.time_ref = 0
67
+
68
+ def update_point3ds_old(self):
69
+ pt = torch.from_numpy(self.keypoints[:, :2]).unsqueeze(-1) # [M 2 1]
70
+ mpt = torch.from_numpy(self.matched_keypoints[:, :2].transpose()).unsqueeze(0) # [1 2 N]
71
+ dist = torch.sqrt(torch.sum((pt - mpt) ** 2, dim=1))
72
+ values, ids = torch.topk(dist, dim=1, k=1, largest=False)
73
+ values = values[:, 0].numpy()
74
+ ids = ids[:, 0].numpy()
75
+ mask = (values < 1) # 1 pixel error
76
+ self.point3D_ids = np.zeros(shape=(self.keypoints.shape[0],), dtype=int) - 1
77
+ self.point3D_ids[mask] = self.matched_point3D_ids[ids[mask]]
78
+
79
+ # self.xyzs = np.zeros(shape=(self.keypoints.shape[0], 3), dtype=float)
80
+ inlier_mask = self.matched_inliers
81
+ self.xyzs[mask] = self.matched_xyzs[ids[mask]]
82
+ self.seg_ids[mask] = self.matched_sids[ids[mask]]
83
+
84
+ def update_point3ds(self):
85
+ # print('Frame: update_point3ds: ', self.matched_keypoint_ids.shape, self.matched_xyzs.shape,
86
+ # self.matched_sids.shape, self.matched_point3D_ids.shape)
87
+ self.xyzs[self.matched_keypoint_ids] = self.matched_xyzs
88
+ self.seg_ids[self.matched_keypoint_ids] = self.matched_sids
89
+ self.point3D_ids[self.matched_keypoint_ids] = self.matched_point3D_ids
90
+
91
+ def add_keypoints(self, keypoints: np.ndarray, descriptors: np.ndarray):
92
+ self.keypoints = keypoints
93
+ self.descriptors = descriptors
94
+ self.initialize_localization_variables()
95
+
96
+ def add_segmentations(self, segmentations: torch.Tensor, filtering_threshold: float):
97
+ '''
98
+ :param segmentations: [number_points number_labels]
99
+ :return:
100
+ '''
101
+ seg_scores = torch.softmax(segmentations, dim=-1)
102
+ if filtering_threshold > 0:
103
+ scores_background = seg_scores[:, 0]
104
+ non_bg_mask = (scores_background < filtering_threshold)
105
+ print('pre filtering before: ', self.keypoints.shape)
106
+ if torch.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]:
107
+ self.keypoints = self.keypoints[non_bg_mask.cpu().numpy()]
108
+ self.descriptors = self.descriptors[non_bg_mask.cpu().numpy()]
109
+ # print('pre filtering after: ', self.keypoints.shape)
110
+
111
+ # update localization variables
112
+ self.initialize_localization_variables()
113
+
114
+ segmentations = segmentations[non_bg_mask]
115
+ seg_scores = seg_scores[non_bg_mask]
116
+ print('pre filtering after: ', self.keypoints.shape)
117
+
118
+ # extract initial segmentation info
119
+ self.segmentations = segmentations.cpu().numpy()
120
+ self.seg_scores = seg_scores.cpu().numpy()
121
+ self.seg_ids = segmentations.max(dim=-1)[1].cpu().numpy() - 1 # should start from 0
122
+
123
+ def filter_keypoints(self, seg_scores: np.ndarray, filtering_threshold: float):
124
+ scores_background = seg_scores[:, 0]
125
+ non_bg_mask = (scores_background < filtering_threshold)
126
+ print('pre filtering before: ', self.keypoints.shape)
127
+ if np.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]:
128
+ self.keypoints = self.keypoints[non_bg_mask]
129
+ self.descriptors = self.descriptors[non_bg_mask]
130
+ print('pre filtering after: ', self.keypoints.shape)
131
+
132
+ # update localization variables
133
+ self.initialize_localization_variables()
134
+ return non_bg_mask
135
+ else:
136
+ print('pre filtering after: ', self.keypoints.shape)
137
+ return None
138
+
139
+ def compute_pose_error(self, pred_qvec=None, pred_tvec=None):
140
+ if pred_qvec is not None and pred_tvec is not None:
141
+ if self.gt_qvec is not None and self.gt_tvec is not None:
142
+ return compute_pose_error(pred_qcw=pred_qvec, pred_tcw=pred_tvec,
143
+ gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec)
144
+ else:
145
+ return 100, 100
146
+
147
+ if self.qvec is None or self.tvec is None or self.gt_qvec is None or self.gt_tvec is None:
148
+ return 100, 100
149
+ else:
150
+ err_q, err_t = compute_pose_error(pred_qcw=self.qvec, pred_tcw=self.tvec,
151
+ gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec)
152
+ return err_q, err_t
153
+
154
+ def get_intrinsics(self) -> np.ndarray:
155
+ camera_model = self.camera.model.name
156
+ params = self.camera.params
157
+ if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
158
+ fx = fy = params[0]
159
+ cx = params[1]
160
+ cy = params[2]
161
+ elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
162
+ fx = params[0]
163
+ fy = params[1]
164
+ cx = params[2]
165
+ cy = params[3]
166
+ else:
167
+ raise Exception("Camera model not supported")
168
+
169
+ # intrinsics
170
+ K = np.identity(3)
171
+ K[0, 0] = fx
172
+ K[1, 1] = fy
173
+ K[0, 2] = cx
174
+ K[1, 2] = cy
175
+ return K
176
+
177
+ def get_dominate_seg_id(self):
178
+ counts = np.bincount(self.seg_ids[self.seg_ids > 0])
179
+ return np.argmax(counts)
180
+
181
+ def clear_localization_track(self):
182
+ self.matched_scene_name = None
183
+ self.matched_keypoints = None
184
+ self.matched_xyzs = None
185
+ self.matched_point3D_ids = None
186
+ self.matched_inliers = None
187
+ self.matched_sids = None
188
+
189
+ self.refinement_reference_frame_ids = None
190
+
191
+ def initialize_localization_variables(self):
192
+ nkpt = self.keypoints.shape[0]
193
+ self.seg_ids = np.zeros(shape=(nkpt,), dtype=int) - 1
194
+ self.point3D_ids = np.zeros(shape=(nkpt,), dtype=int) - 1
195
+ self.xyzs = np.zeros(shape=(nkpt, 3), dtype=float)
third_party/pram/localization/loc_by_rec_eval.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> loc_by_rec
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 08/02/2024 15:26
7
+ =================================================='''
8
+ import torch
9
+ from torch.autograd import Variable
10
+ from localization.multimap3d import MultiMap3D
11
+ from localization.frame import Frame
12
+ import yaml, cv2, time
13
+ import numpy as np
14
+ import os.path as osp
15
+ import threading
16
+ import os
17
+ from tqdm import tqdm
18
+ from recognition.vis_seg import vis_seg_point, generate_color_dic
19
+ from tools.metrics import compute_iou, compute_precision
20
+ from localization.tracker import Tracker
21
+ from localization.utils import read_query_info
22
+ from localization.camera import Camera
23
+
24
+
25
+ def loc_by_rec_eval(rec_model, loader, config, local_feat, img_transforms=None):
26
+ n_epoch = int(config['weight_path'].split('.')[1])
27
+ save_fn = osp.join(config['localization']['save_path'],
28
+ config['weight_path'].split('/')[0] + '_{:d}'.format(n_epoch) + '_{:d}'.format(
29
+ config['feat_dim']))
30
+ tag = 'k{:d}_th{:d}_mm{:d}_mi{:d}'.format(config['localization']['seg_k'], config['localization']['threshold'],
31
+ config['localization']['min_matches'],
32
+ config['localization']['min_inliers'])
33
+ if config['localization']['do_refinement']:
34
+ tag += '_op{:d}'.format(config['localization']['covisibility_frame'])
35
+ if config['localization']['with_compress']:
36
+ tag += '_comp'
37
+
38
+ save_fn = save_fn + '_' + tag
39
+
40
+ save = config['localization']['save']
41
+ save = config['localization']['save']
42
+ if save:
43
+ save_dir = save_fn
44
+ os.makedirs(save_dir, exist_ok=True)
45
+ else:
46
+ save_dir = None
47
+
48
+ seg_color = generate_color_dic(n_seg=2000)
49
+ dataset_path = config['dataset_path']
50
+ show = config['localization']['show']
51
+ if show:
52
+ cv2.namedWindow('img', cv2.WINDOW_NORMAL)
53
+
54
+ locMap = MultiMap3D(config=config, save_dir=None)
55
+ # start tracker
56
+ mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config)
57
+
58
+ dataset_name = config['dataset'][0]
59
+ all_scene_query_info = {}
60
+ with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f:
61
+ scene_config = yaml.load(f, Loader=yaml.Loader)
62
+ scenes = scene_config['scenes']
63
+ for scene in scenes:
64
+ query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path'])
65
+ query_info = read_query_info(query_fn=query_path)
66
+ all_scene_query_info[dataset_name + '/' + scene] = query_info
67
+ # print(scene, query_info.keys())
68
+
69
+ tracking = False
70
+
71
+ full_log = ''
72
+ failed_cases = []
73
+ success_cases = []
74
+ poses = {}
75
+ err_ths_cnt = [0, 0, 0, 0]
76
+
77
+ seg_results = {}
78
+ time_results = {
79
+ 'feat': [],
80
+ 'rec': [],
81
+ 'loc': [],
82
+ 'ref': [],
83
+ 'total': [],
84
+ }
85
+ n_total = 0
86
+
87
+ loc_scene_names = config['localization']['loc_scene_name']
88
+ # loader = loader[8990:]
89
+ for bid, pred in tqdm(enumerate(loader), total=len(loader)):
90
+ pred = loader[bid]
91
+ image_name = pred['file_name'] # [0]
92
+ scene_name = pred['scene_name'] # [0] # dataset_scene
93
+ if len(loc_scene_names) > 0:
94
+ skip = True
95
+ for loc_scene in loc_scene_names:
96
+ if scene_name.find(loc_scene) > 0:
97
+ skip = False
98
+ break
99
+ if skip:
100
+ continue
101
+ with torch.no_grad():
102
+ for k in pred:
103
+ if k.find('name') >= 0:
104
+ continue
105
+ if k != 'image0' and k != 'image1' and k != 'depth0' and k != 'depth1':
106
+ if type(pred[k]) == np.ndarray:
107
+ pred[k] = Variable(torch.from_numpy(pred[k]).float().cuda())[None]
108
+ elif type(pred[k]) == torch.Tensor:
109
+ pred[k] = Variable(pred[k].float().cuda())
110
+ elif type(pred[k]) == list:
111
+ continue
112
+ else:
113
+ pred[k] = Variable(torch.stack(pred[k]).float().cuda())
114
+ print('scene: ', scene_name, image_name)
115
+
116
+ n_total += 1
117
+ with torch.no_grad():
118
+ img = pred['image']
119
+ while isinstance(img, list):
120
+ img = img[0]
121
+
122
+ new_im = torch.from_numpy(img).permute(2, 0, 1).cuda().float()
123
+ if img_transforms is not None:
124
+ new_im = img_transforms(new_im)[None]
125
+ else:
126
+ new_im = new_im[None]
127
+ img = (img * 255).astype(np.uint8)
128
+
129
+ fn = image_name
130
+ camera_model, width, height, params = all_scene_query_info[scene_name][fn]
131
+ camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params)
132
+ curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=scene_name)
133
+ gt_sub_map = locMap.sub_maps[curr_frame.scene_name]
134
+ if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys():
135
+ curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec']
136
+ curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec']
137
+
138
+ t_start = time.time()
139
+ encoder_out = local_feat.extract_local_global(data={'image': new_im},
140
+ config=
141
+ {
142
+ # 'min_keypoints': 128,
143
+ 'max_keypoints': config['eval_max_keypoints'],
144
+ }
145
+ )
146
+ t_feat = time.time() - t_start
147
+ # global_descriptors_cuda = encoder_out['global_descriptors']
148
+ # scores_cuda = encoder_out['scores'][0][None]
149
+ # kpts_cuda = encoder_out['keypoints'][0][None]
150
+ # descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1)
151
+
152
+ sparse_scores = pred['scores']
153
+ sparse_descs = pred['descriptors']
154
+ sparse_kpts = pred['keypoints']
155
+ gt_seg = pred['gt_seg']
156
+
157
+ curr_frame.add_keypoints(keypoints=np.hstack([sparse_kpts[0].cpu().numpy(),
158
+ sparse_scores[0].cpu().numpy().reshape(-1, 1)]),
159
+ descriptors=sparse_descs[0].cpu().numpy())
160
+ curr_frame.time_feat = t_feat
161
+
162
+ t_start = time.time()
163
+ _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'],
164
+ semi_descs=encoder_out['mid_features'],
165
+ # kpts=kpts_cuda[0],
166
+ kpts=sparse_kpts[0],
167
+ norm_desc=config['norm_desc'])
168
+ rec_out = rec_model({'scores': sparse_scores,
169
+ 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1),
170
+ 'keypoints': sparse_kpts,
171
+ 'image': new_im})
172
+ t_rec = time.time() - t_start
173
+ curr_frame.time_rec = t_rec
174
+
175
+ pred = {
176
+ # 'scores': scores_cuda,
177
+ # 'keypoints': kpts_cuda,
178
+ # 'descriptors': descriptors_cuda,
179
+ # 'global_descriptors': global_descriptors_cuda,
180
+ 'image_size': np.array([img.shape[1], img.shape[0]])[None],
181
+ }
182
+
183
+ pred = {**pred, **rec_out}
184
+ pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C]
185
+
186
+ pred_seg = pred_seg[0].cpu().numpy()
187
+ kpts = sparse_kpts[0].cpu().numpy()
188
+ img_pred_seg = vis_seg_point(img=img, kpts=kpts, segs=pred_seg, seg_color=seg_color, radius=9)
189
+ show_text = 'kpts: {:d}'.format(kpts.shape[0])
190
+ img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text,
191
+ org=(50, 30),
192
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
193
+ fontScale=1, color=(0, 0, 255),
194
+ thickness=2, lineType=cv2.LINE_AA)
195
+ curr_frame.image_rec = img_pred_seg
196
+
197
+ if show:
198
+ cv2.imshow('img', img)
199
+ key = cv2.waitKey(1)
200
+ if key == ord('q'):
201
+ exit(0)
202
+ elif key == ord('s'):
203
+ show_time = -1
204
+ elif key == ord('c'):
205
+ show_time = 1
206
+
207
+ segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C]
208
+ curr_frame.add_segmentations(segmentations=segmentations,
209
+ filtering_threshold=config['localization']['pre_filtering_th'])
210
+
211
+ # Step1: do tracker first
212
+ success = not mTracker.lost and tracking
213
+ if success:
214
+ success = mTracker.run(frame=curr_frame)
215
+ if not success:
216
+ success = locMap.run(q_frame=curr_frame)
217
+ if success:
218
+ curr_frame.update_point3ds()
219
+ if tracking:
220
+ mTracker.lost = False
221
+ mTracker.last_frame = curr_frame
222
+ # '''
223
+ pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C]
224
+ pred_seg = pred_seg[0].cpu().numpy()
225
+ gt_seg = gt_seg[0].cpu().numpy()
226
+ iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=pred_seg.shape[0],
227
+ ignored_ids=[0]) # 0 - background
228
+ prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0])
229
+
230
+ kpts = sparse_kpts[0].cpu().numpy()
231
+ if scene not in seg_results.keys():
232
+ seg_results[scene] = {
233
+ 'day': {
234
+ 'prec': [],
235
+ 'iou': [],
236
+ 'kpts': [],
237
+ },
238
+ 'night': {
239
+ 'prec': [],
240
+ 'iou': [],
241
+ 'kpts': [],
242
+
243
+ }
244
+ }
245
+ if fn.find('night') >= 0:
246
+ seg_results[scene]['night']['prec'].append(prec)
247
+ seg_results[scene]['night']['iou'].append(iou)
248
+ seg_results[scene]['night']['kpts'].append(kpts.shape[0])
249
+ else:
250
+ seg_results[scene]['day']['prec'].append(prec)
251
+ seg_results[scene]['day']['iou'].append(iou)
252
+ seg_results[scene]['day']['kpts'].append(kpts.shape[0])
253
+
254
+ print_text = 'name: {:s}, kpts: {:d}, iou: {:.3f}, prec: {:.3f}'.format(fn, kpts.shape[0], iou,
255
+ prec)
256
+ print(print_text)
257
+ # '''
258
+
259
+ t_feat = curr_frame.time_feat
260
+ t_rec = curr_frame.time_rec
261
+ t_loc = curr_frame.time_loc
262
+ t_ref = curr_frame.time_ref
263
+ t_total = t_feat + t_rec + t_loc + t_ref
264
+ time_results['feat'].append(t_feat)
265
+ time_results['rec'].append(t_rec)
266
+ time_results['loc'].append(t_loc)
267
+ time_results['ref'].append(t_ref)
268
+ time_results['total'].append(t_total)
269
+
270
+ poses[scene + '/' + fn] = (curr_frame.qvec, curr_frame.tvec)
271
+ q_err, t_err = curr_frame.compute_pose_error()
272
+ if q_err <= 5 and t_err <= 0.05:
273
+ err_ths_cnt[0] = err_ths_cnt[0] + 1
274
+ if q_err <= 2 and t_err <= 0.25:
275
+ err_ths_cnt[1] = err_ths_cnt[1] + 1
276
+ if q_err <= 5 and t_err <= 0.5:
277
+ err_ths_cnt[2] = err_ths_cnt[2] + 1
278
+ if q_err <= 10 and t_err <= 5:
279
+ err_ths_cnt[3] = err_ths_cnt[3] + 1
280
+
281
+ if success:
282
+ success_cases.append(scene + '/' + fn)
283
+ print_text = 'qname: {:s} localization success {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
284
+ scene + '/' + fn, len(success_cases), n_total, q_err, t_err, err_ths_cnt[0],
285
+ err_ths_cnt[1],
286
+ err_ths_cnt[2],
287
+ err_ths_cnt[3],
288
+ n_total,
289
+ t_feat, t_rec, t_loc, t_ref, t_total
290
+ )
291
+ else:
292
+ failed_cases.append(scene + '/' + fn)
293
+ print_text = 'qname: {:s} localization fail {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
294
+ scene + '/' + fn, len(failed_cases), n_total, q_err, t_err, err_ths_cnt[0],
295
+ err_ths_cnt[1],
296
+ err_ths_cnt[2],
297
+ err_ths_cnt[3],
298
+ n_total, t_feat, t_rec, t_loc, t_ref, t_total)
299
+ print(print_text)
third_party/pram/localization/loc_by_rec_online.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> loc_by_rec
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 08/02/2024 15:26
7
+ =================================================='''
8
+ import torch
9
+ import pycolmap
10
+ from localization.multimap3d import MultiMap3D
11
+ from localization.frame import Frame
12
+ import yaml, cv2, time
13
+ import numpy as np
14
+ import os.path as osp
15
+ import threading
16
+ from recognition.vis_seg import vis_seg_point, generate_color_dic
17
+ from tools.common import resize_img
18
+ from localization.viewer import Viewer
19
+ from localization.tracker import Tracker
20
+ from localization.utils import read_query_info
21
+ from tools.common import puttext_with_background
22
+
23
+
24
+ def loc_by_rec_online(rec_model, config, local_feat, img_transforms=None):
25
+ seg_color = generate_color_dic(n_seg=2000)
26
+ dataset_path = config['dataset_path']
27
+ show = config['localization']['show']
28
+ if show:
29
+ cv2.namedWindow('img', cv2.WINDOW_NORMAL)
30
+
31
+ locMap = MultiMap3D(config=config, save_dir=None)
32
+ if config['dataset'][0] in ['Aachen']:
33
+ viewer_config = {'scene': 'outdoor',
34
+ 'image_size_indoor': 4,
35
+ 'image_line_width_indoor': 8, }
36
+ elif config['dataset'][0] in ['C']:
37
+ viewer_config = {'scene': 'outdoor'}
38
+ elif config['dataset'][0] in ['12Scenes', '7Scenes']:
39
+ viewer_config = {'scene': 'indoor', }
40
+ else:
41
+ viewer_config = {'scene': 'outdoor',
42
+ 'image_size_indoor': 0.4,
43
+ 'image_line_width_indoor': 2, }
44
+ # start viewer
45
+ mViewer = Viewer(locMap=locMap, seg_color=seg_color, config=viewer_config)
46
+ mViewer.refinement = locMap.do_refinement
47
+ # locMap.viewer = mViewer
48
+ viewer_thread = threading.Thread(target=mViewer.run)
49
+ viewer_thread.start()
50
+
51
+ # start tracker
52
+ mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config)
53
+
54
+ dataset_name = config['dataset'][0]
55
+ all_scene_query_info = {}
56
+ with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f:
57
+ scene_config = yaml.load(f, Loader=yaml.Loader)
58
+
59
+ # multiple scenes in a single dataset
60
+ err_ths_cnt = [0, 0, 0, 0]
61
+
62
+ show_time = -1
63
+ scenes = scene_config['scenes']
64
+ n_total = 0
65
+ for scene in scenes:
66
+ if len(config['localization']['loc_scene_name']) > 0:
67
+ if scene not in config['localization']['loc_scene_name']:
68
+ continue
69
+
70
+ query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path'])
71
+ query_info = read_query_info(query_fn=query_path)
72
+ all_scene_query_info[dataset_name + '/' + scene] = query_info
73
+ image_path = osp.join(dataset_path, dataset_name, scene)
74
+ for fn in sorted(query_info.keys()):
75
+ # for fn in sorted(query_info.keys())[880:][::5]: # darwinRGB-loc-outdoor-aligned
76
+ # for fn in sorted(query_info.keys())[3161:][::5]: # darwinRGB-loc-indoor-aligned
77
+ # for fn in sorted(query_info.keys())[2840:][::5]: # darwinRGB-loc-indoor-aligned
78
+
79
+ # for fn in sorted(query_info.keys())[2100:][::5]: # darwinRGB-loc-outdoor
80
+ # for fn in sorted(query_info.keys())[4360:][::5]: # darwinRGB-loc-indoor
81
+ # for fn in sorted(query_info.keys())[1380:]: # Cam-Church
82
+ # for fn in sorted(query_info.keys())[::5]: #ACUED-test2
83
+ # for fn in sorted(query_info.keys())[1260:]: # jesus aligned
84
+ # for fn in sorted(query_info.keys())[1260:]: # jesus aligned
85
+ # for fn in sorted(query_info.keys())[4850:]:
86
+ img = cv2.imread(osp.join(image_path, fn)) # BGR
87
+
88
+ camera_model, width, height, params = all_scene_query_info[dataset_name + '/' + scene][fn]
89
+ # camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params)
90
+ camera = pycolmap.Camera(model=camera_model, width=int(width), height=int(height), params=params)
91
+ curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=dataset_name + '/' + scene)
92
+ gt_sub_map = locMap.sub_maps[curr_frame.scene_name]
93
+ if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys():
94
+ curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec']
95
+ curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec']
96
+
97
+ with torch.no_grad():
98
+ if config['image_dim'] == 1:
99
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
100
+ img_cuda = torch.from_numpy(img_gray / 255)[None].cuda().float()
101
+ else:
102
+ img_cuda = torch.from_numpy(img / 255).permute(2, 0, 1).cuda().float()
103
+ if img_transforms is not None:
104
+ img_cuda = img_transforms(img_cuda)[None]
105
+ else:
106
+ img_cuda = img_cuda[None]
107
+
108
+ t_start = time.time()
109
+ encoder_out = local_feat.extract_local_global(data={'image': img_cuda},
110
+ config={'min_keypoints': 128,
111
+ 'max_keypoints': config['eval_max_keypoints'],
112
+ }
113
+ )
114
+ t_feat = time.time() - t_start
115
+ # global_descriptors_cuda = encoder_out['global_descriptors']
116
+ scores_cuda = encoder_out['scores'][0][None]
117
+ kpts_cuda = encoder_out['keypoints'][0][None]
118
+ descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1)
119
+
120
+ curr_frame.add_keypoints(keypoints=np.hstack([kpts_cuda[0].cpu().numpy(),
121
+ scores_cuda[0].cpu().numpy().reshape(-1, 1)]),
122
+ descriptors=descriptors_cuda[0].cpu().numpy())
123
+ curr_frame.time_feat = t_feat
124
+
125
+ t_start = time.time()
126
+ _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'],
127
+ semi_descs=encoder_out['mid_features'],
128
+ kpts=kpts_cuda[0],
129
+ norm_desc=config['norm_desc'])
130
+ rec_out = rec_model({'scores': scores_cuda,
131
+ 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1),
132
+ 'keypoints': kpts_cuda,
133
+ 'image': img_cuda})
134
+ t_rec = time.time() - t_start
135
+ curr_frame.time_rec = t_rec
136
+
137
+ pred = {
138
+ 'scores': scores_cuda,
139
+ 'keypoints': kpts_cuda,
140
+ 'descriptors': descriptors_cuda,
141
+ # 'global_descriptors': global_descriptors_cuda,
142
+ 'image_size': np.array([img.shape[1], img.shape[0]])[None],
143
+ }
144
+
145
+ pred = {**pred, **rec_out}
146
+ pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C]
147
+
148
+ pred_seg = pred_seg[0].cpu().numpy()
149
+ kpts = kpts_cuda[0].cpu().numpy()
150
+ segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C]
151
+ curr_frame.add_segmentations(segmentations=segmentations,
152
+ filtering_threshold=config['localization']['pre_filtering_th'])
153
+
154
+ img_pred_seg = vis_seg_point(img=img, kpts=curr_frame.keypoints,
155
+ segs=curr_frame.seg_ids + 1, seg_color=seg_color, radius=9)
156
+ show_text = 'kpts: {:d}'.format(kpts.shape[0])
157
+ img_pred_seg = cv2.putText(img=img_pred_seg,
158
+ text=show_text,
159
+ org=(50, 30),
160
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
161
+ fontScale=1, color=(0, 0, 255),
162
+ thickness=2, lineType=cv2.LINE_AA)
163
+ curr_frame.image_rec = img_pred_seg
164
+
165
+ if show:
166
+ img_text = puttext_with_background(image=img, text='Press C - continue | S - pause | Q - exit',
167
+ org=(30, 50),
168
+ bg_color=(255, 255, 255),
169
+ text_color=(0, 0, 255),
170
+ fontScale=1, thickness=2)
171
+ cv2.imshow('img', img_text)
172
+ key = cv2.waitKey(show_time)
173
+ if key == ord('q'):
174
+ exit(0)
175
+ elif key == ord('s'):
176
+ show_time = -1
177
+ elif key == ord('c'):
178
+ show_time = 1
179
+
180
+ # Step1: do tracker first
181
+ success = not mTracker.lost and mViewer.tracking
182
+ if success:
183
+ success = mTracker.run(frame=curr_frame)
184
+ if success:
185
+ mViewer.update(curr_frame=curr_frame)
186
+
187
+ if not success:
188
+ # success = locMap.run(q_frame=curr_frame, q_segs=segmentations)
189
+ success = locMap.run(q_frame=curr_frame)
190
+ if success:
191
+ mViewer.update(curr_frame=curr_frame)
192
+
193
+ if success:
194
+ curr_frame.update_point3ds()
195
+ if mViewer.tracking:
196
+ mTracker.lost = False
197
+ mTracker.last_frame = curr_frame
198
+
199
+ time.sleep(50 / 1000)
200
+ locMap.do_refinement = mViewer.refinement
201
+
202
+ n_total = n_total + 1
203
+ q_err, t_err = curr_frame.compute_pose_error()
204
+ if q_err <= 5 and t_err <= 0.05:
205
+ err_ths_cnt[0] = err_ths_cnt[0] + 1
206
+ if q_err <= 2 and t_err <= 0.25:
207
+ err_ths_cnt[1] = err_ths_cnt[1] + 1
208
+ if q_err <= 5 and t_err <= 0.5:
209
+ err_ths_cnt[2] = err_ths_cnt[2] + 1
210
+ if q_err <= 10 and t_err <= 5:
211
+ err_ths_cnt[3] = err_ths_cnt[3] + 1
212
+ time_total = curr_frame.time_feat + curr_frame.time_rec + curr_frame.time_loc + curr_frame.time_ref
213
+ print_text = 'qname: {:s} localization {:b}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
214
+ scene + '/' + fn, success, q_err, t_err,
215
+ err_ths_cnt[0],
216
+ err_ths_cnt[1],
217
+ err_ths_cnt[2],
218
+ err_ths_cnt[3],
219
+ n_total,
220
+ curr_frame.time_feat, curr_frame.time_rec, curr_frame.time_loc, curr_frame.time_ref, time_total
221
+ )
222
+ print(print_text)
223
+
224
+ mViewer.terminate()
225
+ viewer_thread.join()
third_party/pram/localization/localizer.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> hloc
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 07/02/2024 16:45
7
+ =================================================='''
8
+
9
+ import os
10
+ import os.path as osp
11
+ from tqdm import tqdm
12
+ import argparse
13
+ import time
14
+ import logging
15
+ import h5py
16
+ import numpy as np
17
+ from pathlib import Path
18
+ from colmap_utils.read_write_model import read_model
19
+ from colmap_utils.parsers import parse_image_lists_with_intrinsics
20
+ # localization
21
+ from localization.match_features_batch import confs
22
+ from localization.base_model import dynamic_load
23
+ from localization import matchers
24
+ from localization.utils import compute_pose_error, read_gt_pose, read_retrieval_results
25
+ from localization.pose_estimator import pose_estimator_hloc, pose_estimator_iterative
26
+
27
+
28
+ def run(args):
29
+ if args.gt_pose_fn is not None:
30
+ gt_poses = read_gt_pose(path=args.gt_pose_fn)
31
+ else:
32
+ gt_poses = {}
33
+ retrievals = read_retrieval_results(args.retrieval)
34
+
35
+ save_root = args.save_root # path to save
36
+ os.makedirs(save_root, exist_ok=True)
37
+ matcher_name = args.matcher_method # matching method
38
+ print('matcher: ', confs[args.matcher_method]['model']['name'])
39
+ Model = dynamic_load(matchers, confs[args.matcher_method]['model']['name'])
40
+ matcher = Model(confs[args.matcher_method]['model']).eval().cuda()
41
+
42
+ local_feat_name = args.features.as_posix().split("/")[-1].split(".")[0] # name of local features
43
+ save_fn = '{:s}_{:s}'.format(local_feat_name, matcher_name)
44
+ if args.use_hloc:
45
+ save_fn = 'hloc_' + save_fn
46
+ save_fn = osp.join(save_root, save_fn)
47
+
48
+ queries = parse_image_lists_with_intrinsics(args.queries)
49
+ _, db_images, points3D = read_model(str(args.reference_sfm), '.bin')
50
+ db_name_to_id = {image.name: i for i, image in db_images.items()}
51
+ feature_file = h5py.File(args.features, 'r')
52
+
53
+ tag = ''
54
+ if args.do_covisible_opt:
55
+ tag = tag + "_o" + str(int(args.obs_thresh)) + 'op' + str(int(args.covisibility_frame))
56
+ tag = tag + "th" + str(int(args.opt_thresh))
57
+ if args.iters > 0:
58
+ tag = tag + "i" + str(int(args.iters))
59
+
60
+ log_fn = save_fn + tag
61
+ vis_dir = save_fn + tag
62
+ results = save_fn + tag
63
+
64
+ full_log_fn = log_fn + '_full.log'
65
+ loc_log_fn = log_fn + '_loc.npy'
66
+ results = Path(results + '.txt')
67
+ vis_dir = Path(vis_dir)
68
+ if vis_dir is not None:
69
+ Path(vis_dir).mkdir(exist_ok=True)
70
+ print("save_fn: ", log_fn)
71
+
72
+ logging.info('Starting localization...')
73
+ poses = {}
74
+ failed_cases = []
75
+ n_total = 0
76
+ n_failed = 0
77
+ full_log_info = ''
78
+ loc_results = {}
79
+
80
+ error_ths = ((0.25, 2), (0.5, 5), (5, 10))
81
+ success = [0, 0, 0]
82
+ total_loc_time = []
83
+
84
+ for qname, qinfo in tqdm(queries):
85
+ kpq = feature_file[qname]['keypoints'].__array__()
86
+ n_total += 1
87
+ time_start = time.time()
88
+
89
+ if qname in retrievals.keys():
90
+ cans = retrievals[qname]
91
+ db_ids = [db_name_to_id[v] for v in cans]
92
+ else:
93
+ cans = []
94
+ db_ids = []
95
+ time_coarse = time.time()
96
+
97
+ if args.use_hloc:
98
+ output = pose_estimator_hloc(qname=qname, qinfo=qinfo, db_ids=db_ids, db_images=db_images,
99
+ points3D=points3D,
100
+ feature_file=feature_file,
101
+ thresh=args.ransac_thresh,
102
+ image_dir=args.image_dir,
103
+ matcher=matcher,
104
+ log_info='',
105
+ query_img_prefix='',
106
+ db_img_prefix='')
107
+ else: # should be faster and more accurate than hloc
108
+ t_start = time.time()
109
+ output = pose_estimator_iterative(qname=qname,
110
+ qinfo=qinfo,
111
+ matcher=matcher,
112
+ db_ids=db_ids,
113
+ db_images=db_images,
114
+ points3D=points3D,
115
+ feature_file=feature_file,
116
+ thresh=args.ransac_thresh,
117
+ image_dir=args.image_dir,
118
+ do_covisibility_opt=args.do_covisible_opt,
119
+ covisibility_frame=args.covisibility_frame,
120
+ log_info='',
121
+ inlier_th=args.inlier_thresh,
122
+ obs_th=args.obs_thresh,
123
+ opt_th=args.opt_thresh,
124
+ gt_qvec=gt_poses[qname]['qvec'] if qname in gt_poses.keys() else None,
125
+ gt_tvec=gt_poses[qname]['tvec'] if qname in gt_poses.keys() else None,
126
+ query_img_prefix='',
127
+ db_img_prefix='database',
128
+ )
129
+ time_full = time.time()
130
+
131
+ qvec = output['qvec']
132
+ tvec = output['tvec']
133
+ loc_time = time_full - time_start
134
+ total_loc_time.append(loc_time)
135
+
136
+ poses[qname] = (qvec, tvec)
137
+ print_text = "All {:d}/{:d} failed cases, time[cs/fn]: {:.2f}/{:.2f}".format(
138
+ n_failed, n_total,
139
+ time_coarse - time_start,
140
+ time_full - time_coarse,
141
+ )
142
+
143
+ if qname in gt_poses.keys():
144
+ gt_qvec = gt_poses[qname]['qvec']
145
+ gt_tvec = gt_poses[qname]['tvec']
146
+
147
+ q_error, t_error = compute_pose_error(pred_qcw=qvec, pred_tcw=tvec, gt_qcw=gt_qvec, gt_tcw=gt_tvec)
148
+
149
+ for error_idx, th in enumerate(error_ths):
150
+ if t_error <= th[0] and q_error <= th[1]:
151
+ success[error_idx] += 1
152
+ print_text += (
153
+ ', q_error:{:.2f} t_error:{:.2f} {:d}/{:d}/{:d}/{:d}, time: {:.2f}, {:d}pts'.format(q_error, t_error,
154
+ success[0],
155
+ success[1],
156
+ success[2], n_total,
157
+ loc_time,
158
+ kpq.shape[0]))
159
+ if output['num_inliers'] == 0:
160
+ failed_cases.append(qname)
161
+
162
+ loc_results[qname] = {
163
+ 'keypoints_query': output['keypoints_query'],
164
+ 'points3D_ids': output['points3D_ids'],
165
+ }
166
+ full_log_info = full_log_info + output['log_info']
167
+ full_log_info += (print_text + "\n")
168
+ print(print_text)
169
+
170
+ logs_path = f'{results}.failed'
171
+ with open(logs_path, 'w') as f:
172
+ for v in failed_cases:
173
+ print(v)
174
+ f.write(v + "\n")
175
+
176
+ logging.info(f'Localized {len(poses)} / {len(queries)} images.')
177
+ logging.info(f'Writing poses to {results}...')
178
+ # logging.info(f'Mean loc time: {np.mean(total_loc_time)}...')
179
+ print('Mean loc time: {:.2f}...'.format(np.mean(total_loc_time)))
180
+ with open(results, 'w') as f:
181
+ for q in poses:
182
+ qvec, tvec = poses[q]
183
+ qvec = ' '.join(map(str, qvec))
184
+ tvec = ' '.join(map(str, tvec))
185
+ name = q
186
+ f.write(f'{name} {qvec} {tvec}\n')
187
+
188
+ with open(full_log_fn, 'w') as f:
189
+ f.write(full_log_info)
190
+
191
+ np.save(loc_log_fn, loc_results)
192
+ print('Save logs to ', loc_log_fn)
193
+ logging.info('Done!')
194
+
195
+
196
+ if __name__ == '__main__':
197
+ parser = argparse.ArgumentParser()
198
+ parser.add_argument('--image_dir', type=str, required=True)
199
+ parser.add_argument('--dataset', type=str, required=True)
200
+ parser.add_argument('--reference_sfm', type=Path, required=True)
201
+ parser.add_argument('--queries', type=Path, required=True)
202
+ parser.add_argument('--features', type=Path, required=True)
203
+ parser.add_argument('--ransac_thresh', type=float, default=12)
204
+ parser.add_argument('--covisibility_frame', type=int, default=50)
205
+ parser.add_argument('--do_covisible_opt', action='store_true')
206
+ parser.add_argument('--use_hloc', action='store_true')
207
+ parser.add_argument('--matcher_method', type=str, default="NNM")
208
+ parser.add_argument('--inlier_thresh', type=int, default=50)
209
+ parser.add_argument('--obs_thresh', type=float, default=3)
210
+ parser.add_argument('--opt_thresh', type=float, default=12)
211
+ parser.add_argument('--save_root', type=str, required=True)
212
+ parser.add_argument('--retrieval', type=Path, default=None)
213
+ parser.add_argument('--gt_pose_fn', type=str, default=None)
214
+
215
+ args = parser.parse_args()
216
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
217
+ run(args=args)
third_party/pram/localization/match_features.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from pathlib import Path
4
+ import h5py
5
+ import logging
6
+ from tqdm import tqdm
7
+ import pprint
8
+
9
+ import localization.matchers as matchers
10
+ from localization.base_model import dynamic_load
11
+ from colmap_utils.parsers import names_to_pair
12
+
13
+ confs = {
14
+ 'gm': {
15
+ 'output': 'gm',
16
+ 'model': {
17
+ 'name': 'gm',
18
+ 'weight_path': 'weights/imp_gm.900.pth',
19
+ 'sinkhorn_iterations': 20,
20
+ },
21
+ },
22
+ 'gml': {
23
+ 'output': 'gml',
24
+ 'model': {
25
+ 'name': 'gml',
26
+ 'weight_path': 'weights/imp_gml.920.pth',
27
+ 'sinkhorn_iterations': 20,
28
+ },
29
+ },
30
+
31
+ 'adagml': {
32
+ 'output': 'adagml',
33
+ 'model': {
34
+ 'name': 'adagml',
35
+ 'weight_path': 'weights/imp_adagml.80.pth',
36
+ 'sinkhorn_iterations': 20,
37
+ },
38
+ },
39
+
40
+ 'superglue': {
41
+ 'output': 'superglue',
42
+ 'model': {
43
+ 'name': 'superglue',
44
+ 'weights': 'outdoor',
45
+ 'sinkhorn_iterations': 20,
46
+ 'weight_path': 'weights/superglue_outdoor.pth',
47
+ },
48
+ },
49
+ 'NNM': {
50
+ 'output': 'NNM',
51
+ 'model': {
52
+ 'name': 'nearest_neighbor',
53
+ 'do_mutual_check': True,
54
+ 'distance_threshold': None,
55
+ },
56
+ },
57
+ }
58
+
59
+
60
+ @torch.no_grad()
61
+ def main(conf, pairs, features, export_dir, exhaustive=False):
62
+ logging.info('Matching local features with configuration:'
63
+ f'\n{pprint.pformat(conf)}')
64
+
65
+ feature_path = Path(export_dir, features + '.h5')
66
+ assert feature_path.exists(), feature_path
67
+ feature_file = h5py.File(str(feature_path), 'r')
68
+ pairs_name = pairs.stem
69
+ if not exhaustive:
70
+ assert pairs.exists(), pairs
71
+ with open(pairs, 'r') as f:
72
+ pair_list = f.read().rstrip('\n').split('\n')
73
+ elif exhaustive:
74
+ logging.info(f'Writing exhaustive match pairs to {pairs}.')
75
+ assert not pairs.exists(), pairs
76
+
77
+ # get the list of images from the feature file
78
+ images = []
79
+ feature_file.visititems(
80
+ lambda name, obj: images.append(obj.parent.name.strip('/'))
81
+ if isinstance(obj, h5py.Dataset) else None)
82
+ images = list(set(images))
83
+
84
+ pair_list = [' '.join((images[i], images[j]))
85
+ for i in range(len(images)) for j in range(i)]
86
+ with open(str(pairs), 'w') as f:
87
+ f.write('\n'.join(pair_list))
88
+
89
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
+ Model = dynamic_load(matchers, conf['model']['name'])
91
+ model = Model(conf['model']).eval().to(device)
92
+
93
+ match_name = f'{features}-{conf["output"]}-{pairs_name}'
94
+ match_path = Path(export_dir, match_name + '.h5')
95
+
96
+ match_file = h5py.File(str(match_path), 'a')
97
+
98
+ matched = set()
99
+ for pair in tqdm(pair_list, smoothing=.1):
100
+ name0, name1 = pair.split(' ')
101
+ pair = names_to_pair(name0, name1)
102
+
103
+ # Avoid to recompute duplicates to save time
104
+ if len({(name0, name1), (name1, name0)} & matched) \
105
+ or pair in match_file:
106
+ continue
107
+
108
+ data = {}
109
+ feats0, feats1 = feature_file[name0], feature_file[name1]
110
+ for k in feats1.keys():
111
+ # data[k + '0'] = feats0[k].__array__()
112
+ if k == 'descriptors':
113
+ data[k + '0'] = feats0[k][()].transpose() # [N D]
114
+ else:
115
+ data[k + '0'] = feats0[k][()]
116
+ for k in feats1.keys():
117
+ # data[k + '1'] = feats1[k].__array__()
118
+ # data[k + '1'] = feats1[k][()].transpose() # [N D]
119
+ if k == 'descriptors':
120
+ data[k + '1'] = feats1[k][()].transpose() # [N D]
121
+ else:
122
+ data[k + '1'] = feats1[k][()]
123
+ data = {k: torch.from_numpy(v)[None].float().to(device)
124
+ for k, v in data.items()}
125
+
126
+ # some matchers might expect an image but only use its size
127
+ data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1])
128
+ data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1])
129
+
130
+ pred = model(data)
131
+ grp = match_file.create_group(pair)
132
+ matches = pred['matches0'][0].cpu().short().numpy()
133
+ grp.create_dataset('matches0', data=matches)
134
+
135
+ if 'matching_scores0' in pred:
136
+ scores = pred['matching_scores0'][0].cpu().half().numpy()
137
+ grp.create_dataset('matching_scores0', data=scores)
138
+
139
+ matched |= {(name0, name1), (name1, name0)}
140
+
141
+ match_file.close()
142
+ logging.info('Finished exporting matches.')
143
+
144
+ return match_path
145
+
146
+
147
+ if __name__ == '__main__':
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument('--export_dir', type=Path, required=True)
150
+ parser.add_argument('--features', type=str, required=True)
151
+ parser.add_argument('--pairs', type=Path, required=True)
152
+ parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
153
+ parser.add_argument('--exhaustive', action='store_true')
154
+ args = parser.parse_args()
155
+ main(confs[args.conf], args.pairs, args.features, args.export_dir,
156
+ exhaustive=args.exhaustive)
third_party/pram/localization/match_features_batch.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from pathlib import Path
4
+ import h5py
5
+ import logging
6
+ from tqdm import tqdm
7
+ import pprint
8
+ from queue import Queue
9
+ from threading import Thread
10
+ from functools import partial
11
+ from typing import Dict, List, Optional, Tuple, Union
12
+
13
+ import localization.matchers as matchers
14
+ from localization.base_model import dynamic_load
15
+ from colmap_utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
16
+
17
+ confs = {
18
+ 'gm': {
19
+ 'output': 'gm',
20
+ 'model': {
21
+ 'name': 'gm',
22
+ 'weight_path': 'weights/imp_gm.900.pth',
23
+ 'sinkhorn_iterations': 20,
24
+ },
25
+ },
26
+ 'gml': {
27
+ 'output': 'gml',
28
+ 'model': {
29
+ 'name': 'gml',
30
+ 'weight_path': 'weights/imp_gml.920.pth',
31
+ 'sinkhorn_iterations': 20,
32
+ },
33
+ },
34
+
35
+ 'adagml': {
36
+ 'output': 'adagml',
37
+ 'model': {
38
+ 'name': 'adagml',
39
+ 'weight_path': 'weights/imp_adagml.80.pth',
40
+ 'sinkhorn_iterations': 20,
41
+ },
42
+ },
43
+
44
+ 'superglue': {
45
+ 'output': 'superglue',
46
+ 'model': {
47
+ 'name': 'superglue',
48
+ 'weights': 'outdoor',
49
+ 'sinkhorn_iterations': 20,
50
+ 'weight_path': 'weights/superglue_outdoor.pth',
51
+ },
52
+ },
53
+ 'NNM': {
54
+ 'output': 'NNM',
55
+ 'model': {
56
+ 'name': 'nearest_neighbor',
57
+ 'do_mutual_check': True,
58
+ 'distance_threshold': None,
59
+ },
60
+ },
61
+ }
62
+
63
+
64
+ class WorkQueue:
65
+ def __init__(self, work_fn, num_threads=1):
66
+ self.queue = Queue(num_threads)
67
+ self.threads = [
68
+ Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads)
69
+ ]
70
+ for thread in self.threads:
71
+ thread.start()
72
+
73
+ def join(self):
74
+ for thread in self.threads:
75
+ self.queue.put(None)
76
+ for thread in self.threads:
77
+ thread.join()
78
+
79
+ def thread_fn(self, work_fn):
80
+ item = self.queue.get()
81
+ while item is not None:
82
+ work_fn(item)
83
+ item = self.queue.get()
84
+
85
+ def put(self, data):
86
+ self.queue.put(data)
87
+
88
+
89
+ class FeaturePairsDataset(torch.utils.data.Dataset):
90
+ def __init__(self, pairs, feature_path_q, feature_path_r):
91
+ self.pairs = pairs
92
+ self.feature_path_q = feature_path_q
93
+ self.feature_path_r = feature_path_r
94
+
95
+ def __getitem__(self, idx):
96
+ name0, name1 = self.pairs[idx]
97
+ data = {}
98
+ with h5py.File(self.feature_path_q, "r") as fd:
99
+ grp = fd[name0]
100
+ for k, v in grp.items():
101
+ data[k + "0"] = torch.from_numpy(v.__array__()).float()
102
+ if k == 'descriptors':
103
+ data[k + '0'] = data[k + '0'].t()
104
+ # some matchers might expect an image but only use its size
105
+ data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
106
+ with h5py.File(self.feature_path_r, "r") as fd:
107
+ grp = fd[name1]
108
+ for k, v in grp.items():
109
+ data[k + "1"] = torch.from_numpy(v.__array__()).float()
110
+ if k == 'descriptors':
111
+ data[k + '1'] = data[k + '1'].t()
112
+ data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
113
+ return data
114
+
115
+ def __len__(self):
116
+ return len(self.pairs)
117
+
118
+
119
+ def writer_fn(inp, match_path):
120
+ pair, pred = inp
121
+ with h5py.File(str(match_path), "a", libver="latest") as fd:
122
+ if pair in fd:
123
+ del fd[pair]
124
+ grp = fd.create_group(pair)
125
+ matches = pred["matches0"][0].cpu().short().numpy()
126
+ grp.create_dataset("matches0", data=matches)
127
+ if "matching_scores0" in pred:
128
+ scores = pred["matching_scores0"][0].cpu().half().numpy()
129
+ grp.create_dataset("matching_scores0", data=scores)
130
+
131
+
132
+ def main(
133
+ conf: Dict,
134
+ pairs: Path,
135
+ features: Union[Path, str],
136
+ export_dir: Optional[Path] = None,
137
+ matches: Optional[Path] = None,
138
+ features_ref: Optional[Path] = None,
139
+ overwrite: bool = False,
140
+ ) -> Path:
141
+ if isinstance(features, Path) or Path(features).exists():
142
+ features_q = features
143
+ if matches is None:
144
+ raise ValueError(
145
+ "Either provide both features and matches as Path" " or both as names."
146
+ )
147
+ else:
148
+ if export_dir is None:
149
+ raise ValueError(
150
+ "Provide an export_dir if features is not" f" a file path: {features}."
151
+ )
152
+ features_q = Path(export_dir, features + ".h5")
153
+ if matches is None:
154
+ matches = Path(export_dir, f'{features}-{conf["output"]}-{pairs.stem}.h5')
155
+
156
+ if features_ref is None:
157
+ features_ref = features_q
158
+ match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)
159
+
160
+ return matches
161
+
162
+
163
+ def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None):
164
+ """Avoid to recompute duplicates to save time."""
165
+ pairs = set()
166
+ for i, j in pairs_all:
167
+ if (j, i) not in pairs:
168
+ pairs.add((i, j))
169
+ pairs = list(pairs)
170
+ if match_path is not None and match_path.exists():
171
+ with h5py.File(str(match_path), "r", libver="latest") as fd:
172
+ pairs_filtered = []
173
+ for i, j in pairs:
174
+ if (
175
+ names_to_pair(i, j) in fd
176
+ or names_to_pair(j, i) in fd
177
+ or names_to_pair_old(i, j) in fd
178
+ or names_to_pair_old(j, i) in fd
179
+ ):
180
+ continue
181
+ pairs_filtered.append((i, j))
182
+ return pairs_filtered
183
+ return pairs
184
+
185
+
186
+ @torch.no_grad()
187
+ def match_from_paths(
188
+ conf: Dict,
189
+ pairs_path: Path,
190
+ match_path: Path,
191
+ feature_path_q: Path,
192
+ feature_path_ref: Path,
193
+ overwrite: bool = False,
194
+ ) -> Path:
195
+ logging.info(
196
+ "Matching local features with configuration:" f"\n{pprint.pformat(conf)}"
197
+ )
198
+
199
+ if not feature_path_q.exists():
200
+ raise FileNotFoundError(f"Query feature file {feature_path_q}.")
201
+ if not feature_path_ref.exists():
202
+ raise FileNotFoundError(f"Reference feature file {feature_path_ref}.")
203
+ match_path.parent.mkdir(exist_ok=True, parents=True)
204
+
205
+ assert pairs_path.exists(), pairs_path
206
+ pairs = parse_retrieval(pairs_path)
207
+ pairs = [(q, r) for q, rs in pairs.items() for r in rs]
208
+ pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
209
+ if len(pairs) == 0:
210
+ logging.info("Skipping the matching.")
211
+ return
212
+
213
+ device = "cuda" if torch.cuda.is_available() else "cpu"
214
+ Model = dynamic_load(matchers, conf["model"]["name"])
215
+ model = Model(conf["model"]).eval().to(device)
216
+
217
+ dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
218
+ loader = torch.utils.data.DataLoader(
219
+ dataset, num_workers=4, batch_size=1, shuffle=False, pin_memory=True
220
+ )
221
+ writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)
222
+
223
+ for idx, data in enumerate(tqdm(loader, smoothing=0.1)):
224
+ data = {
225
+ k: v if k.startswith("image") else v.to(device, non_blocking=True)
226
+ for k, v in data.items()
227
+ }
228
+ pred = model(data)
229
+ pair = names_to_pair(*pairs[idx])
230
+ writer_queue.put((pair, pred))
231
+ writer_queue.join()
232
+ logging.info("Finished exporting matches.")
233
+
234
+
235
+ if __name__ == '__main__':
236
+ parser = argparse.ArgumentParser()
237
+ parser.add_argument('--export_dir', type=Path, required=True)
238
+ parser.add_argument('--features', type=str, required=True)
239
+ parser.add_argument('--pairs', type=Path, required=True)
240
+ parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
241
+ args = parser.parse_args()
242
+ main(confs[args.conf], args.pairs, args.features, args.export_dir)
third_party/pram/localization/matchers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def get_matcher(matcher):
2
+ mod = __import__(f'{__name__}.{matcher}', fromlist=[''])
3
+ return getattr(mod, 'Model')
third_party/pram/localization/matchers/adagml.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ '''=================================================
3
+ @Project -> File pram -> adagml
4
+ @IDE PyCharm
5
+ @Author fx221@cam.ac.uk
6
+ @Date 11/02/2024 14:34
7
+ =================================================='''
8
+ import torch
9
+ from localization.base_model import BaseModel
10
+ from nets.adagml import AdaGML as GMatcher
11
+
12
+
13
+ class AdaGML(BaseModel):
14
+ default_config = {
15
+ 'descriptor_dim': 128,
16
+ 'hidden_dim': 256,
17
+ 'weights': 'indoor',
18
+ 'keypoint_encoder': [32, 64, 128, 256],
19
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
20
+ 'sinkhorn_iterations': 20,
21
+ 'match_threshold': 0.2,
22
+ 'with_pose': False,
23
+ 'n_layers': 9,
24
+ 'n_min_tokens': 256,
25
+ 'with_sinkhorn': True,
26
+ 'weight_path': None,
27
+ }
28
+
29
+ required_inputs = [
30
+ 'image0', 'keypoints0', 'scores0', 'descriptors0',
31
+ 'image1', 'keypoints1', 'scores1', 'descriptors1',
32
+ ]
33
+
34
+ def _init(self, conf):
35
+ self.net = GMatcher(config=conf).eval()
36
+ state_dict = torch.load(conf['weight_path'], map_location='cpu')['model']
37
+ self.net.load_state_dict(state_dict, strict=True)
38
+
39
+ def _forward(self, data):
40
+ with torch.no_grad():
41
+ return self.net(data)