Vincentqyw commited on
Commit
8320ccc
1 Parent(s): b075789

update: ci

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .flake8 +4 -0
  2. .github/.stale.yml +17 -0
  3. .github/ISSUE_TEMPLATE/bug_report.md +30 -0
  4. .github/ISSUE_TEMPLATE/config.yml +3 -0
  5. .github/ISSUE_TEMPLATE/feature_request.md +15 -0
  6. .github/ISSUE_TEMPLATE/question.md +25 -0
  7. .github/PULL_REQUEST_TEMPLATE.md +7 -0
  8. .github/release-drafter.yml +24 -0
  9. .github/workflows/ci.yml +32 -0
  10. .github/workflows/format.yml +24 -0
  11. .github/workflows/release-drafter.yml +16 -0
  12. .gitignore +3 -1
  13. README.md +2 -0
  14. common/api.py +19 -22
  15. common/app_class.py +17 -16
  16. common/config.yaml +39 -6
  17. common/utils.py +39 -41
  18. common/viz.py +6 -4
  19. docker/build_docker.bat +3 -0
  20. run_docker.sh → docker/run_docker.bat +0 -0
  21. docker/run_docker.sh +1 -0
  22. env-docker.txt +0 -33
  23. format.sh +3 -0
  24. hloc/__init__.py +1 -0
  25. hloc/extract_features.py +27 -9
  26. hloc/extractors/alike.py +4 -2
  27. hloc/extractors/d2net.py +8 -7
  28. hloc/extractors/darkfeat.py +5 -2
  29. hloc/extractors/dedode.py +10 -8
  30. hloc/extractors/dir.py +5 -4
  31. hloc/extractors/disk.py +1 -1
  32. hloc/extractors/dog.py +4 -5
  33. hloc/extractors/example.py +3 -3
  34. hloc/extractors/fire.py +5 -4
  35. hloc/extractors/fire_local.py +7 -10
  36. hloc/extractors/lanet.py +9 -6
  37. hloc/extractors/netvlad.py +4 -3
  38. hloc/extractors/r2d2.py +6 -5
  39. hloc/extractors/rekd.py +4 -3
  40. hloc/extractors/rord.py +3 -3
  41. hloc/extractors/sfd2.py +43 -0
  42. hloc/extractors/sift.py +3 -2
  43. hloc/extractors/superpoint.py +4 -2
  44. hloc/extractors/xfeat.py +3 -2
  45. hloc/match_dense.py +4 -2
  46. hloc/match_features.py +14 -6
  47. hloc/matchers/adalam.py +2 -3
  48. hloc/matchers/aspanformer.py +5 -6
  49. hloc/matchers/cotr.py +10 -10
  50. hloc/matchers/dkm.py +6 -4
.flake8 ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 80
3
+ extend-ignore = E203,E501,E402
4
+ exclude = .git,__pycache__,build,.venv/,third_party
.github/.stale.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Number of days of inactivity before an issue becomes stale
2
+ daysUntilStale: 60
3
+ # Number of days of inactivity before a stale issue is closed
4
+ daysUntilClose: 7
5
+ # Issues with these labels will never be considered stale
6
+ exemptLabels:
7
+ - pinned
8
+ - security
9
+ # Label to use when marking an issue as stale
10
+ staleLabel: wontfix
11
+ # Comment to post when marking an issue as stale. Set to `false` to disable
12
+ markComment: >
13
+ This issue has been automatically marked as stale because it has not had
14
+ recent activity. It will be closed if no further activity occurs. Thank you
15
+ for your contributions.
16
+ # Comment to post when closing a stale issue. Set to `false` to disable
17
+ closeComment: false
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug report
3
+ about: If something isn't working 🔧
4
+ title: ""
5
+ labels: bug
6
+ assignees:
7
+ ---
8
+
9
+ ## 🐛 Bug Report
10
+
11
+ <!-- A clear and concise description of what the bug is. -->
12
+
13
+ ## 🔬 How To Reproduce
14
+
15
+ Steps to reproduce the behavior:
16
+
17
+ 1. ...
18
+
19
+ ### Environment
20
+
21
+ - OS: [e.g. Linux / Windows / macOS]
22
+ - Python version, get it with:
23
+
24
+ ```bash
25
+ python --version
26
+ ```
27
+
28
+ ## 📎 Additional context
29
+
30
+ <!-- Add any other context about the problem here. -->
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository
2
+
3
+ blank_issues_enabled: false
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature request
3
+ about: Suggest an idea for this project 🏖
4
+ title: ""
5
+ labels: enhancement
6
+ assignees:
7
+ ---
8
+
9
+ ## 🚀 Feature Request
10
+
11
+ <!-- A clear and concise description of the feature proposal. -->
12
+
13
+ ## 📎 Additional context
14
+
15
+ <!-- Add any other context or screenshots about the feature request here. -->
.github/ISSUE_TEMPLATE/question.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Question
3
+ about: Ask a question about this project 🎓
4
+ title: ""
5
+ labels: question
6
+ assignees:
7
+ ---
8
+
9
+ ## Checklist
10
+
11
+ <!-- Mark with an `x` all the checkboxes that apply (like `[x]`) -->
12
+
13
+ - [ ] I've searched the project's [`issues`]
14
+
15
+ ## ❓ Question
16
+
17
+ <!-- What is your question -->
18
+
19
+ How can I [...]?
20
+
21
+ Is it possible to [...]?
22
+
23
+ ## 📎 Additional context
24
+
25
+ <!-- Add any other context or screenshots about the feature request here. -->
.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ## Description
2
+
3
+ <!-- Add a more detailed description of the changes if needed. -->
4
+
5
+ ## Related Issue
6
+
7
+ <!-- If your PR refers to a related issue, link it here. -->
.github/release-drafter.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Release drafter configuration https://github.com/release-drafter/release-drafter#configuration
2
+ # Emojis were chosen to match the https://gitmoji.carloscuesta.me/
3
+
4
+ name-template: "v$RESOLVED_VERSION"
5
+ tag-template: "v$RESOLVED_VERSION"
6
+
7
+ categories:
8
+ - title: ":rocket: Features"
9
+ labels: [enhancement, feature]
10
+ - title: ":wrench: Fixes"
11
+ labels: [bug, bugfix, fix]
12
+ - title: ":toolbox: Maintenance & Refactor"
13
+ labels: [refactor, refactoring, chore]
14
+ - title: ":package: Build System & CI/CD & Test"
15
+ labels: [build, ci, testing, test]
16
+ - title: ":pencil: Documentation"
17
+ labels: [documentation]
18
+ - title: ":arrow_up: Dependencies updates"
19
+ labels: [dependencies]
20
+
21
+ template: |
22
+ ## What’s Changed
23
+
24
+ $CHANGES
.github/workflows/ci.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI CPU
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+
11
+ jobs:
12
+ build:
13
+ runs-on: ubuntu-latest
14
+
15
+ steps:
16
+ - name: Checkout code
17
+ uses: actions/checkout@v4
18
+ with:
19
+ submodules: recursive
20
+
21
+ - name: Set up Python
22
+ uses: actions/setup-python@v4
23
+ with:
24
+ python-version: "3.10"
25
+
26
+ - name: Install dependencies
27
+ run: |
28
+ pip install -r requirements.txt
29
+ sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
30
+
31
+ - name: Run tests
32
+ run: python test_app_cli.py
.github/workflows/format.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Format and Lint Checks
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ paths:
7
+ - '*.py'
8
+ pull_request:
9
+ types: [ assigned, opened, synchronize, reopened ]
10
+ jobs:
11
+ check:
12
+ name: Format and Lint Checks
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - uses: actions/setup-python@v4
17
+ with:
18
+ python-version: '3.10'
19
+ cache: 'pip'
20
+ - run: python -m pip install --upgrade pip
21
+ - run: python -m pip install .[dev]
22
+ - run: python -m flake8 common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
23
+ - run: python -m isort common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py --check-only --diff
24
+ - run: python -m black common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py --check --diff
.github/workflows/release-drafter.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release Drafter
2
+
3
+ on:
4
+ push:
5
+ # branches to consider in the event; optional, defaults to all
6
+ branches:
7
+ - master
8
+
9
+ jobs:
10
+ update_release_draft:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ # Drafts your next Release notes as Pull Requests are merged into "master"
14
+ - uses: release-drafter/release-drafter@v5.23.0
15
+ env:
16
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
.gitignore CHANGED
@@ -14,9 +14,11 @@ experiments
14
  third_party/REKD
15
  hloc/matchers/dedode.py
16
  gradio_cached_examples
17
-
18
  hloc/matchers/quadtree.py
19
  third_party/QuadTreeAttention
20
  desktop.ini
 
 
21
  experiments*
22
  gen_example.py
 
14
  third_party/REKD
15
  hloc/matchers/dedode.py
16
  gradio_cached_examples
17
+ *.mp4
18
  hloc/matchers/quadtree.py
19
  third_party/QuadTreeAttention
20
  desktop.ini
21
+ *.egg-info
22
+ output.pkl
23
  experiments*
24
  gen_example.py
README.md CHANGED
@@ -44,6 +44,8 @@ The tool currently supports various popular image matching algorithms, namely:
44
  - [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023
45
  - [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
46
  - [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
 
 
47
  - [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
48
  - [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
49
  - [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
 
44
  - [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023
45
  - [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
46
  - [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
47
+ - [x] [SFD2](https://github.com/feixue94/sfd2), CVPR 2023
48
+ - [x] [IMP](https://github.com/feixue94/imp-release), CVPR 2023
49
  - [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
50
  - [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
51
  - [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
common/api.py CHANGED
@@ -1,26 +1,23 @@
1
- import cv2
2
- import torch
3
  import warnings
4
- import numpy as np
5
  from pathlib import Path
6
- from typing import Dict, Any, Optional, Tuple, List, Union
7
- from hloc import logger
8
- from hloc import match_dense, match_features, extract_features
 
 
 
 
 
9
  from hloc.utils.viz import add_text, plot_keypoints
 
10
  from .utils import (
11
- load_config,
12
- get_model,
13
- get_feature_model,
14
- filter_matches,
15
- device,
16
  ROOT,
 
 
 
 
17
  )
18
- from .viz import (
19
- fig2im,
20
- plot_images,
21
- display_matches,
22
- )
23
- import matplotlib.pyplot as plt
24
 
25
  warnings.simplefilter("ignore")
26
 
@@ -109,7 +106,7 @@ class ImageMatchingAPI(torch.nn.Module):
109
  "match_threshold"
110
  ] = match_threshold
111
  except TypeError as e:
112
- breakpoint()
113
  else:
114
  self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
115
  self.conf["feature"]["model"][
@@ -137,7 +134,9 @@ class ImageMatchingAPI(torch.nn.Module):
137
  self.match_conf["preprocessing"],
138
  device=self.device,
139
  )
140
- last_fixed = "{}".format(self.match_conf["model"]["name"])
 
 
141
  else:
142
  pred0 = extract_features.extract(
143
  self.extractor, img0, self.extract_conf["preprocessing"]
@@ -290,7 +289,5 @@ class ImageMatchingAPI(torch.nn.Module):
290
 
291
 
292
  if __name__ == "__main__":
293
- import argparse
294
-
295
  config = load_config(ROOT / "common/config.yaml")
296
- test_api(config)
 
 
 
1
  import warnings
 
2
  from pathlib import Path
3
+ from typing import Any, Dict, Optional
4
+
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+
10
+ from hloc import extract_features, logger, match_dense, match_features
11
  from hloc.utils.viz import add_text, plot_keypoints
12
+
13
  from .utils import (
 
 
 
 
 
14
  ROOT,
15
+ filter_matches,
16
+ get_feature_model,
17
+ get_model,
18
+ load_config,
19
  )
20
+ from .viz import display_matches, fig2im, plot_images
 
 
 
 
 
21
 
22
  warnings.simplefilter("ignore")
23
 
 
106
  "match_threshold"
107
  ] = match_threshold
108
  except TypeError as e:
109
+ logger.error(e)
110
  else:
111
  self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
112
  self.conf["feature"]["model"][
 
134
  self.match_conf["preprocessing"],
135
  device=self.device,
136
  )
137
+ last_fixed = "{}".format( # noqa: F841
138
+ self.match_conf["model"]["name"]
139
+ )
140
  else:
141
  pred0 = extract_features.extract(
142
  self.extractor, img0, self.extract_conf["preprocessing"]
 
289
 
290
 
291
  if __name__ == "__main__":
 
 
292
  config = load_config(ROOT / "common/config.yaml")
293
+ api = ImageMatchingAPI(config)
common/app_class.py CHANGED
@@ -1,22 +1,21 @@
1
- import argparse
2
- import numpy as np
3
- import gradio as gr
4
  from pathlib import Path
5
- from typing import Dict, Any, Optional, Tuple, List, Union
 
 
 
 
6
  from common.utils import (
7
- ransac_zoo,
 
8
  generate_warp_images,
9
- load_config,
10
  get_matcher_zoo,
 
 
11
  run_matching,
12
  run_ransac,
13
  send_to_match,
14
- gen_examples,
15
- GRADIO_VERSION,
16
- ROOT,
17
  )
18
 
19
-
20
  DESCRIPTION = """
21
  # Image Matching WebUI
22
  This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
@@ -132,12 +131,14 @@ class ImageMatchingApp:
132
  label="Keypoint thres.",
133
  value=0.015,
134
  )
135
- detect_line_threshold = gr.Slider(
136
- minimum=0.1,
137
- maximum=1,
138
- step=0.01,
139
- label="Line thres.",
140
- value=0.2,
 
 
141
  )
142
  # matcher_lists = gr.Radio(
143
  # ["NN-mutual", "Dual-Softmax"],
 
 
 
 
1
  from pathlib import Path
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
  from common.utils import (
8
+ GRADIO_VERSION,
9
+ gen_examples,
10
  generate_warp_images,
 
11
  get_matcher_zoo,
12
+ load_config,
13
+ ransac_zoo,
14
  run_matching,
15
  run_ransac,
16
  send_to_match,
 
 
 
17
  )
18
 
 
19
  DESCRIPTION = """
20
  # Image Matching WebUI
21
  This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
 
131
  label="Keypoint thres.",
132
  value=0.015,
133
  )
134
+ detect_line_threshold = ( # noqa: F841
135
+ gr.Slider(
136
+ minimum=0.1,
137
+ maximum=1,
138
+ step=0.01,
139
+ label="Line thres.",
140
+ value=0.2,
141
+ )
142
  )
143
  # matcher_lists = gr.Radio(
144
  # ["NN-mutual", "Dual-Softmax"],
common/config.yaml CHANGED
@@ -30,20 +30,22 @@ matcher_zoo:
30
  DUSt3R:
31
  # TODO: duster is under development
32
  enable: true
 
33
  matcher: duster
34
  dense: true
35
- info:
36
  name: DUSt3R #dispaly name
37
  source: "CVPR 2024"
38
  github: https://github.com/naver/dust3r
39
  paper: https://arxiv.org/abs/2312.14132
40
  project: https://dust3r.europe.naverlabs.com
41
- display: true
42
  GIM(dkm):
43
  enable: true
 
44
  matcher: gim(dkm)
45
  dense: true
46
- info:
47
  name: GIM(DKM) #dispaly name
48
  source: "ICLR 2024"
49
  github: https://github.com/xuelunshen/gim
@@ -52,8 +54,9 @@ matcher_zoo:
52
  display: true
53
  RoMa:
54
  matcher: roma
 
55
  dense: true
56
- info:
57
  name: RoMa #dispaly name
58
  source: "CVPR 2024"
59
  github: https://github.com/Parskatt/RoMa
@@ -62,8 +65,9 @@ matcher_zoo:
62
  display: true
63
  dkm:
64
  matcher: dkm
 
65
  dense: true
66
- info:
67
  name: DKM #dispaly name
68
  source: "CVPR 2023"
69
  github: https://github.com/Parskatt/DKM
@@ -73,7 +77,7 @@ matcher_zoo:
73
  loftr:
74
  matcher: loftr
75
  dense: true
76
- info:
77
  name: LoFTR #dispaly name
78
  source: "CVPR 2021"
79
  github: https://github.com/zju3dv/LoFTR
@@ -82,6 +86,7 @@ matcher_zoo:
82
  display: true
83
  cotr:
84
  enable: false
 
85
  matcher: cotr
86
  dense: true
87
  info:
@@ -363,3 +368,31 @@ matcher_zoo:
363
  paper: https://arxiv.org/abs/2104.03362
364
  project: null
365
  display: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  DUSt3R:
31
  # TODO: duster is under development
32
  enable: true
33
+ skip_ci: true
34
  matcher: duster
35
  dense: true
36
+ info:
37
  name: DUSt3R #dispaly name
38
  source: "CVPR 2024"
39
  github: https://github.com/naver/dust3r
40
  paper: https://arxiv.org/abs/2312.14132
41
  project: https://dust3r.europe.naverlabs.com
42
+ display: true
43
  GIM(dkm):
44
  enable: true
45
+ skip_ci: true
46
  matcher: gim(dkm)
47
  dense: true
48
+ info:
49
  name: GIM(DKM) #dispaly name
50
  source: "ICLR 2024"
51
  github: https://github.com/xuelunshen/gim
 
54
  display: true
55
  RoMa:
56
  matcher: roma
57
+ skip_ci: true
58
  dense: true
59
+ info:
60
  name: RoMa #dispaly name
61
  source: "CVPR 2024"
62
  github: https://github.com/Parskatt/RoMa
 
65
  display: true
66
  dkm:
67
  matcher: dkm
68
+ skip_ci: true
69
  dense: true
70
+ info:
71
  name: DKM #dispaly name
72
  source: "CVPR 2023"
73
  github: https://github.com/Parskatt/DKM
 
77
  loftr:
78
  matcher: loftr
79
  dense: true
80
+ info:
81
  name: LoFTR #dispaly name
82
  source: "CVPR 2021"
83
  github: https://github.com/zju3dv/LoFTR
 
86
  display: true
87
  cotr:
88
  enable: false
89
+ skip_ci: true
90
  matcher: cotr
91
  dense: true
92
  info:
 
368
  paper: https://arxiv.org/abs/2104.03362
369
  project: null
370
  display: true
371
+
372
+ sfd2+imp:
373
+ matcher: imp
374
+ feature: sfd2
375
+ enable: false
376
+ dense: false
377
+ skip_ci: true
378
+ info:
379
+ name: SFD2+IMP #dispaly name
380
+ source: "CVPR 2023"
381
+ github: https://github.com/feixue94/imp-release
382
+ paper: https://arxiv.org/pdf/2304.14837
383
+ project: https://feixue94.github.io/
384
+ display: true
385
+
386
+ sfd2+mnn:
387
+ matcher: NN-mutual
388
+ feature: sfd2
389
+ enable: false
390
+ dense: false
391
+ skip_ci: true
392
+ info:
393
+ name: SFD2+MNN #dispaly name
394
+ source: "CVPR 2023"
395
+ github: https://github.com/feixue94/sfd2
396
+ paper: https://arxiv.org/abs/2304.14845
397
+ project: https://feixue94.github.io/
398
+ display: true
common/utils.py CHANGED
@@ -1,35 +1,35 @@
1
  import os
2
- import cv2
3
- import sys
4
- import torch
5
  import random
6
- import psutil
7
  import shutil
8
- import numpy as np
9
- import gradio as gr
10
- from PIL import Image
11
  from pathlib import Path
 
 
 
 
 
 
12
  import poselib
13
- from itertools import combinations
14
- from typing import Callable, Dict, Any, Optional, Tuple, List, Union
15
- from hloc import matchers, extractors, logger
16
- from hloc.utils.base_model import dynamic_load
17
- from hloc import match_dense, match_features, extract_features
18
- from .viz import (
19
- fig2im,
20
- plot_images,
21
- display_matches,
22
- display_keypoints,
23
- plot_color_line_matches,
24
  )
25
- import time
26
- import matplotlib.pyplot as plt
27
- import warnings
28
- import tempfile
29
- import pickle
30
 
31
  warnings.simplefilter("ignore")
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
  ROOT = Path(__file__).parent.parent
35
  # some default values
@@ -91,14 +91,13 @@ class ModelCache:
91
  host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
92
  vm = psutil.virtual_memory()
93
  du = shutil.disk_usage(".")
94
- vm_ratio = host_colocation * vm.used / vm.total
95
  if verbose:
96
  logger.info(
97
  f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
98
  )
99
- # logger.info(
100
- # f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
101
- # )
102
  return vm.used / 1e9
103
 
104
  def print_memory_usage(self):
@@ -173,7 +172,7 @@ def get_model(match_conf: Dict[str, Any]):
173
  A matcher model instance.
174
  """
175
  Model = dynamic_load(matchers, match_conf["model"]["name"])
176
- model = Model(match_conf["model"]).eval().to(device)
177
  return model
178
 
179
 
@@ -188,7 +187,7 @@ def get_feature_model(conf: Dict[str, Dict[str, Any]]):
188
  A feature extraction model instance.
189
  """
190
  Model = dynamic_load(extractors, conf["model"]["name"])
191
- model = Model(conf["model"]).eval().to(device)
192
  return model
193
 
194
 
@@ -423,7 +422,7 @@ def _filter_matches_poselib(
423
  elif geometry_type == "Fundamental":
424
  M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options)
425
  else:
426
- raise notImplementedError("Not Implemented")
427
 
428
  return M, np.array(info["inliers"])
429
 
@@ -464,7 +463,7 @@ def proc_ransac_matches(
464
  geometry_type,
465
  )
466
  else:
467
- raise notImplementedError("Not Implemented")
468
 
469
 
470
  def filter_matches(
@@ -617,7 +616,9 @@ def compute_geometry(
617
  geo_info["H1"] = H1.tolist()
618
  geo_info["H2"] = H2.tolist()
619
  except cv2.error as e:
620
- logger.error(f"StereoRectifyUncalibrated failed, skip!")
 
 
621
  return geo_info
622
  else:
623
  return {}
@@ -643,7 +644,6 @@ def wrap_images(
643
  """
644
  h0, w0, _ = img0.shape
645
  h1, w1, _ = img1.shape
646
- result_matrix: Optional[np.ndarray] = None
647
  if geo_info is not None and len(geo_info) != 0:
648
  rectified_image0 = img0
649
  rectified_image1 = None
@@ -656,7 +656,6 @@ def wrap_images(
656
  title: List[str] = []
657
  if geom_type == "Homography":
658
  rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0))
659
- result_matrix = H
660
  title = ["Image 0", "Image 1 - warped"]
661
  elif geom_type == "Fundamental":
662
  if geom_type not in geo_info:
@@ -666,7 +665,6 @@ def wrap_images(
666
  H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
667
  rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0))
668
  rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1))
669
- result_matrix = np.array(geo_info["Fundamental"])
670
  title = ["Image 0 - warped", "Image 1 - warped"]
671
  else:
672
  print("Error: Unknown geometry type")
@@ -705,7 +703,7 @@ def generate_warp_images(
705
  ):
706
  return None, None
707
  geom_info = matches_info["geom_info"]
708
- wrapped_images = None
709
  if choice != "No":
710
  wrapped_image_pair, warped_image = wrap_images(
711
  input_image0, input_image1, geom_info, choice
@@ -805,7 +803,7 @@ def run_ransac(
805
  with open(tmp_state_cache, "wb") as f:
806
  pickle.dump(state_cache, f)
807
 
808
- logger.info(f"Dump results done!")
809
 
810
  return (
811
  output_matches_ransac,
@@ -880,7 +878,7 @@ def run_matching(
880
  output_matches_ransac = None
881
 
882
  # super slow!
883
- if "roma" in key.lower() and device == "cpu":
884
  gr.Info(
885
  f"Success! Please be patient and allow for about 2-3 minutes."
886
  f" Due to CPU inference, {key} is quiet slow."
@@ -905,7 +903,7 @@ def run_matching(
905
 
906
  if model["dense"]:
907
  pred = match_dense.match_images(
908
- matcher, image0, image1, match_conf["preprocessing"], device=device
909
  )
910
  del matcher
911
  extract_conf = None
@@ -1000,7 +998,7 @@ def run_matching(
1000
  tmp_state_cache = "output.pkl"
1001
  with open(tmp_state_cache, "wb") as f:
1002
  pickle.dump(state_cache, f)
1003
- logger.info(f"Dump results done!")
1004
  return (
1005
  output_keypoints,
1006
  output_matches_raw,
 
1
  import os
2
+ import pickle
 
 
3
  import random
 
4
  import shutil
5
+ import time
6
+ import warnings
7
+ from itertools import combinations
8
  from pathlib import Path
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
  import poselib
16
+ import psutil
17
+ from PIL import Image
18
+
19
+ from hloc import (
20
+ DEVICE,
21
+ extract_features,
22
+ extractors,
23
+ logger,
24
+ match_dense,
25
+ match_features,
26
+ matchers,
27
  )
28
+ from hloc.utils.base_model import dynamic_load
29
+
30
+ from .viz import display_keypoints, display_matches, fig2im, plot_images
 
 
31
 
32
  warnings.simplefilter("ignore")
 
33
 
34
  ROOT = Path(__file__).parent.parent
35
  # some default values
 
91
  host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
92
  vm = psutil.virtual_memory()
93
  du = shutil.disk_usage(".")
 
94
  if verbose:
95
  logger.info(
96
  f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
97
  )
98
+ logger.info(
99
+ f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
100
+ )
101
  return vm.used / 1e9
102
 
103
  def print_memory_usage(self):
 
172
  A matcher model instance.
173
  """
174
  Model = dynamic_load(matchers, match_conf["model"]["name"])
175
+ model = Model(match_conf["model"]).eval().to(DEVICE)
176
  return model
177
 
178
 
 
187
  A feature extraction model instance.
188
  """
189
  Model = dynamic_load(extractors, conf["model"]["name"])
190
+ model = Model(conf["model"]).eval().to(DEVICE)
191
  return model
192
 
193
 
 
422
  elif geometry_type == "Fundamental":
423
  M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options)
424
  else:
425
+ raise NotImplementedError
426
 
427
  return M, np.array(info["inliers"])
428
 
 
463
  geometry_type,
464
  )
465
  else:
466
+ raise NotImplementedError
467
 
468
 
469
  def filter_matches(
 
616
  geo_info["H1"] = H1.tolist()
617
  geo_info["H2"] = H2.tolist()
618
  except cv2.error as e:
619
+ logger.error(
620
+ f"StereoRectifyUncalibrated failed, skip! error: {e}"
621
+ )
622
  return geo_info
623
  else:
624
  return {}
 
644
  """
645
  h0, w0, _ = img0.shape
646
  h1, w1, _ = img1.shape
 
647
  if geo_info is not None and len(geo_info) != 0:
648
  rectified_image0 = img0
649
  rectified_image1 = None
 
656
  title: List[str] = []
657
  if geom_type == "Homography":
658
  rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0))
 
659
  title = ["Image 0", "Image 1 - warped"]
660
  elif geom_type == "Fundamental":
661
  if geom_type not in geo_info:
 
665
  H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
666
  rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0))
667
  rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1))
 
668
  title = ["Image 0 - warped", "Image 1 - warped"]
669
  else:
670
  print("Error: Unknown geometry type")
 
703
  ):
704
  return None, None
705
  geom_info = matches_info["geom_info"]
706
+ warped_image = None
707
  if choice != "No":
708
  wrapped_image_pair, warped_image = wrap_images(
709
  input_image0, input_image1, geom_info, choice
 
803
  with open(tmp_state_cache, "wb") as f:
804
  pickle.dump(state_cache, f)
805
 
806
+ logger.info("Dump results done!")
807
 
808
  return (
809
  output_matches_ransac,
 
878
  output_matches_ransac = None
879
 
880
  # super slow!
881
+ if "roma" in key.lower() and DEVICE == "cpu":
882
  gr.Info(
883
  f"Success! Please be patient and allow for about 2-3 minutes."
884
  f" Due to CPU inference, {key} is quiet slow."
 
903
 
904
  if model["dense"]:
905
  pred = match_dense.match_images(
906
+ matcher, image0, image1, match_conf["preprocessing"], device=DEVICE
907
  )
908
  del matcher
909
  extract_conf = None
 
998
  tmp_state_cache = "output.pkl"
999
  with open(tmp_state_cache, "wb") as f:
1000
  pickle.dump(state_cache, f)
1001
+ logger.info("Dump results done!")
1002
  return (
1003
  output_keypoints,
1004
  output_matches_raw,
common/viz.py CHANGED
@@ -1,11 +1,13 @@
1
- import cv2
2
  import typing
 
 
 
 
3
  import matplotlib
 
4
  import numpy as np
5
  import seaborn as sns
6
- import matplotlib.pyplot as plt
7
- from pathlib import Path
8
- from typing import Dict, Any, Optional, Tuple, List, Union
9
  from hloc.utils.viz import add_text, plot_keypoints
10
 
11
 
 
 
1
  import typing
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import cv2
6
  import matplotlib
7
+ import matplotlib.pyplot as plt
8
  import numpy as np
9
  import seaborn as sns
10
+
 
 
11
  from hloc.utils.viz import add_text, plot_keypoints
12
 
13
 
docker/build_docker.bat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ docker build -t image-matching-webui:latest . --no-cache
2
+ # docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
3
+ # docker push vincentqin/image-matching-webui:latest
run_docker.sh → docker/run_docker.bat RENAMED
File without changes
docker/run_docker.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860
env-docker.txt DELETED
@@ -1,33 +0,0 @@
1
- e2cnn==0.2.3
2
- einops==0.6.1
3
- gdown==4.7.1
4
- gradio==4.28.3
5
- gradio_client==0.16.0
6
- h5py==3.9.0
7
- imageio==2.31.1
8
- Jinja2==3.1.2
9
- kornia==0.6.12
10
- loguru==0.7.0
11
- matplotlib==3.7.1
12
- numpy==1.23.5
13
- omegaconf==2.3.0
14
- opencv-contrib-python==4.6.0.66
15
- opencv-python==4.6.0.66
16
- pandas==2.0.3
17
- plotly==5.15.0
18
- protobuf==4.23.2
19
- pycolmap==0.5.0
20
- pytlsd==0.0.2
21
- pytorch-lightning==1.4.9
22
- PyYAML==6.0
23
- scikit-image==0.21.0
24
- scikit-learn==1.2.2
25
- scipy==1.11.1
26
- seaborn==0.12.2
27
- shapely==2.0.1
28
- tensorboardX==2.6.1
29
- torchmetrics==0.6.0
30
- torchvision==0.17.1
31
- tqdm==4.65.0
32
- yacs==0.1.8
33
- onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
format.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ python -m flake8 common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
2
+ python -m isort common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
3
+ python -m black common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
hloc/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  import torch
3
  from packaging import version
4
 
 
1
  import logging
2
+
3
  import torch
4
  from packaging import version
5
 
hloc/extract_features.py CHANGED
@@ -1,21 +1,22 @@
1
  import argparse
2
- import torch
 
3
  from pathlib import Path
4
- from typing import Dict, List, Union, Optional
5
- import h5py
6
  from types import SimpleNamespace
 
 
7
  import cv2
 
8
  import numpy as np
9
- from tqdm import tqdm
10
- import pprint
11
- import collections.abc as collections
12
  import PIL.Image
 
13
  import torchvision.transforms.functional as F
 
 
14
  from . import extractors, logger
15
  from .utils.base_model import dynamic_load
 
16
  from .utils.parsers import parse_image_lists
17
- from .utils.io import read_image, list_h5_names
18
-
19
 
20
  """
21
  A set of standard configurations that can be directly selected from the command
@@ -290,6 +291,23 @@ confs = {
290
  "dfactor": 8,
291
  },
292
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  # Global descriptors
294
  "dir": {
295
  "output": "global-feats-dir",
@@ -460,7 +478,7 @@ def extract(model, image_0, conf):
460
  # image0 = image_0[:, :, ::-1] # BGR to RGB
461
  data = preprocess(image0, conf)
462
  pred = model({"image": data["image"]})
463
- pred["image_size"] = original_size = data["original_size"]
464
  pred = {**pred, **data}
465
  return pred
466
 
 
1
  import argparse
2
+ import collections.abc as collections
3
+ import pprint
4
  from pathlib import Path
 
 
5
  from types import SimpleNamespace
6
+ from typing import Dict, List, Optional, Union
7
+
8
  import cv2
9
+ import h5py
10
  import numpy as np
 
 
 
11
  import PIL.Image
12
+ import torch
13
  import torchvision.transforms.functional as F
14
+ from tqdm import tqdm
15
+
16
  from . import extractors, logger
17
  from .utils.base_model import dynamic_load
18
+ from .utils.io import list_h5_names, read_image
19
  from .utils.parsers import parse_image_lists
 
 
20
 
21
  """
22
  A set of standard configurations that can be directly selected from the command
 
291
  "dfactor": 8,
292
  },
293
  },
294
+ "sfd2": {
295
+ "output": "feats-sfd2-n4096-r1600",
296
+ "model": {
297
+ "name": "sfd2",
298
+ "max_keypoints": 4096,
299
+ },
300
+ "preprocessing": {
301
+ "grayscale": False,
302
+ "force_resize": True,
303
+ "resize_max": 1600,
304
+ "width": 640,
305
+ "height": 480,
306
+ "conf_th": 0.001,
307
+ "multiscale": False,
308
+ "scales": [1.0],
309
+ },
310
+ },
311
  # Global descriptors
312
  "dir": {
313
  "output": "global-feats-dir",
 
478
  # image0 = image_0[:, :, ::-1] # BGR to RGB
479
  data = preprocess(image0, conf)
480
  pred = model({"image": data["image"]})
481
+ pred["image_size"] = data["original_size"]
482
  pred = {**pred, **data}
483
  return pred
484
 
hloc/extractors/alike.py CHANGED
@@ -1,10 +1,12 @@
1
  import sys
2
  from pathlib import Path
 
3
  import torch
4
 
5
- from ..utils.base_model import BaseModel
6
  from hloc import logger
7
 
 
 
8
  alike_path = Path(__file__).parent / "../../third_party/ALIKE"
9
  sys.path.append(str(alike_path))
10
  from alike import ALike as Alike_
@@ -34,7 +36,7 @@ class Alike(BaseModel):
34
  scores_th=conf["detection_threshold"],
35
  n_limit=conf["max_keypoints"],
36
  )
37
- logger.info(f"Load Alike model done.")
38
 
39
  def _forward(self, data):
40
  image = data["image"]
 
1
  import sys
2
  from pathlib import Path
3
+
4
  import torch
5
 
 
6
  from hloc import logger
7
 
8
+ from ..utils.base_model import BaseModel
9
+
10
  alike_path = Path(__file__).parent / "../../third_party/ALIKE"
11
  sys.path.append(str(alike_path))
12
  from alike import ALike as Alike_
 
36
  scores_th=conf["detection_threshold"],
37
  n_limit=conf["max_keypoints"],
38
  )
39
+ logger.info("Load Alike model done.")
40
 
41
  def _forward(self, data):
42
  image = data["image"]
hloc/extractors/d2net.py CHANGED
@@ -1,17 +1,17 @@
 
1
  import sys
2
  from pathlib import Path
3
- import subprocess
4
  import torch
5
 
6
- from ..utils.base_model import BaseModel
7
  from hloc import logger
8
 
9
- d2net_path = Path(__file__).parent / "../../third_party"
10
- sys.path.append(str(d2net_path))
11
- from d2net.lib.model_test import D2Net as _D2Net
12
- from d2net.lib.pyramid import process_multiscale
13
 
14
  d2net_path = Path(__file__).parent / "../../third_party/d2net"
 
 
 
15
 
16
 
17
  class D2Net(BaseModel):
@@ -30,6 +30,7 @@ class D2Net(BaseModel):
30
  model_file.parent.mkdir(exist_ok=True)
31
  cmd = [
32
  "wget",
 
33
  "https://dusmanu.com/files/d2-net/" + conf["model_name"],
34
  "-O",
35
  str(model_file),
@@ -39,7 +40,7 @@ class D2Net(BaseModel):
39
  self.net = _D2Net(
40
  model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
41
  )
42
- logger.info(f"Load D2Net model done.")
43
 
44
  def _forward(self, data):
45
  image = data["image"]
 
1
+ import subprocess
2
  import sys
3
  from pathlib import Path
4
+
5
  import torch
6
 
 
7
  from hloc import logger
8
 
9
+ from ..utils.base_model import BaseModel
 
 
 
10
 
11
  d2net_path = Path(__file__).parent / "../../third_party/d2net"
12
+ sys.path.append(str(d2net_path))
13
+ from lib.model_test import D2Net as _D2Net
14
+ from lib.pyramid import process_multiscale
15
 
16
 
17
  class D2Net(BaseModel):
 
30
  model_file.parent.mkdir(exist_ok=True)
31
  cmd = [
32
  "wget",
33
+ "--quiet",
34
  "https://dusmanu.com/files/d2-net/" + conf["model_name"],
35
  "-O",
36
  str(model_file),
 
40
  self.net = _D2Net(
41
  model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
42
  )
43
+ logger.info("Load D2Net model done.")
44
 
45
  def _forward(self, data):
46
  image = data["image"]
hloc/extractors/darkfeat.py CHANGED
@@ -1,9 +1,12 @@
1
  import sys
2
  from pathlib import Path
 
3
  import subprocess
4
- from ..utils.base_model import BaseModel
5
  from hloc import logger
6
 
 
 
7
  darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
8
  sys.path.append(str(darkfeat_path))
9
  from darkfeat import DarkFeat as DarkFeat_
@@ -43,7 +46,7 @@ class DarkFeat(BaseModel):
43
  raise e
44
 
45
  self.net = DarkFeat_(model_path)
46
- logger.info(f"Load DarkFeat model done.")
47
 
48
  def _forward(self, data):
49
  pred = self.net({"image": data["image"]})
 
1
  import sys
2
  from pathlib import Path
3
+
4
  import subprocess
5
+
6
  from hloc import logger
7
 
8
+ from ..utils.base_model import BaseModel
9
+
10
  darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
11
  sys.path.append(str(darkfeat_path))
12
  from darkfeat import DarkFeat as DarkFeat_
 
46
  raise e
47
 
48
  self.net = DarkFeat_(model_path)
49
+ logger.info("Load DarkFeat model done.")
50
 
51
  def _forward(self, data):
52
  pred = self.net({"image": data["image"]})
hloc/extractors/dedode.py CHANGED
@@ -1,16 +1,18 @@
 
1
  import sys
2
  from pathlib import Path
3
- import subprocess
4
  import torch
5
- from PIL import Image
6
- from ..utils.base_model import BaseModel
7
- from hloc import logger
8
  import torchvision.transforms as transforms
9
 
 
 
 
 
10
  dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
11
  sys.path.append(str(dedode_path))
12
 
13
- from DeDoDe import dedode_detector_L, dedode_descriptor_B
14
  from DeDoDe.utils import to_pixel_coords
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -49,14 +51,14 @@ class DeDoDe(BaseModel):
49
  if not model_detector_path.exists():
50
  model_detector_path.parent.mkdir(exist_ok=True)
51
  link = self.weight_urls[conf["model_detector_name"]]
52
- cmd = ["wget", link, "-O", str(model_detector_path)]
53
  logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
54
  subprocess.run(cmd, check=True)
55
 
56
  if not model_descriptor_path.exists():
57
  model_descriptor_path.parent.mkdir(exist_ok=True)
58
  link = self.weight_urls[conf["model_descriptor_name"]]
59
- cmd = ["wget", link, "-O", str(model_descriptor_path)]
60
  logger.info(
61
  f"Downloading the DeDoDe descriptor model with `{cmd}`."
62
  )
@@ -73,7 +75,7 @@ class DeDoDe(BaseModel):
73
  self.descriptor = dedode_descriptor_B(
74
  weights=weights_descriptor, device=device
75
  )
76
- logger.info(f"Load DeDoDe model done.")
77
 
78
  def _forward(self, data):
79
  """
 
1
+ import subprocess
2
  import sys
3
  from pathlib import Path
4
+
5
  import torch
 
 
 
6
  import torchvision.transforms as transforms
7
 
8
+ from hloc import logger
9
+
10
+ from ..utils.base_model import BaseModel
11
+
12
  dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
13
  sys.path.append(str(dedode_path))
14
 
15
+ from DeDoDe import dedode_descriptor_B, dedode_detector_L
16
  from DeDoDe.utils import to_pixel_coords
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
51
  if not model_detector_path.exists():
52
  model_detector_path.parent.mkdir(exist_ok=True)
53
  link = self.weight_urls[conf["model_detector_name"]]
54
+ cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)]
55
  logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
56
  subprocess.run(cmd, check=True)
57
 
58
  if not model_descriptor_path.exists():
59
  model_descriptor_path.parent.mkdir(exist_ok=True)
60
  link = self.weight_urls[conf["model_descriptor_name"]]
61
+ cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)]
62
  logger.info(
63
  f"Downloading the DeDoDe descriptor model with `{cmd}`."
64
  )
 
75
  self.descriptor = dedode_descriptor_B(
76
  weights=weights_descriptor, device=device
77
  )
78
+ logger.info("Load DeDoDe model done.")
79
 
80
  def _forward(self, data):
81
  """
hloc/extractors/dir.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import sys
2
  from pathlib import Path
3
- import torch
4
  from zipfile import ZipFile
5
- import os
6
- import sklearn
7
  import gdown
 
 
8
 
9
  from ..utils.base_model import BaseModel
10
 
@@ -13,8 +14,8 @@ sys.path.append(
13
  )
14
  os.environ["DB_ROOT"] = "" # required by dirtorch
15
 
16
- from dirtorch.utils import common # noqa: E402
17
  from dirtorch.extract_features import load_model # noqa: E402
 
18
 
19
  # The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
20
  # which has been deprecated in sklearn v0.24
 
1
+ import os
2
  import sys
3
  from pathlib import Path
 
4
  from zipfile import ZipFile
5
+
 
6
  import gdown
7
+ import sklearn
8
+ import torch
9
 
10
  from ..utils.base_model import BaseModel
11
 
 
14
  )
15
  os.environ["DB_ROOT"] = "" # required by dirtorch
16
 
 
17
  from dirtorch.extract_features import load_model # noqa: E402
18
+ from dirtorch.utils import common # noqa: E402
19
 
20
  # The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
21
  # which has been deprecated in sklearn v0.24
hloc/extractors/disk.py CHANGED
@@ -15,7 +15,7 @@ class DISK(BaseModel):
15
 
16
  def _init(self, conf):
17
  self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
18
- logger.info(f"Load DISK model done.")
19
 
20
  def _forward(self, data):
21
  image = data["image"]
 
15
 
16
  def _init(self, conf):
17
  self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
18
+ logger.info("Load DISK model done.")
19
 
20
  def _forward(self, data):
21
  image = data["image"]
hloc/extractors/dog.py CHANGED
@@ -1,15 +1,14 @@
1
  import kornia
 
 
 
2
  from kornia.feature.laf import (
3
- laf_from_center_scale_ori,
4
  extract_patches_from_pyramid,
 
5
  )
6
- import numpy as np
7
- import torch
8
- import pycolmap
9
 
10
  from ..utils.base_model import BaseModel
11
 
12
-
13
  EPS = 1e-6
14
 
15
 
 
1
  import kornia
2
+ import numpy as np
3
+ import pycolmap
4
+ import torch
5
  from kornia.feature.laf import (
 
6
  extract_patches_from_pyramid,
7
+ laf_from_center_scale_ori,
8
  )
 
 
 
9
 
10
  from ..utils.base_model import BaseModel
11
 
 
12
  EPS = 1e-6
13
 
14
 
hloc/extractors/example.py CHANGED
@@ -1,9 +1,9 @@
1
  import sys
2
  from pathlib import Path
3
- import subprocess
4
  import torch
5
- from .. import logger
6
 
 
7
  from ..utils.base_model import BaseModel
8
 
9
  example_path = Path(__file__).parent / "../../third_party/example"
@@ -35,7 +35,7 @@ class Example(BaseModel):
35
  # self.net = ExampleNet(is_test=True)
36
  state_dict = torch.load(model_path, map_location="cpu")
37
  self.net.load_state_dict(state_dict["model_state"])
38
- logger.info(f"Load example model done.")
39
 
40
  def _forward(self, data):
41
  # data: dict, keys: 'image'
 
1
  import sys
2
  from pathlib import Path
3
+
4
  import torch
 
5
 
6
+ from .. import logger
7
  from ..utils.base_model import BaseModel
8
 
9
  example_path = Path(__file__).parent / "../../third_party/example"
 
35
  # self.net = ExampleNet(is_test=True)
36
  state_dict = torch.load(model_path, map_location="cpu")
37
  self.net.load_state_dict(state_dict["model_state"])
38
+ logger.info("Load example model done.")
39
 
40
  def _forward(self, data):
41
  # data: dict, keys: 'image'
hloc/extractors/fire.py CHANGED
@@ -1,7 +1,8 @@
1
- from pathlib import Path
2
- import subprocess
3
  import logging
 
4
  import sys
 
 
5
  import torch
6
  import torchvision.transforms as tvf
7
 
@@ -42,11 +43,11 @@ class FIRe(BaseModel):
42
  if not model_path.exists():
43
  model_path.parent.mkdir(exist_ok=True)
44
  link = self.fire_models[conf["model_name"]]
45
- cmd = ["wget", link, "-O", str(model_path)]
46
  logger.info(f"Downloading the FIRe model with `{cmd}`.")
47
  subprocess.run(cmd, check=True)
48
 
49
- logger.info(f"Loading fire model...")
50
 
51
  # Load net
52
  state = torch.load(model_path)
 
 
 
1
  import logging
2
+ import subprocess
3
  import sys
4
+ from pathlib import Path
5
+
6
  import torch
7
  import torchvision.transforms as tvf
8
 
 
43
  if not model_path.exists():
44
  model_path.parent.mkdir(exist_ok=True)
45
  link = self.fire_models[conf["model_name"]]
46
+ cmd = ["wget", "--quiet", link, "-O", str(model_path)]
47
  logger.info(f"Downloading the FIRe model with `{cmd}`.")
48
  subprocess.run(cmd, check=True)
49
 
50
+ logger.info("Loading fire model...")
51
 
52
  # Load net
53
  state = torch.load(model_path)
hloc/extractors/fire_local.py CHANGED
@@ -1,11 +1,12 @@
1
- from pathlib import Path
2
  import subprocess
3
  import sys
 
 
4
  import torch
5
  import torchvision.transforms as tvf
6
 
7
- from ..utils.base_model import BaseModel
8
  from .. import logger
 
9
 
10
  fire_path = Path(__file__).parent / "../../third_party/fire"
11
 
@@ -13,10 +14,6 @@ sys.path.append(str(fire_path))
13
 
14
 
15
  import fire_network
16
- from lib.how.how.stages.evaluate import eval_asmk_fire, load_dataset_fire
17
-
18
- from lib.asmk import asmk
19
- from asmk import io_helpers, asmk_method, kernel as kern_pkg
20
 
21
  EPS = 1e-6
22
 
@@ -44,18 +41,18 @@ class FIRe(BaseModel):
44
 
45
  # Config paths
46
  model_path = fire_path / "model" / conf["model_name"]
47
- config_path = fire_path / conf["config_name"]
48
- asmk_bin_path = fire_path / "model" / conf["asmk_name"]
49
 
50
  # Download the model.
51
  if not model_path.exists():
52
  model_path.parent.mkdir(exist_ok=True)
53
  link = self.fire_models[conf["model_name"]]
54
- cmd = ["wget", link, "-O", str(model_path)]
55
  logger.info(f"Downloading the FIRe model with `{cmd}`.")
56
  subprocess.run(cmd, check=True)
57
 
58
- logger.info(f"Loading fire model...")
59
 
60
  # Load net
61
  state = torch.load(model_path)
 
 
1
  import subprocess
2
  import sys
3
+ from pathlib import Path
4
+
5
  import torch
6
  import torchvision.transforms as tvf
7
 
 
8
  from .. import logger
9
+ from ..utils.base_model import BaseModel
10
 
11
  fire_path = Path(__file__).parent / "../../third_party/fire"
12
 
 
14
 
15
 
16
  import fire_network
 
 
 
 
17
 
18
  EPS = 1e-6
19
 
 
41
 
42
  # Config paths
43
  model_path = fire_path / "model" / conf["model_name"]
44
+ config_path = fire_path / conf["config_name"] # noqa: F841
45
+ asmk_bin_path = fire_path / "model" / conf["asmk_name"] # noqa: F841
46
 
47
  # Download the model.
48
  if not model_path.exists():
49
  model_path.parent.mkdir(exist_ok=True)
50
  link = self.fire_models[conf["model_name"]]
51
+ cmd = ["wget", "--quiet", link, "-O", str(model_path)]
52
  logger.info(f"Downloading the FIRe model with `{cmd}`.")
53
  subprocess.run(cmd, check=True)
54
 
55
+ logger.info("Loading fire model...")
56
 
57
  # Load net
58
  state = torch.load(model_path)
hloc/extractors/lanet.py CHANGED
@@ -1,14 +1,17 @@
 
1
  import sys
2
  from pathlib import Path
3
- import subprocess
4
  import torch
 
5
 
6
  from ..utils.base_model import BaseModel
7
- from hloc import logger
 
 
 
8
 
9
  lanet_path = Path(__file__).parent / "../../third_party/lanet"
10
- sys.path.append(str(lanet_path))
11
- from network_v0.model import PointModel
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
@@ -26,11 +29,11 @@ class LANet(BaseModel):
26
  lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
27
  )
28
  if not model_path.exists():
29
- print(f"No model found at {model_path}")
30
  self.net = PointModel(is_test=True)
31
  state_dict = torch.load(model_path, map_location="cpu")
32
  self.net.load_state_dict(state_dict["model_state"])
33
- logger.info(f"Load LANet model done.")
34
 
35
  def _forward(self, data):
36
  image = data["image"]
 
1
+ import subprocess
2
  import sys
3
  from pathlib import Path
4
+
5
  import torch
6
+ from hloc import logger
7
 
8
  from ..utils.base_model import BaseModel
9
+
10
+ lib_path = Path(__file__).parent / "../../third_party"
11
+ sys.path.append(str(lib_path))
12
+ from lanet.network_v0.model import PointModel
13
 
14
  lanet_path = Path(__file__).parent / "../../third_party/lanet"
 
 
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
 
29
  lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
30
  )
31
  if not model_path.exists():
32
+ logger.warning(f"No model found at {model_path}, start downloading")
33
  self.net = PointModel(is_test=True)
34
  state_dict = torch.load(model_path, map_location="cpu")
35
  self.net.load_state_dict(state_dict["model_state"])
36
+ logger.info("Load LANet model done.")
37
 
38
  def _forward(self, data):
39
  image = data["image"]
hloc/extractors/netvlad.py CHANGED
@@ -1,5 +1,6 @@
1
- from pathlib import Path
2
  import subprocess
 
 
3
  import numpy as np
4
  import torch
5
  import torch.nn as nn
@@ -7,8 +8,8 @@ import torch.nn.functional as F
7
  import torchvision.models as models
8
  from scipy.io import loadmat
9
 
10
- from ..utils.base_model import BaseModel
11
  from .. import logger
 
12
 
13
  EPS = 1e-6
14
 
@@ -60,7 +61,7 @@ class NetVLAD(BaseModel):
60
  if not checkpoint.exists():
61
  checkpoint.parent.mkdir(exist_ok=True, parents=True)
62
  link = self.dir_models[conf["model_name"]]
63
- cmd = ["wget", link, "-O", str(checkpoint)]
64
  logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
65
  subprocess.run(cmd, check=True)
66
 
 
 
1
  import subprocess
2
+ from pathlib import Path
3
+
4
  import numpy as np
5
  import torch
6
  import torch.nn as nn
 
8
  import torchvision.models as models
9
  from scipy.io import loadmat
10
 
 
11
  from .. import logger
12
+ from ..utils.base_model import BaseModel
13
 
14
  EPS = 1e-6
15
 
 
61
  if not checkpoint.exists():
62
  checkpoint.parent.mkdir(exist_ok=True, parents=True)
63
  link = self.dir_models[conf["model_name"]]
64
+ cmd = ["wget", "--quiet", link, "-O", str(checkpoint)]
65
  logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
66
  subprocess.run(cmd, check=True)
67
 
hloc/extractors/r2d2.py CHANGED
@@ -1,14 +1,15 @@
1
  import sys
2
  from pathlib import Path
 
3
  import torchvision.transforms as tvf
4
 
5
- from ..utils.base_model import BaseModel
6
  from hloc import logger
7
 
8
- base_path = Path(__file__).parent / "../../third_party"
9
- sys.path.append(str(base_path))
10
  r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
11
- from r2d2.extract import load_network, NonMaxSuppression, extract_multiscale
 
12
 
13
 
14
  class R2D2(BaseModel):
@@ -35,7 +36,7 @@ class R2D2(BaseModel):
35
  rel_thr=conf["reliability_threshold"],
36
  rep_thr=conf["repetability_threshold"],
37
  )
38
- logger.info(f"Load R2D2 model done.")
39
 
40
  def _forward(self, data):
41
  img = data["image"]
 
1
  import sys
2
  from pathlib import Path
3
+
4
  import torchvision.transforms as tvf
5
 
 
6
  from hloc import logger
7
 
8
+ from ..utils.base_model import BaseModel
9
+
10
  r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
11
+ sys.path.append(str(r2d2_path))
12
+ from extract import NonMaxSuppression, extract_multiscale, load_network
13
 
14
 
15
  class R2D2(BaseModel):
 
36
  rel_thr=conf["reliability_threshold"],
37
  rep_thr=conf["repetability_threshold"],
38
  )
39
+ logger.info("Load R2D2 model done.")
40
 
41
  def _forward(self, data):
42
  img = data["image"]
hloc/extractors/rekd.py CHANGED
@@ -1,11 +1,12 @@
1
  import sys
2
  from pathlib import Path
3
- import subprocess
4
  import torch
5
 
6
- from ..utils.base_model import BaseModel
7
  from hloc import logger
8
 
 
 
9
  rekd_path = Path(__file__).parent / "../../third_party"
10
  sys.path.append(str(rekd_path))
11
  from REKD.training.model.REKD import REKD as REKD_
@@ -29,7 +30,7 @@ class REKD(BaseModel):
29
  self.net = REKD_(is_test=True)
30
  state_dict = torch.load(model_path, map_location="cpu")
31
  self.net.load_state_dict(state_dict["model_state"])
32
- logger.info(f"Load REKD model done.")
33
 
34
  def _forward(self, data):
35
  image = data["image"]
 
1
  import sys
2
  from pathlib import Path
3
+
4
  import torch
5
 
 
6
  from hloc import logger
7
 
8
+ from ..utils.base_model import BaseModel
9
+
10
  rekd_path = Path(__file__).parent / "../../third_party"
11
  sys.path.append(str(rekd_path))
12
  from REKD.training.model.REKD import REKD as REKD_
 
30
  self.net = REKD_(is_test=True)
31
  state_dict = torch.load(model_path, map_location="cpu")
32
  self.net.load_state_dict(state_dict["model_state"])
33
+ logger.info("Load REKD model done.")
34
 
35
  def _forward(self, data):
36
  image = data["image"]
hloc/extractors/rord.py CHANGED
@@ -3,9 +3,10 @@ from pathlib import Path
3
  import subprocess
4
  import torch
5
 
6
- from ..utils.base_model import BaseModel
7
  from hloc import logger
8
 
 
 
9
  rord_path = Path(__file__).parent / "../../third_party"
10
  sys.path.append(str(rord_path))
11
  from RoRD.lib.model_test import D2Net as _RoRD
@@ -42,11 +43,10 @@ class RoRD(BaseModel):
42
  subprocess.run(cmd, check=True)
43
  except subprocess.CalledProcessError as e:
44
  logger.error(f"Failed to download the RoRD model.")
45
- raise e
46
  self.net = _RoRD(
47
  model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
48
  )
49
- logger.info(f"Load RoRD model done.")
50
 
51
  def _forward(self, data):
52
  image = data["image"]
 
3
  import subprocess
4
  import torch
5
 
 
6
  from hloc import logger
7
 
8
+ from ..utils.base_model import BaseModel
9
+
10
  rord_path = Path(__file__).parent / "../../third_party"
11
  sys.path.append(str(rord_path))
12
  from RoRD.lib.model_test import D2Net as _RoRD
 
43
  subprocess.run(cmd, check=True)
44
  except subprocess.CalledProcessError as e:
45
  logger.error(f"Failed to download the RoRD model.")
 
46
  self.net = _RoRD(
47
  model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
48
  )
49
+ logger.info("Load RoRD model done.")
50
 
51
  def _forward(self, data):
52
  image = data["image"]
hloc/extractors/sfd2.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import torchvision.transforms as tvf
6
+
7
+ from .. import logger
8
+ from ..utils.base_model import BaseModel
9
+
10
+ pram_path = Path(__file__).parent / "../../third_party/pram"
11
+ sys.path.append(str(pram_path))
12
+
13
+ from nets.sfd2 import load_sfd2
14
+
15
+
16
+ class SFD2(BaseModel):
17
+ default_conf = {
18
+ "max_keypoints": 4096,
19
+ "model_name": "sfd2_20230511_210205_resnet4x.79.pth",
20
+ "conf_th": 0.001,
21
+ }
22
+ required_inputs = ["image"]
23
+
24
+ def _init(self, conf):
25
+ self.conf = {**self.default_conf, **conf}
26
+ self.norm_rgb = tvf.Normalize(
27
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
28
+ )
29
+ model_fn = pram_path / "weights" / self.conf["model_name"]
30
+ self.net = load_sfd2(weight_path=model_fn).eval()
31
+
32
+ logger.info("Load SFD2 model done.")
33
+
34
+ def _forward(self, data):
35
+ pred = self.net.extract_local_global(
36
+ data={"image": self.norm_rgb(data["image"])}, config=self.conf
37
+ )
38
+ out = {
39
+ "keypoints": pred["keypoints"][0][None],
40
+ "scores": pred["scores"][0][None],
41
+ "descriptors": pred["descriptors"][0][None],
42
+ }
43
+ return out
hloc/extractors/sift.py CHANGED
@@ -4,14 +4,15 @@ import cv2
4
  import numpy as np
5
  import torch
6
  from kornia.color import rgb_to_grayscale
7
- from packaging import version
8
  from omegaconf import OmegaConf
 
9
 
10
  try:
11
  import pycolmap
12
  except ImportError:
13
  pycolmap = None
14
  from hloc import logger
 
15
  from ..utils.base_model import BaseModel
16
 
17
 
@@ -140,7 +141,7 @@ class SIFT(BaseModel):
140
  f"Unknown backend: {backend} not in "
141
  f"{{{','.join(backends)}}}."
142
  )
143
- logger.info(f"Load SIFT model done.")
144
 
145
  def extract_single_image(self, image: torch.Tensor):
146
  image_np = image.cpu().numpy().squeeze(0)
 
4
  import numpy as np
5
  import torch
6
  from kornia.color import rgb_to_grayscale
 
7
  from omegaconf import OmegaConf
8
+ from packaging import version
9
 
10
  try:
11
  import pycolmap
12
  except ImportError:
13
  pycolmap = None
14
  from hloc import logger
15
+
16
  from ..utils.base_model import BaseModel
17
 
18
 
 
141
  f"Unknown backend: {backend} not in "
142
  f"{{{','.join(backends)}}}."
143
  )
144
+ logger.info("Load SIFT model done.")
145
 
146
  def extract_single_image(self, image: torch.Tensor):
147
  image_np = image.cpu().numpy().squeeze(0)
hloc/extractors/superpoint.py CHANGED
@@ -1,10 +1,12 @@
1
  import sys
2
  from pathlib import Path
 
3
  import torch
4
 
5
- from ..utils.base_model import BaseModel
6
  from hloc import logger
7
 
 
 
8
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
9
  from SuperGluePretrainedNetwork.models import superpoint # noqa E402
10
 
@@ -43,7 +45,7 @@ class SuperPoint(BaseModel):
43
  if conf["fix_sampling"]:
44
  superpoint.sample_descriptors = sample_descriptors_fix_sampling
45
  self.net = superpoint.SuperPoint(conf)
46
- logger.info(f"Load SuperPoint model done.")
47
 
48
  def _forward(self, data):
49
  return self.net(data, self.conf)
 
1
  import sys
2
  from pathlib import Path
3
+
4
  import torch
5
 
 
6
  from hloc import logger
7
 
8
+ from ..utils.base_model import BaseModel
9
+
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
11
  from SuperGluePretrainedNetwork.models import superpoint # noqa E402
12
 
 
45
  if conf["fix_sampling"]:
46
  superpoint.sample_descriptors = sample_descriptors_fix_sampling
47
  self.net = superpoint.SuperPoint(conf)
48
+ logger.info("Load SuperPoint model done.")
49
 
50
  def _forward(self, data):
51
  return self.net(data, self.conf)
hloc/extractors/xfeat.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
- from pathlib import Path
3
  from hloc import logger
 
4
  from ..utils.base_model import BaseModel
5
 
6
 
@@ -18,7 +19,7 @@ class XFeat(BaseModel):
18
  pretrained=True,
19
  top_k=self.conf["max_keypoints"],
20
  )
21
- logger.info(f"Load XFeat(sparse) model done.")
22
 
23
  def _forward(self, data):
24
  pred = self.net.detectAndCompute(
 
1
  import torch
2
+
3
  from hloc import logger
4
+
5
  from ..utils.base_model import BaseModel
6
 
7
 
 
19
  pretrained=True,
20
  top_k=self.conf["max_keypoints"],
21
  )
22
+ logger.info("Load XFeat(sparse) model done.")
23
 
24
  def _forward(self, data):
25
  pred = self.net.detectAndCompute(
hloc/match_dense.py CHANGED
@@ -1,9 +1,11 @@
 
 
 
1
  import numpy as np
2
  import torch
3
  import torchvision.transforms.functional as F
4
- from types import SimpleNamespace
5
  from .extract_features import read_image, resize_image
6
- import cv2
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import cv2
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms.functional as F
7
+
8
  from .extract_features import read_image, resize_image
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
hloc/match_features.py CHANGED
@@ -1,18 +1,19 @@
1
  import argparse
2
- from typing import Union, Optional, Dict, List, Tuple
3
- from pathlib import Path
4
  import pprint
 
 
5
  from queue import Queue
6
  from threading import Thread
7
- from functools import partial
8
- from tqdm import tqdm
9
  import h5py
 
10
  import torch
 
11
 
12
- from . import matchers, logger
13
  from .utils.base_model import dynamic_load
14
  from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
15
- import numpy as np
16
 
17
  """
18
  A set of standard configurations that can be directly selected from the command
@@ -162,6 +163,13 @@ confs = {
162
  "match_threshold": 0.2,
163
  },
164
  },
 
 
 
 
 
 
 
165
  }
166
 
167
 
 
1
  import argparse
 
 
2
  import pprint
3
+ from functools import partial
4
+ from pathlib import Path
5
  from queue import Queue
6
  from threading import Thread
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
  import h5py
10
+ import numpy as np
11
  import torch
12
+ from tqdm import tqdm
13
 
14
+ from . import logger, matchers
15
  from .utils.base_model import dynamic_load
16
  from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
 
17
 
18
  """
19
  A set of standard configurations that can be directly selected from the command
 
163
  "match_threshold": 0.2,
164
  },
165
  },
166
+ "imp": {
167
+ "output": "matches-imp",
168
+ "model": {
169
+ "name": "imp",
170
+ "match_threshold": 0.2,
171
+ },
172
+ },
173
  }
174
 
175
 
hloc/matchers/adalam.py CHANGED
@@ -1,10 +1,9 @@
1
  import torch
2
-
3
- from ..utils.base_model import BaseModel
4
-
5
  from kornia.feature.adalam import AdalamFilter
6
  from kornia.utils.helpers import get_cuda_device_if_available
7
 
 
 
8
 
9
  class AdaLAM(BaseModel):
10
  # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
 
1
  import torch
 
 
 
2
  from kornia.feature.adalam import AdalamFilter
3
  from kornia.utils.helpers import get_cuda_device_if_available
4
 
5
+ from ..utils.base_model import BaseModel
6
+
7
 
8
  class AdaLAM(BaseModel):
9
  # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
hloc/matchers/aspanformer.py CHANGED
@@ -1,17 +1,16 @@
 
1
  import sys
2
- import torch
3
- from ..utils.base_model import BaseModel
4
- from ..utils import do_system
5
  from pathlib import Path
6
- import subprocess
 
7
 
8
  from .. import logger
 
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
11
  from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
12
  from ASpanFormer.src.config.default import get_cfg_defaults
13
  from ASpanFormer.src.utils.misc import lower_config
14
- from ASpanFormer.demo import demo_utils
15
 
16
  aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
17
 
@@ -85,7 +84,7 @@ class ASpanFormer(BaseModel):
85
  "state_dict"
86
  ]
87
  self.net.load_state_dict(state_dict, strict=False)
88
- logger.info(f"Loaded Aspanformer model")
89
 
90
  def _forward(self, data):
91
  data_ = {
 
1
+ import subprocess
2
  import sys
 
 
 
3
  from pathlib import Path
4
+
5
+ import torch
6
 
7
  from .. import logger
8
+ from ..utils.base_model import BaseModel
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
11
  from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
12
  from ASpanFormer.src.config.default import get_cfg_defaults
13
  from ASpanFormer.src.utils.misc import lower_config
 
14
 
15
  aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
16
 
 
84
  "state_dict"
85
  ]
86
  self.net.load_state_dict(state_dict, strict=False)
87
+ logger.info("Loaded Aspanformer model")
88
 
89
  def _forward(self, data):
90
  data_ = {
hloc/matchers/cotr.py CHANGED
@@ -1,19 +1,19 @@
1
- import sys
2
  import argparse
3
- import torch
4
- import warnings
5
- import numpy as np
6
  from pathlib import Path
 
 
 
7
  from torchvision.transforms import ToPILImage
 
8
  from ..utils.base_model import BaseModel
9
 
10
  sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
11
- from COTR.utils import utils as utils_cotr
12
- from COTR.models import build_model
13
- from COTR.options.options import *
14
- from COTR.options.options_utils import *
15
- from COTR.inference.inference_helper import triangulate_corr
16
  from COTR.inference.sparse_engine import SparseEngine
 
 
 
 
17
 
18
  utils_cotr.fix_randomness(0)
19
  torch.set_grad_enabled(False)
@@ -33,7 +33,7 @@ class COTR(BaseModel):
33
 
34
  def _init(self, conf):
35
  parser = argparse.ArgumentParser()
36
- set_COTR_arguments(parser)
37
  opt = parser.parse_args()
38
  opt.command = " ".join(sys.argv)
39
  opt.load_weights_path = str(
 
 
1
  import argparse
2
+ import sys
 
 
3
  from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
  from torchvision.transforms import ToPILImage
8
+
9
  from ..utils.base_model import BaseModel
10
 
11
  sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
 
 
 
 
 
12
  from COTR.inference.sparse_engine import SparseEngine
13
+ from COTR.models import build_model
14
+ from COTR.options.options import * # noqa: F403
15
+ from COTR.options.options_utils import * # noqa: F403
16
+ from COTR.utils import utils as utils_cotr
17
 
18
  utils_cotr.fix_randomness(0)
19
  torch.set_grad_enabled(False)
 
33
 
34
  def _init(self, conf):
35
  parser = argparse.ArgumentParser()
36
+ set_COTR_arguments(parser) # noqa: F405
37
  opt = parser.parse_args()
38
  opt.command = " ".join(sys.argv)
39
  opt.load_weights_path = str(
hloc/matchers/dkm.py CHANGED
@@ -1,10 +1,12 @@
 
1
  import sys
2
  from pathlib import Path
 
3
  import torch
4
  from PIL import Image
5
- import subprocess
6
- from ..utils.base_model import BaseModel
7
  from .. import logger
 
8
 
9
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
10
  from DKM.dkm import DKMv3_outdoor
@@ -37,11 +39,11 @@ class DKMv3(BaseModel):
37
  if not model_path.exists():
38
  model_path.parent.mkdir(exist_ok=True)
39
  link = self.dkm_models[conf["model_name"]]
40
- cmd = ["wget", link, "-O", str(model_path)]
41
  logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
42
  subprocess.run(cmd, check=True)
43
  self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device)
44
- logger.info(f"Loading DKMv3 model done")
45
 
46
  def _forward(self, data):
47
  img0 = data["image0"].cpu().numpy().squeeze() * 255
 
1
+ import subprocess
2
  import sys
3
  from pathlib import Path
4
+
5
  import torch
6
  from PIL import Image
7
+
 
8
  from .. import logger
9
+ from ..utils.base_model import BaseModel
10
 
11
  sys.path.append(str(Path(__file__).parent / "../../third_party"))
12
  from DKM.dkm import DKMv3_outdoor
 
39
  if not model_path.exists():
40
  model_path.parent.mkdir(exist_ok=True)
41
  link = self.dkm_models[conf["model_name"]]
42
+ cmd = ["wget", "--quiet", link, "-O", str(model_path)]
43
  logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
44
  subprocess.run(cmd, check=True)
45
  self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device)
46
+ logger.info("Loading DKMv3 model done")
47
 
48
  def _forward(self, data):
49
  img0 = data["image0"].cpu().numpy().squeeze() * 255