Spaces:
Running
Running
Vincentqyw
commited on
Commit
•
8320ccc
1
Parent(s):
b075789
update: ci
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .flake8 +4 -0
- .github/.stale.yml +17 -0
- .github/ISSUE_TEMPLATE/bug_report.md +30 -0
- .github/ISSUE_TEMPLATE/config.yml +3 -0
- .github/ISSUE_TEMPLATE/feature_request.md +15 -0
- .github/ISSUE_TEMPLATE/question.md +25 -0
- .github/PULL_REQUEST_TEMPLATE.md +7 -0
- .github/release-drafter.yml +24 -0
- .github/workflows/ci.yml +32 -0
- .github/workflows/format.yml +24 -0
- .github/workflows/release-drafter.yml +16 -0
- .gitignore +3 -1
- README.md +2 -0
- common/api.py +19 -22
- common/app_class.py +17 -16
- common/config.yaml +39 -6
- common/utils.py +39 -41
- common/viz.py +6 -4
- docker/build_docker.bat +3 -0
- run_docker.sh → docker/run_docker.bat +0 -0
- docker/run_docker.sh +1 -0
- env-docker.txt +0 -33
- format.sh +3 -0
- hloc/__init__.py +1 -0
- hloc/extract_features.py +27 -9
- hloc/extractors/alike.py +4 -2
- hloc/extractors/d2net.py +8 -7
- hloc/extractors/darkfeat.py +5 -2
- hloc/extractors/dedode.py +10 -8
- hloc/extractors/dir.py +5 -4
- hloc/extractors/disk.py +1 -1
- hloc/extractors/dog.py +4 -5
- hloc/extractors/example.py +3 -3
- hloc/extractors/fire.py +5 -4
- hloc/extractors/fire_local.py +7 -10
- hloc/extractors/lanet.py +9 -6
- hloc/extractors/netvlad.py +4 -3
- hloc/extractors/r2d2.py +6 -5
- hloc/extractors/rekd.py +4 -3
- hloc/extractors/rord.py +3 -3
- hloc/extractors/sfd2.py +43 -0
- hloc/extractors/sift.py +3 -2
- hloc/extractors/superpoint.py +4 -2
- hloc/extractors/xfeat.py +3 -2
- hloc/match_dense.py +4 -2
- hloc/match_features.py +14 -6
- hloc/matchers/adalam.py +2 -3
- hloc/matchers/aspanformer.py +5 -6
- hloc/matchers/cotr.py +10 -10
- hloc/matchers/dkm.py +6 -4
.flake8
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 80
|
3 |
+
extend-ignore = E203,E501,E402
|
4 |
+
exclude = .git,__pycache__,build,.venv/,third_party
|
.github/.stale.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Number of days of inactivity before an issue becomes stale
|
2 |
+
daysUntilStale: 60
|
3 |
+
# Number of days of inactivity before a stale issue is closed
|
4 |
+
daysUntilClose: 7
|
5 |
+
# Issues with these labels will never be considered stale
|
6 |
+
exemptLabels:
|
7 |
+
- pinned
|
8 |
+
- security
|
9 |
+
# Label to use when marking an issue as stale
|
10 |
+
staleLabel: wontfix
|
11 |
+
# Comment to post when marking an issue as stale. Set to `false` to disable
|
12 |
+
markComment: >
|
13 |
+
This issue has been automatically marked as stale because it has not had
|
14 |
+
recent activity. It will be closed if no further activity occurs. Thank you
|
15 |
+
for your contributions.
|
16 |
+
# Comment to post when closing a stale issue. Set to `false` to disable
|
17 |
+
closeComment: false
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🐛 Bug report
|
3 |
+
about: If something isn't working 🔧
|
4 |
+
title: ""
|
5 |
+
labels: bug
|
6 |
+
assignees:
|
7 |
+
---
|
8 |
+
|
9 |
+
## 🐛 Bug Report
|
10 |
+
|
11 |
+
<!-- A clear and concise description of what the bug is. -->
|
12 |
+
|
13 |
+
## 🔬 How To Reproduce
|
14 |
+
|
15 |
+
Steps to reproduce the behavior:
|
16 |
+
|
17 |
+
1. ...
|
18 |
+
|
19 |
+
### Environment
|
20 |
+
|
21 |
+
- OS: [e.g. Linux / Windows / macOS]
|
22 |
+
- Python version, get it with:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
python --version
|
26 |
+
```
|
27 |
+
|
28 |
+
## 📎 Additional context
|
29 |
+
|
30 |
+
<!-- Add any other context about the problem here. -->
|
.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository
|
2 |
+
|
3 |
+
blank_issues_enabled: false
|
.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🚀 Feature request
|
3 |
+
about: Suggest an idea for this project 🏖
|
4 |
+
title: ""
|
5 |
+
labels: enhancement
|
6 |
+
assignees:
|
7 |
+
---
|
8 |
+
|
9 |
+
## 🚀 Feature Request
|
10 |
+
|
11 |
+
<!-- A clear and concise description of the feature proposal. -->
|
12 |
+
|
13 |
+
## 📎 Additional context
|
14 |
+
|
15 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
.github/ISSUE_TEMPLATE/question.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: ❓ Question
|
3 |
+
about: Ask a question about this project 🎓
|
4 |
+
title: ""
|
5 |
+
labels: question
|
6 |
+
assignees:
|
7 |
+
---
|
8 |
+
|
9 |
+
## Checklist
|
10 |
+
|
11 |
+
<!-- Mark with an `x` all the checkboxes that apply (like `[x]`) -->
|
12 |
+
|
13 |
+
- [ ] I've searched the project's [`issues`]
|
14 |
+
|
15 |
+
## ❓ Question
|
16 |
+
|
17 |
+
<!-- What is your question -->
|
18 |
+
|
19 |
+
How can I [...]?
|
20 |
+
|
21 |
+
Is it possible to [...]?
|
22 |
+
|
23 |
+
## 📎 Additional context
|
24 |
+
|
25 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Description
|
2 |
+
|
3 |
+
<!-- Add a more detailed description of the changes if needed. -->
|
4 |
+
|
5 |
+
## Related Issue
|
6 |
+
|
7 |
+
<!-- If your PR refers to a related issue, link it here. -->
|
.github/release-drafter.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Release drafter configuration https://github.com/release-drafter/release-drafter#configuration
|
2 |
+
# Emojis were chosen to match the https://gitmoji.carloscuesta.me/
|
3 |
+
|
4 |
+
name-template: "v$RESOLVED_VERSION"
|
5 |
+
tag-template: "v$RESOLVED_VERSION"
|
6 |
+
|
7 |
+
categories:
|
8 |
+
- title: ":rocket: Features"
|
9 |
+
labels: [enhancement, feature]
|
10 |
+
- title: ":wrench: Fixes"
|
11 |
+
labels: [bug, bugfix, fix]
|
12 |
+
- title: ":toolbox: Maintenance & Refactor"
|
13 |
+
labels: [refactor, refactoring, chore]
|
14 |
+
- title: ":package: Build System & CI/CD & Test"
|
15 |
+
labels: [build, ci, testing, test]
|
16 |
+
- title: ":pencil: Documentation"
|
17 |
+
labels: [documentation]
|
18 |
+
- title: ":arrow_up: Dependencies updates"
|
19 |
+
labels: [dependencies]
|
20 |
+
|
21 |
+
template: |
|
22 |
+
## What’s Changed
|
23 |
+
|
24 |
+
$CHANGES
|
.github/workflows/ci.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI CPU
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
|
11 |
+
jobs:
|
12 |
+
build:
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
|
15 |
+
steps:
|
16 |
+
- name: Checkout code
|
17 |
+
uses: actions/checkout@v4
|
18 |
+
with:
|
19 |
+
submodules: recursive
|
20 |
+
|
21 |
+
- name: Set up Python
|
22 |
+
uses: actions/setup-python@v4
|
23 |
+
with:
|
24 |
+
python-version: "3.10"
|
25 |
+
|
26 |
+
- name: Install dependencies
|
27 |
+
run: |
|
28 |
+
pip install -r requirements.txt
|
29 |
+
sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
|
30 |
+
|
31 |
+
- name: Run tests
|
32 |
+
run: python test_app_cli.py
|
.github/workflows/format.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Format and Lint Checks
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
paths:
|
7 |
+
- '*.py'
|
8 |
+
pull_request:
|
9 |
+
types: [ assigned, opened, synchronize, reopened ]
|
10 |
+
jobs:
|
11 |
+
check:
|
12 |
+
name: Format and Lint Checks
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v4
|
16 |
+
- uses: actions/setup-python@v4
|
17 |
+
with:
|
18 |
+
python-version: '3.10'
|
19 |
+
cache: 'pip'
|
20 |
+
- run: python -m pip install --upgrade pip
|
21 |
+
- run: python -m pip install .[dev]
|
22 |
+
- run: python -m flake8 common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
|
23 |
+
- run: python -m isort common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py --check-only --diff
|
24 |
+
- run: python -m black common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py --check --diff
|
.github/workflows/release-drafter.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release Drafter
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
# branches to consider in the event; optional, defaults to all
|
6 |
+
branches:
|
7 |
+
- master
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
update_release_draft:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
# Drafts your next Release notes as Pull Requests are merged into "master"
|
14 |
+
- uses: release-drafter/release-drafter@v5.23.0
|
15 |
+
env:
|
16 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
.gitignore
CHANGED
@@ -14,9 +14,11 @@ experiments
|
|
14 |
third_party/REKD
|
15 |
hloc/matchers/dedode.py
|
16 |
gradio_cached_examples
|
17 |
-
|
18 |
hloc/matchers/quadtree.py
|
19 |
third_party/QuadTreeAttention
|
20 |
desktop.ini
|
|
|
|
|
21 |
experiments*
|
22 |
gen_example.py
|
|
|
14 |
third_party/REKD
|
15 |
hloc/matchers/dedode.py
|
16 |
gradio_cached_examples
|
17 |
+
*.mp4
|
18 |
hloc/matchers/quadtree.py
|
19 |
third_party/QuadTreeAttention
|
20 |
desktop.ini
|
21 |
+
*.egg-info
|
22 |
+
output.pkl
|
23 |
experiments*
|
24 |
gen_example.py
|
README.md
CHANGED
@@ -44,6 +44,8 @@ The tool currently supports various popular image matching algorithms, namely:
|
|
44 |
- [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023
|
45 |
- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
|
46 |
- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
|
|
|
|
|
47 |
- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
|
48 |
- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
|
49 |
- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
|
|
|
44 |
- [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023
|
45 |
- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
|
46 |
- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
|
47 |
+
- [x] [SFD2](https://github.com/feixue94/sfd2), CVPR 2023
|
48 |
+
- [x] [IMP](https://github.com/feixue94/imp-release), CVPR 2023
|
49 |
- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
|
50 |
- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
|
51 |
- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
|
common/api.py
CHANGED
@@ -1,26 +1,23 @@
|
|
1 |
-
import cv2
|
2 |
-
import torch
|
3 |
import warnings
|
4 |
-
import numpy as np
|
5 |
from pathlib import Path
|
6 |
-
from typing import
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
9 |
from hloc.utils.viz import add_text, plot_keypoints
|
|
|
10 |
from .utils import (
|
11 |
-
load_config,
|
12 |
-
get_model,
|
13 |
-
get_feature_model,
|
14 |
-
filter_matches,
|
15 |
-
device,
|
16 |
ROOT,
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
-
from .viz import
|
19 |
-
fig2im,
|
20 |
-
plot_images,
|
21 |
-
display_matches,
|
22 |
-
)
|
23 |
-
import matplotlib.pyplot as plt
|
24 |
|
25 |
warnings.simplefilter("ignore")
|
26 |
|
@@ -109,7 +106,7 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
109 |
"match_threshold"
|
110 |
] = match_threshold
|
111 |
except TypeError as e:
|
112 |
-
|
113 |
else:
|
114 |
self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
|
115 |
self.conf["feature"]["model"][
|
@@ -137,7 +134,9 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
137 |
self.match_conf["preprocessing"],
|
138 |
device=self.device,
|
139 |
)
|
140 |
-
last_fixed = "{}".format(
|
|
|
|
|
141 |
else:
|
142 |
pred0 = extract_features.extract(
|
143 |
self.extractor, img0, self.extract_conf["preprocessing"]
|
@@ -290,7 +289,5 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
290 |
|
291 |
|
292 |
if __name__ == "__main__":
|
293 |
-
import argparse
|
294 |
-
|
295 |
config = load_config(ROOT / "common/config.yaml")
|
296 |
-
|
|
|
|
|
|
|
1 |
import warnings
|
|
|
2 |
from pathlib import Path
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from hloc import extract_features, logger, match_dense, match_features
|
11 |
from hloc.utils.viz import add_text, plot_keypoints
|
12 |
+
|
13 |
from .utils import (
|
|
|
|
|
|
|
|
|
|
|
14 |
ROOT,
|
15 |
+
filter_matches,
|
16 |
+
get_feature_model,
|
17 |
+
get_model,
|
18 |
+
load_config,
|
19 |
)
|
20 |
+
from .viz import display_matches, fig2im, plot_images
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
warnings.simplefilter("ignore")
|
23 |
|
|
|
106 |
"match_threshold"
|
107 |
] = match_threshold
|
108 |
except TypeError as e:
|
109 |
+
logger.error(e)
|
110 |
else:
|
111 |
self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
|
112 |
self.conf["feature"]["model"][
|
|
|
134 |
self.match_conf["preprocessing"],
|
135 |
device=self.device,
|
136 |
)
|
137 |
+
last_fixed = "{}".format( # noqa: F841
|
138 |
+
self.match_conf["model"]["name"]
|
139 |
+
)
|
140 |
else:
|
141 |
pred0 = extract_features.extract(
|
142 |
self.extractor, img0, self.extract_conf["preprocessing"]
|
|
|
289 |
|
290 |
|
291 |
if __name__ == "__main__":
|
|
|
|
|
292 |
config = load_config(ROOT / "common/config.yaml")
|
293 |
+
api = ImageMatchingAPI(config)
|
common/app_class.py
CHANGED
@@ -1,22 +1,21 @@
|
|
1 |
-
import argparse
|
2 |
-
import numpy as np
|
3 |
-
import gradio as gr
|
4 |
from pathlib import Path
|
5 |
-
from typing import
|
|
|
|
|
|
|
|
|
6 |
from common.utils import (
|
7 |
-
|
|
|
8 |
generate_warp_images,
|
9 |
-
load_config,
|
10 |
get_matcher_zoo,
|
|
|
|
|
11 |
run_matching,
|
12 |
run_ransac,
|
13 |
send_to_match,
|
14 |
-
gen_examples,
|
15 |
-
GRADIO_VERSION,
|
16 |
-
ROOT,
|
17 |
)
|
18 |
|
19 |
-
|
20 |
DESCRIPTION = """
|
21 |
# Image Matching WebUI
|
22 |
This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
|
@@ -132,12 +131,14 @@ class ImageMatchingApp:
|
|
132 |
label="Keypoint thres.",
|
133 |
value=0.015,
|
134 |
)
|
135 |
-
detect_line_threshold =
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
141 |
)
|
142 |
# matcher_lists = gr.Radio(
|
143 |
# ["NN-mutual", "Dual-Softmax"],
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
+
from typing import Any, Dict, Optional, Tuple
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
from common.utils import (
|
8 |
+
GRADIO_VERSION,
|
9 |
+
gen_examples,
|
10 |
generate_warp_images,
|
|
|
11 |
get_matcher_zoo,
|
12 |
+
load_config,
|
13 |
+
ransac_zoo,
|
14 |
run_matching,
|
15 |
run_ransac,
|
16 |
send_to_match,
|
|
|
|
|
|
|
17 |
)
|
18 |
|
|
|
19 |
DESCRIPTION = """
|
20 |
# Image Matching WebUI
|
21 |
This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
|
|
|
131 |
label="Keypoint thres.",
|
132 |
value=0.015,
|
133 |
)
|
134 |
+
detect_line_threshold = ( # noqa: F841
|
135 |
+
gr.Slider(
|
136 |
+
minimum=0.1,
|
137 |
+
maximum=1,
|
138 |
+
step=0.01,
|
139 |
+
label="Line thres.",
|
140 |
+
value=0.2,
|
141 |
+
)
|
142 |
)
|
143 |
# matcher_lists = gr.Radio(
|
144 |
# ["NN-mutual", "Dual-Softmax"],
|
common/config.yaml
CHANGED
@@ -30,20 +30,22 @@ matcher_zoo:
|
|
30 |
DUSt3R:
|
31 |
# TODO: duster is under development
|
32 |
enable: true
|
|
|
33 |
matcher: duster
|
34 |
dense: true
|
35 |
-
info:
|
36 |
name: DUSt3R #dispaly name
|
37 |
source: "CVPR 2024"
|
38 |
github: https://github.com/naver/dust3r
|
39 |
paper: https://arxiv.org/abs/2312.14132
|
40 |
project: https://dust3r.europe.naverlabs.com
|
41 |
-
display: true
|
42 |
GIM(dkm):
|
43 |
enable: true
|
|
|
44 |
matcher: gim(dkm)
|
45 |
dense: true
|
46 |
-
info:
|
47 |
name: GIM(DKM) #dispaly name
|
48 |
source: "ICLR 2024"
|
49 |
github: https://github.com/xuelunshen/gim
|
@@ -52,8 +54,9 @@ matcher_zoo:
|
|
52 |
display: true
|
53 |
RoMa:
|
54 |
matcher: roma
|
|
|
55 |
dense: true
|
56 |
-
info:
|
57 |
name: RoMa #dispaly name
|
58 |
source: "CVPR 2024"
|
59 |
github: https://github.com/Parskatt/RoMa
|
@@ -62,8 +65,9 @@ matcher_zoo:
|
|
62 |
display: true
|
63 |
dkm:
|
64 |
matcher: dkm
|
|
|
65 |
dense: true
|
66 |
-
info:
|
67 |
name: DKM #dispaly name
|
68 |
source: "CVPR 2023"
|
69 |
github: https://github.com/Parskatt/DKM
|
@@ -73,7 +77,7 @@ matcher_zoo:
|
|
73 |
loftr:
|
74 |
matcher: loftr
|
75 |
dense: true
|
76 |
-
info:
|
77 |
name: LoFTR #dispaly name
|
78 |
source: "CVPR 2021"
|
79 |
github: https://github.com/zju3dv/LoFTR
|
@@ -82,6 +86,7 @@ matcher_zoo:
|
|
82 |
display: true
|
83 |
cotr:
|
84 |
enable: false
|
|
|
85 |
matcher: cotr
|
86 |
dense: true
|
87 |
info:
|
@@ -363,3 +368,31 @@ matcher_zoo:
|
|
363 |
paper: https://arxiv.org/abs/2104.03362
|
364 |
project: null
|
365 |
display: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
DUSt3R:
|
31 |
# TODO: duster is under development
|
32 |
enable: true
|
33 |
+
skip_ci: true
|
34 |
matcher: duster
|
35 |
dense: true
|
36 |
+
info:
|
37 |
name: DUSt3R #dispaly name
|
38 |
source: "CVPR 2024"
|
39 |
github: https://github.com/naver/dust3r
|
40 |
paper: https://arxiv.org/abs/2312.14132
|
41 |
project: https://dust3r.europe.naverlabs.com
|
42 |
+
display: true
|
43 |
GIM(dkm):
|
44 |
enable: true
|
45 |
+
skip_ci: true
|
46 |
matcher: gim(dkm)
|
47 |
dense: true
|
48 |
+
info:
|
49 |
name: GIM(DKM) #dispaly name
|
50 |
source: "ICLR 2024"
|
51 |
github: https://github.com/xuelunshen/gim
|
|
|
54 |
display: true
|
55 |
RoMa:
|
56 |
matcher: roma
|
57 |
+
skip_ci: true
|
58 |
dense: true
|
59 |
+
info:
|
60 |
name: RoMa #dispaly name
|
61 |
source: "CVPR 2024"
|
62 |
github: https://github.com/Parskatt/RoMa
|
|
|
65 |
display: true
|
66 |
dkm:
|
67 |
matcher: dkm
|
68 |
+
skip_ci: true
|
69 |
dense: true
|
70 |
+
info:
|
71 |
name: DKM #dispaly name
|
72 |
source: "CVPR 2023"
|
73 |
github: https://github.com/Parskatt/DKM
|
|
|
77 |
loftr:
|
78 |
matcher: loftr
|
79 |
dense: true
|
80 |
+
info:
|
81 |
name: LoFTR #dispaly name
|
82 |
source: "CVPR 2021"
|
83 |
github: https://github.com/zju3dv/LoFTR
|
|
|
86 |
display: true
|
87 |
cotr:
|
88 |
enable: false
|
89 |
+
skip_ci: true
|
90 |
matcher: cotr
|
91 |
dense: true
|
92 |
info:
|
|
|
368 |
paper: https://arxiv.org/abs/2104.03362
|
369 |
project: null
|
370 |
display: true
|
371 |
+
|
372 |
+
sfd2+imp:
|
373 |
+
matcher: imp
|
374 |
+
feature: sfd2
|
375 |
+
enable: false
|
376 |
+
dense: false
|
377 |
+
skip_ci: true
|
378 |
+
info:
|
379 |
+
name: SFD2+IMP #dispaly name
|
380 |
+
source: "CVPR 2023"
|
381 |
+
github: https://github.com/feixue94/imp-release
|
382 |
+
paper: https://arxiv.org/pdf/2304.14837
|
383 |
+
project: https://feixue94.github.io/
|
384 |
+
display: true
|
385 |
+
|
386 |
+
sfd2+mnn:
|
387 |
+
matcher: NN-mutual
|
388 |
+
feature: sfd2
|
389 |
+
enable: false
|
390 |
+
dense: false
|
391 |
+
skip_ci: true
|
392 |
+
info:
|
393 |
+
name: SFD2+MNN #dispaly name
|
394 |
+
source: "CVPR 2023"
|
395 |
+
github: https://github.com/feixue94/sfd2
|
396 |
+
paper: https://arxiv.org/abs/2304.14845
|
397 |
+
project: https://feixue94.github.io/
|
398 |
+
display: true
|
common/utils.py
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import sys
|
4 |
-
import torch
|
5 |
import random
|
6 |
-
import psutil
|
7 |
import shutil
|
8 |
-
import
|
9 |
-
import
|
10 |
-
from
|
11 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import poselib
|
13 |
-
|
14 |
-
from
|
15 |
-
|
16 |
-
from hloc
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
)
|
25 |
-
import
|
26 |
-
|
27 |
-
import
|
28 |
-
import tempfile
|
29 |
-
import pickle
|
30 |
|
31 |
warnings.simplefilter("ignore")
|
32 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
|
34 |
ROOT = Path(__file__).parent.parent
|
35 |
# some default values
|
@@ -91,14 +91,13 @@ class ModelCache:
|
|
91 |
host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
|
92 |
vm = psutil.virtual_memory()
|
93 |
du = shutil.disk_usage(".")
|
94 |
-
vm_ratio = host_colocation * vm.used / vm.total
|
95 |
if verbose:
|
96 |
logger.info(
|
97 |
f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
|
98 |
)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
return vm.used / 1e9
|
103 |
|
104 |
def print_memory_usage(self):
|
@@ -173,7 +172,7 @@ def get_model(match_conf: Dict[str, Any]):
|
|
173 |
A matcher model instance.
|
174 |
"""
|
175 |
Model = dynamic_load(matchers, match_conf["model"]["name"])
|
176 |
-
model = Model(match_conf["model"]).eval().to(
|
177 |
return model
|
178 |
|
179 |
|
@@ -188,7 +187,7 @@ def get_feature_model(conf: Dict[str, Dict[str, Any]]):
|
|
188 |
A feature extraction model instance.
|
189 |
"""
|
190 |
Model = dynamic_load(extractors, conf["model"]["name"])
|
191 |
-
model = Model(conf["model"]).eval().to(
|
192 |
return model
|
193 |
|
194 |
|
@@ -423,7 +422,7 @@ def _filter_matches_poselib(
|
|
423 |
elif geometry_type == "Fundamental":
|
424 |
M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options)
|
425 |
else:
|
426 |
-
raise
|
427 |
|
428 |
return M, np.array(info["inliers"])
|
429 |
|
@@ -464,7 +463,7 @@ def proc_ransac_matches(
|
|
464 |
geometry_type,
|
465 |
)
|
466 |
else:
|
467 |
-
raise
|
468 |
|
469 |
|
470 |
def filter_matches(
|
@@ -617,7 +616,9 @@ def compute_geometry(
|
|
617 |
geo_info["H1"] = H1.tolist()
|
618 |
geo_info["H2"] = H2.tolist()
|
619 |
except cv2.error as e:
|
620 |
-
logger.error(
|
|
|
|
|
621 |
return geo_info
|
622 |
else:
|
623 |
return {}
|
@@ -643,7 +644,6 @@ def wrap_images(
|
|
643 |
"""
|
644 |
h0, w0, _ = img0.shape
|
645 |
h1, w1, _ = img1.shape
|
646 |
-
result_matrix: Optional[np.ndarray] = None
|
647 |
if geo_info is not None and len(geo_info) != 0:
|
648 |
rectified_image0 = img0
|
649 |
rectified_image1 = None
|
@@ -656,7 +656,6 @@ def wrap_images(
|
|
656 |
title: List[str] = []
|
657 |
if geom_type == "Homography":
|
658 |
rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0))
|
659 |
-
result_matrix = H
|
660 |
title = ["Image 0", "Image 1 - warped"]
|
661 |
elif geom_type == "Fundamental":
|
662 |
if geom_type not in geo_info:
|
@@ -666,7 +665,6 @@ def wrap_images(
|
|
666 |
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
|
667 |
rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0))
|
668 |
rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1))
|
669 |
-
result_matrix = np.array(geo_info["Fundamental"])
|
670 |
title = ["Image 0 - warped", "Image 1 - warped"]
|
671 |
else:
|
672 |
print("Error: Unknown geometry type")
|
@@ -705,7 +703,7 @@ def generate_warp_images(
|
|
705 |
):
|
706 |
return None, None
|
707 |
geom_info = matches_info["geom_info"]
|
708 |
-
|
709 |
if choice != "No":
|
710 |
wrapped_image_pair, warped_image = wrap_images(
|
711 |
input_image0, input_image1, geom_info, choice
|
@@ -805,7 +803,7 @@ def run_ransac(
|
|
805 |
with open(tmp_state_cache, "wb") as f:
|
806 |
pickle.dump(state_cache, f)
|
807 |
|
808 |
-
logger.info(
|
809 |
|
810 |
return (
|
811 |
output_matches_ransac,
|
@@ -880,7 +878,7 @@ def run_matching(
|
|
880 |
output_matches_ransac = None
|
881 |
|
882 |
# super slow!
|
883 |
-
if "roma" in key.lower() and
|
884 |
gr.Info(
|
885 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
886 |
f" Due to CPU inference, {key} is quiet slow."
|
@@ -905,7 +903,7 @@ def run_matching(
|
|
905 |
|
906 |
if model["dense"]:
|
907 |
pred = match_dense.match_images(
|
908 |
-
matcher, image0, image1, match_conf["preprocessing"], device=
|
909 |
)
|
910 |
del matcher
|
911 |
extract_conf = None
|
@@ -1000,7 +998,7 @@ def run_matching(
|
|
1000 |
tmp_state_cache = "output.pkl"
|
1001 |
with open(tmp_state_cache, "wb") as f:
|
1002 |
pickle.dump(state_cache, f)
|
1003 |
-
logger.info(
|
1004 |
return (
|
1005 |
output_keypoints,
|
1006 |
output_matches_raw,
|
|
|
1 |
import os
|
2 |
+
import pickle
|
|
|
|
|
3 |
import random
|
|
|
4 |
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
from itertools import combinations
|
8 |
from pathlib import Path
|
9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import gradio as gr
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
import poselib
|
16 |
+
import psutil
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
from hloc import (
|
20 |
+
DEVICE,
|
21 |
+
extract_features,
|
22 |
+
extractors,
|
23 |
+
logger,
|
24 |
+
match_dense,
|
25 |
+
match_features,
|
26 |
+
matchers,
|
27 |
)
|
28 |
+
from hloc.utils.base_model import dynamic_load
|
29 |
+
|
30 |
+
from .viz import display_keypoints, display_matches, fig2im, plot_images
|
|
|
|
|
31 |
|
32 |
warnings.simplefilter("ignore")
|
|
|
33 |
|
34 |
ROOT = Path(__file__).parent.parent
|
35 |
# some default values
|
|
|
91 |
host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
|
92 |
vm = psutil.virtual_memory()
|
93 |
du = shutil.disk_usage(".")
|
|
|
94 |
if verbose:
|
95 |
logger.info(
|
96 |
f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
|
97 |
)
|
98 |
+
logger.info(
|
99 |
+
f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
|
100 |
+
)
|
101 |
return vm.used / 1e9
|
102 |
|
103 |
def print_memory_usage(self):
|
|
|
172 |
A matcher model instance.
|
173 |
"""
|
174 |
Model = dynamic_load(matchers, match_conf["model"]["name"])
|
175 |
+
model = Model(match_conf["model"]).eval().to(DEVICE)
|
176 |
return model
|
177 |
|
178 |
|
|
|
187 |
A feature extraction model instance.
|
188 |
"""
|
189 |
Model = dynamic_load(extractors, conf["model"]["name"])
|
190 |
+
model = Model(conf["model"]).eval().to(DEVICE)
|
191 |
return model
|
192 |
|
193 |
|
|
|
422 |
elif geometry_type == "Fundamental":
|
423 |
M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options)
|
424 |
else:
|
425 |
+
raise NotImplementedError
|
426 |
|
427 |
return M, np.array(info["inliers"])
|
428 |
|
|
|
463 |
geometry_type,
|
464 |
)
|
465 |
else:
|
466 |
+
raise NotImplementedError
|
467 |
|
468 |
|
469 |
def filter_matches(
|
|
|
616 |
geo_info["H1"] = H1.tolist()
|
617 |
geo_info["H2"] = H2.tolist()
|
618 |
except cv2.error as e:
|
619 |
+
logger.error(
|
620 |
+
f"StereoRectifyUncalibrated failed, skip! error: {e}"
|
621 |
+
)
|
622 |
return geo_info
|
623 |
else:
|
624 |
return {}
|
|
|
644 |
"""
|
645 |
h0, w0, _ = img0.shape
|
646 |
h1, w1, _ = img1.shape
|
|
|
647 |
if geo_info is not None and len(geo_info) != 0:
|
648 |
rectified_image0 = img0
|
649 |
rectified_image1 = None
|
|
|
656 |
title: List[str] = []
|
657 |
if geom_type == "Homography":
|
658 |
rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0))
|
|
|
659 |
title = ["Image 0", "Image 1 - warped"]
|
660 |
elif geom_type == "Fundamental":
|
661 |
if geom_type not in geo_info:
|
|
|
665 |
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
|
666 |
rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0))
|
667 |
rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1))
|
|
|
668 |
title = ["Image 0 - warped", "Image 1 - warped"]
|
669 |
else:
|
670 |
print("Error: Unknown geometry type")
|
|
|
703 |
):
|
704 |
return None, None
|
705 |
geom_info = matches_info["geom_info"]
|
706 |
+
warped_image = None
|
707 |
if choice != "No":
|
708 |
wrapped_image_pair, warped_image = wrap_images(
|
709 |
input_image0, input_image1, geom_info, choice
|
|
|
803 |
with open(tmp_state_cache, "wb") as f:
|
804 |
pickle.dump(state_cache, f)
|
805 |
|
806 |
+
logger.info("Dump results done!")
|
807 |
|
808 |
return (
|
809 |
output_matches_ransac,
|
|
|
878 |
output_matches_ransac = None
|
879 |
|
880 |
# super slow!
|
881 |
+
if "roma" in key.lower() and DEVICE == "cpu":
|
882 |
gr.Info(
|
883 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
884 |
f" Due to CPU inference, {key} is quiet slow."
|
|
|
903 |
|
904 |
if model["dense"]:
|
905 |
pred = match_dense.match_images(
|
906 |
+
matcher, image0, image1, match_conf["preprocessing"], device=DEVICE
|
907 |
)
|
908 |
del matcher
|
909 |
extract_conf = None
|
|
|
998 |
tmp_state_cache = "output.pkl"
|
999 |
with open(tmp_state_cache, "wb") as f:
|
1000 |
pickle.dump(state_cache, f)
|
1001 |
+
logger.info("Dump results done!")
|
1002 |
return (
|
1003 |
output_keypoints,
|
1004 |
output_matches_raw,
|
common/viz.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
import cv2
|
2 |
import typing
|
|
|
|
|
|
|
|
|
3 |
import matplotlib
|
|
|
4 |
import numpy as np
|
5 |
import seaborn as sns
|
6 |
-
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Dict, Any, Optional, Tuple, List, Union
|
9 |
from hloc.utils.viz import add_text, plot_keypoints
|
10 |
|
11 |
|
|
|
|
|
1 |
import typing
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import cv2
|
6 |
import matplotlib
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import seaborn as sns
|
10 |
+
|
|
|
|
|
11 |
from hloc.utils.viz import add_text, plot_keypoints
|
12 |
|
13 |
|
docker/build_docker.bat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
docker build -t image-matching-webui:latest . --no-cache
|
2 |
+
# docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
|
3 |
+
# docker push vincentqin/image-matching-webui:latest
|
run_docker.sh → docker/run_docker.bat
RENAMED
File without changes
|
docker/run_docker.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860
|
env-docker.txt
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
e2cnn==0.2.3
|
2 |
-
einops==0.6.1
|
3 |
-
gdown==4.7.1
|
4 |
-
gradio==4.28.3
|
5 |
-
gradio_client==0.16.0
|
6 |
-
h5py==3.9.0
|
7 |
-
imageio==2.31.1
|
8 |
-
Jinja2==3.1.2
|
9 |
-
kornia==0.6.12
|
10 |
-
loguru==0.7.0
|
11 |
-
matplotlib==3.7.1
|
12 |
-
numpy==1.23.5
|
13 |
-
omegaconf==2.3.0
|
14 |
-
opencv-contrib-python==4.6.0.66
|
15 |
-
opencv-python==4.6.0.66
|
16 |
-
pandas==2.0.3
|
17 |
-
plotly==5.15.0
|
18 |
-
protobuf==4.23.2
|
19 |
-
pycolmap==0.5.0
|
20 |
-
pytlsd==0.0.2
|
21 |
-
pytorch-lightning==1.4.9
|
22 |
-
PyYAML==6.0
|
23 |
-
scikit-image==0.21.0
|
24 |
-
scikit-learn==1.2.2
|
25 |
-
scipy==1.11.1
|
26 |
-
seaborn==0.12.2
|
27 |
-
shapely==2.0.1
|
28 |
-
tensorboardX==2.6.1
|
29 |
-
torchmetrics==0.6.0
|
30 |
-
torchvision==0.17.1
|
31 |
-
tqdm==4.65.0
|
32 |
-
yacs==0.1.8
|
33 |
-
onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
format.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
python -m flake8 common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
|
2 |
+
python -m isort common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
|
3 |
+
python -m black common/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
|
hloc/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
import torch
|
3 |
from packaging import version
|
4 |
|
|
|
1 |
import logging
|
2 |
+
|
3 |
import torch
|
4 |
from packaging import version
|
5 |
|
hloc/extract_features.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1 |
import argparse
|
2 |
-
import
|
|
|
3 |
from pathlib import Path
|
4 |
-
from typing import Dict, List, Union, Optional
|
5 |
-
import h5py
|
6 |
from types import SimpleNamespace
|
|
|
|
|
7 |
import cv2
|
|
|
8 |
import numpy as np
|
9 |
-
from tqdm import tqdm
|
10 |
-
import pprint
|
11 |
-
import collections.abc as collections
|
12 |
import PIL.Image
|
|
|
13 |
import torchvision.transforms.functional as F
|
|
|
|
|
14 |
from . import extractors, logger
|
15 |
from .utils.base_model import dynamic_load
|
|
|
16 |
from .utils.parsers import parse_image_lists
|
17 |
-
from .utils.io import read_image, list_h5_names
|
18 |
-
|
19 |
|
20 |
"""
|
21 |
A set of standard configurations that can be directly selected from the command
|
@@ -290,6 +291,23 @@ confs = {
|
|
290 |
"dfactor": 8,
|
291 |
},
|
292 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
# Global descriptors
|
294 |
"dir": {
|
295 |
"output": "global-feats-dir",
|
@@ -460,7 +478,7 @@ def extract(model, image_0, conf):
|
|
460 |
# image0 = image_0[:, :, ::-1] # BGR to RGB
|
461 |
data = preprocess(image0, conf)
|
462 |
pred = model({"image": data["image"]})
|
463 |
-
pred["image_size"] =
|
464 |
pred = {**pred, **data}
|
465 |
return pred
|
466 |
|
|
|
1 |
import argparse
|
2 |
+
import collections.abc as collections
|
3 |
+
import pprint
|
4 |
from pathlib import Path
|
|
|
|
|
5 |
from types import SimpleNamespace
|
6 |
+
from typing import Dict, List, Optional, Union
|
7 |
+
|
8 |
import cv2
|
9 |
+
import h5py
|
10 |
import numpy as np
|
|
|
|
|
|
|
11 |
import PIL.Image
|
12 |
+
import torch
|
13 |
import torchvision.transforms.functional as F
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
from . import extractors, logger
|
17 |
from .utils.base_model import dynamic_load
|
18 |
+
from .utils.io import list_h5_names, read_image
|
19 |
from .utils.parsers import parse_image_lists
|
|
|
|
|
20 |
|
21 |
"""
|
22 |
A set of standard configurations that can be directly selected from the command
|
|
|
291 |
"dfactor": 8,
|
292 |
},
|
293 |
},
|
294 |
+
"sfd2": {
|
295 |
+
"output": "feats-sfd2-n4096-r1600",
|
296 |
+
"model": {
|
297 |
+
"name": "sfd2",
|
298 |
+
"max_keypoints": 4096,
|
299 |
+
},
|
300 |
+
"preprocessing": {
|
301 |
+
"grayscale": False,
|
302 |
+
"force_resize": True,
|
303 |
+
"resize_max": 1600,
|
304 |
+
"width": 640,
|
305 |
+
"height": 480,
|
306 |
+
"conf_th": 0.001,
|
307 |
+
"multiscale": False,
|
308 |
+
"scales": [1.0],
|
309 |
+
},
|
310 |
+
},
|
311 |
# Global descriptors
|
312 |
"dir": {
|
313 |
"output": "global-feats-dir",
|
|
|
478 |
# image0 = image_0[:, :, ::-1] # BGR to RGB
|
479 |
data = preprocess(image0, conf)
|
480 |
pred = model({"image": data["image"]})
|
481 |
+
pred["image_size"] = data["original_size"]
|
482 |
pred = {**pred, **data}
|
483 |
return pred
|
484 |
|
hloc/extractors/alike.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
3 |
import torch
|
4 |
|
5 |
-
from ..utils.base_model import BaseModel
|
6 |
from hloc import logger
|
7 |
|
|
|
|
|
8 |
alike_path = Path(__file__).parent / "../../third_party/ALIKE"
|
9 |
sys.path.append(str(alike_path))
|
10 |
from alike import ALike as Alike_
|
@@ -34,7 +36,7 @@ class Alike(BaseModel):
|
|
34 |
scores_th=conf["detection_threshold"],
|
35 |
n_limit=conf["max_keypoints"],
|
36 |
)
|
37 |
-
logger.info(
|
38 |
|
39 |
def _forward(self, data):
|
40 |
image = data["image"]
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import torch
|
5 |
|
|
|
6 |
from hloc import logger
|
7 |
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
alike_path = Path(__file__).parent / "../../third_party/ALIKE"
|
11 |
sys.path.append(str(alike_path))
|
12 |
from alike import ALike as Alike_
|
|
|
36 |
scores_th=conf["detection_threshold"],
|
37 |
n_limit=conf["max_keypoints"],
|
38 |
)
|
39 |
+
logger.info("Load Alike model done.")
|
40 |
|
41 |
def _forward(self, data):
|
42 |
image = data["image"]
|
hloc/extractors/d2net.py
CHANGED
@@ -1,17 +1,17 @@
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
4 |
import torch
|
5 |
|
6 |
-
from ..utils.base_model import BaseModel
|
7 |
from hloc import logger
|
8 |
|
9 |
-
|
10 |
-
sys.path.append(str(d2net_path))
|
11 |
-
from d2net.lib.model_test import D2Net as _D2Net
|
12 |
-
from d2net.lib.pyramid import process_multiscale
|
13 |
|
14 |
d2net_path = Path(__file__).parent / "../../third_party/d2net"
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
class D2Net(BaseModel):
|
@@ -30,6 +30,7 @@ class D2Net(BaseModel):
|
|
30 |
model_file.parent.mkdir(exist_ok=True)
|
31 |
cmd = [
|
32 |
"wget",
|
|
|
33 |
"https://dusmanu.com/files/d2-net/" + conf["model_name"],
|
34 |
"-O",
|
35 |
str(model_file),
|
@@ -39,7 +40,7 @@ class D2Net(BaseModel):
|
|
39 |
self.net = _D2Net(
|
40 |
model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
|
41 |
)
|
42 |
-
logger.info(
|
43 |
|
44 |
def _forward(self, data):
|
45 |
image = data["image"]
|
|
|
1 |
+
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
import torch
|
6 |
|
|
|
7 |
from hloc import logger
|
8 |
|
9 |
+
from ..utils.base_model import BaseModel
|
|
|
|
|
|
|
10 |
|
11 |
d2net_path = Path(__file__).parent / "../../third_party/d2net"
|
12 |
+
sys.path.append(str(d2net_path))
|
13 |
+
from lib.model_test import D2Net as _D2Net
|
14 |
+
from lib.pyramid import process_multiscale
|
15 |
|
16 |
|
17 |
class D2Net(BaseModel):
|
|
|
30 |
model_file.parent.mkdir(exist_ok=True)
|
31 |
cmd = [
|
32 |
"wget",
|
33 |
+
"--quiet",
|
34 |
"https://dusmanu.com/files/d2-net/" + conf["model_name"],
|
35 |
"-O",
|
36 |
str(model_file),
|
|
|
40 |
self.net = _D2Net(
|
41 |
model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
|
42 |
)
|
43 |
+
logger.info("Load D2Net model done.")
|
44 |
|
45 |
def _forward(self, data):
|
46 |
image = data["image"]
|
hloc/extractors/darkfeat.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
3 |
import subprocess
|
4 |
-
|
5 |
from hloc import logger
|
6 |
|
|
|
|
|
7 |
darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
|
8 |
sys.path.append(str(darkfeat_path))
|
9 |
from darkfeat import DarkFeat as DarkFeat_
|
@@ -43,7 +46,7 @@ class DarkFeat(BaseModel):
|
|
43 |
raise e
|
44 |
|
45 |
self.net = DarkFeat_(model_path)
|
46 |
-
logger.info(
|
47 |
|
48 |
def _forward(self, data):
|
49 |
pred = self.net({"image": data["image"]})
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import subprocess
|
5 |
+
|
6 |
from hloc import logger
|
7 |
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
|
11 |
sys.path.append(str(darkfeat_path))
|
12 |
from darkfeat import DarkFeat as DarkFeat_
|
|
|
46 |
raise e
|
47 |
|
48 |
self.net = DarkFeat_(model_path)
|
49 |
+
logger.info("Load DarkFeat model done.")
|
50 |
|
51 |
def _forward(self, data):
|
52 |
pred = self.net({"image": data["image"]})
|
hloc/extractors/dedode.py
CHANGED
@@ -1,16 +1,18 @@
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
4 |
import torch
|
5 |
-
from PIL import Image
|
6 |
-
from ..utils.base_model import BaseModel
|
7 |
-
from hloc import logger
|
8 |
import torchvision.transforms as transforms
|
9 |
|
|
|
|
|
|
|
|
|
10 |
dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
|
11 |
sys.path.append(str(dedode_path))
|
12 |
|
13 |
-
from DeDoDe import
|
14 |
from DeDoDe.utils import to_pixel_coords
|
15 |
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -49,14 +51,14 @@ class DeDoDe(BaseModel):
|
|
49 |
if not model_detector_path.exists():
|
50 |
model_detector_path.parent.mkdir(exist_ok=True)
|
51 |
link = self.weight_urls[conf["model_detector_name"]]
|
52 |
-
cmd = ["wget", link, "-O", str(model_detector_path)]
|
53 |
logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
|
54 |
subprocess.run(cmd, check=True)
|
55 |
|
56 |
if not model_descriptor_path.exists():
|
57 |
model_descriptor_path.parent.mkdir(exist_ok=True)
|
58 |
link = self.weight_urls[conf["model_descriptor_name"]]
|
59 |
-
cmd = ["wget", link, "-O", str(model_descriptor_path)]
|
60 |
logger.info(
|
61 |
f"Downloading the DeDoDe descriptor model with `{cmd}`."
|
62 |
)
|
@@ -73,7 +75,7 @@ class DeDoDe(BaseModel):
|
|
73 |
self.descriptor = dedode_descriptor_B(
|
74 |
weights=weights_descriptor, device=device
|
75 |
)
|
76 |
-
logger.info(
|
77 |
|
78 |
def _forward(self, data):
|
79 |
"""
|
|
|
1 |
+
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
import torch
|
|
|
|
|
|
|
6 |
import torchvision.transforms as transforms
|
7 |
|
8 |
+
from hloc import logger
|
9 |
+
|
10 |
+
from ..utils.base_model import BaseModel
|
11 |
+
|
12 |
dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
|
13 |
sys.path.append(str(dedode_path))
|
14 |
|
15 |
+
from DeDoDe import dedode_descriptor_B, dedode_detector_L
|
16 |
from DeDoDe.utils import to_pixel_coords
|
17 |
|
18 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
51 |
if not model_detector_path.exists():
|
52 |
model_detector_path.parent.mkdir(exist_ok=True)
|
53 |
link = self.weight_urls[conf["model_detector_name"]]
|
54 |
+
cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)]
|
55 |
logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
|
56 |
subprocess.run(cmd, check=True)
|
57 |
|
58 |
if not model_descriptor_path.exists():
|
59 |
model_descriptor_path.parent.mkdir(exist_ok=True)
|
60 |
link = self.weight_urls[conf["model_descriptor_name"]]
|
61 |
+
cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)]
|
62 |
logger.info(
|
63 |
f"Downloading the DeDoDe descriptor model with `{cmd}`."
|
64 |
)
|
|
|
75 |
self.descriptor = dedode_descriptor_B(
|
76 |
weights=weights_descriptor, device=device
|
77 |
)
|
78 |
+
logger.info("Load DeDoDe model done.")
|
79 |
|
80 |
def _forward(self, data):
|
81 |
"""
|
hloc/extractors/dir.py
CHANGED
@@ -1,10 +1,11 @@
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
import torch
|
4 |
from zipfile import ZipFile
|
5 |
-
|
6 |
-
import sklearn
|
7 |
import gdown
|
|
|
|
|
8 |
|
9 |
from ..utils.base_model import BaseModel
|
10 |
|
@@ -13,8 +14,8 @@ sys.path.append(
|
|
13 |
)
|
14 |
os.environ["DB_ROOT"] = "" # required by dirtorch
|
15 |
|
16 |
-
from dirtorch.utils import common # noqa: E402
|
17 |
from dirtorch.extract_features import load_model # noqa: E402
|
|
|
18 |
|
19 |
# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
|
20 |
# which has been deprecated in sklearn v0.24
|
|
|
1 |
+
import os
|
2 |
import sys
|
3 |
from pathlib import Path
|
|
|
4 |
from zipfile import ZipFile
|
5 |
+
|
|
|
6 |
import gdown
|
7 |
+
import sklearn
|
8 |
+
import torch
|
9 |
|
10 |
from ..utils.base_model import BaseModel
|
11 |
|
|
|
14 |
)
|
15 |
os.environ["DB_ROOT"] = "" # required by dirtorch
|
16 |
|
|
|
17 |
from dirtorch.extract_features import load_model # noqa: E402
|
18 |
+
from dirtorch.utils import common # noqa: E402
|
19 |
|
20 |
# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
|
21 |
# which has been deprecated in sklearn v0.24
|
hloc/extractors/disk.py
CHANGED
@@ -15,7 +15,7 @@ class DISK(BaseModel):
|
|
15 |
|
16 |
def _init(self, conf):
|
17 |
self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
|
18 |
-
logger.info(
|
19 |
|
20 |
def _forward(self, data):
|
21 |
image = data["image"]
|
|
|
15 |
|
16 |
def _init(self, conf):
|
17 |
self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
|
18 |
+
logger.info("Load DISK model done.")
|
19 |
|
20 |
def _forward(self, data):
|
21 |
image = data["image"]
|
hloc/extractors/dog.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
import kornia
|
|
|
|
|
|
|
2 |
from kornia.feature.laf import (
|
3 |
-
laf_from_center_scale_ori,
|
4 |
extract_patches_from_pyramid,
|
|
|
5 |
)
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import pycolmap
|
9 |
|
10 |
from ..utils.base_model import BaseModel
|
11 |
|
12 |
-
|
13 |
EPS = 1e-6
|
14 |
|
15 |
|
|
|
1 |
import kornia
|
2 |
+
import numpy as np
|
3 |
+
import pycolmap
|
4 |
+
import torch
|
5 |
from kornia.feature.laf import (
|
|
|
6 |
extract_patches_from_pyramid,
|
7 |
+
laf_from_center_scale_ori,
|
8 |
)
|
|
|
|
|
|
|
9 |
|
10 |
from ..utils.base_model import BaseModel
|
11 |
|
|
|
12 |
EPS = 1e-6
|
13 |
|
14 |
|
hloc/extractors/example.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
4 |
import torch
|
5 |
-
from .. import logger
|
6 |
|
|
|
7 |
from ..utils.base_model import BaseModel
|
8 |
|
9 |
example_path = Path(__file__).parent / "../../third_party/example"
|
@@ -35,7 +35,7 @@ class Example(BaseModel):
|
|
35 |
# self.net = ExampleNet(is_test=True)
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
self.net.load_state_dict(state_dict["model_state"])
|
38 |
-
logger.info(
|
39 |
|
40 |
def _forward(self, data):
|
41 |
# data: dict, keys: 'image'
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import torch
|
|
|
5 |
|
6 |
+
from .. import logger
|
7 |
from ..utils.base_model import BaseModel
|
8 |
|
9 |
example_path = Path(__file__).parent / "../../third_party/example"
|
|
|
35 |
# self.net = ExampleNet(is_test=True)
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
self.net.load_state_dict(state_dict["model_state"])
|
38 |
+
logger.info("Load example model done.")
|
39 |
|
40 |
def _forward(self, data):
|
41 |
# data: dict, keys: 'image'
|
hloc/extractors/fire.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
import subprocess
|
3 |
import logging
|
|
|
4 |
import sys
|
|
|
|
|
5 |
import torch
|
6 |
import torchvision.transforms as tvf
|
7 |
|
@@ -42,11 +43,11 @@ class FIRe(BaseModel):
|
|
42 |
if not model_path.exists():
|
43 |
model_path.parent.mkdir(exist_ok=True)
|
44 |
link = self.fire_models[conf["model_name"]]
|
45 |
-
cmd = ["wget", link, "-O", str(model_path)]
|
46 |
logger.info(f"Downloading the FIRe model with `{cmd}`.")
|
47 |
subprocess.run(cmd, check=True)
|
48 |
|
49 |
-
logger.info(
|
50 |
|
51 |
# Load net
|
52 |
state = torch.load(model_path)
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import subprocess
|
3 |
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
import torch
|
7 |
import torchvision.transforms as tvf
|
8 |
|
|
|
43 |
if not model_path.exists():
|
44 |
model_path.parent.mkdir(exist_ok=True)
|
45 |
link = self.fire_models[conf["model_name"]]
|
46 |
+
cmd = ["wget", "--quiet", link, "-O", str(model_path)]
|
47 |
logger.info(f"Downloading the FIRe model with `{cmd}`.")
|
48 |
subprocess.run(cmd, check=True)
|
49 |
|
50 |
+
logger.info("Loading fire model...")
|
51 |
|
52 |
# Load net
|
53 |
state = torch.load(model_path)
|
hloc/extractors/fire_local.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
-
from pathlib import Path
|
2 |
import subprocess
|
3 |
import sys
|
|
|
|
|
4 |
import torch
|
5 |
import torchvision.transforms as tvf
|
6 |
|
7 |
-
from ..utils.base_model import BaseModel
|
8 |
from .. import logger
|
|
|
9 |
|
10 |
fire_path = Path(__file__).parent / "../../third_party/fire"
|
11 |
|
@@ -13,10 +14,6 @@ sys.path.append(str(fire_path))
|
|
13 |
|
14 |
|
15 |
import fire_network
|
16 |
-
from lib.how.how.stages.evaluate import eval_asmk_fire, load_dataset_fire
|
17 |
-
|
18 |
-
from lib.asmk import asmk
|
19 |
-
from asmk import io_helpers, asmk_method, kernel as kern_pkg
|
20 |
|
21 |
EPS = 1e-6
|
22 |
|
@@ -44,18 +41,18 @@ class FIRe(BaseModel):
|
|
44 |
|
45 |
# Config paths
|
46 |
model_path = fire_path / "model" / conf["model_name"]
|
47 |
-
config_path = fire_path / conf["config_name"]
|
48 |
-
asmk_bin_path = fire_path / "model" / conf["asmk_name"]
|
49 |
|
50 |
# Download the model.
|
51 |
if not model_path.exists():
|
52 |
model_path.parent.mkdir(exist_ok=True)
|
53 |
link = self.fire_models[conf["model_name"]]
|
54 |
-
cmd = ["wget", link, "-O", str(model_path)]
|
55 |
logger.info(f"Downloading the FIRe model with `{cmd}`.")
|
56 |
subprocess.run(cmd, check=True)
|
57 |
|
58 |
-
logger.info(
|
59 |
|
60 |
# Load net
|
61 |
state = torch.load(model_path)
|
|
|
|
|
1 |
import subprocess
|
2 |
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
import torch
|
6 |
import torchvision.transforms as tvf
|
7 |
|
|
|
8 |
from .. import logger
|
9 |
+
from ..utils.base_model import BaseModel
|
10 |
|
11 |
fire_path = Path(__file__).parent / "../../third_party/fire"
|
12 |
|
|
|
14 |
|
15 |
|
16 |
import fire_network
|
|
|
|
|
|
|
|
|
17 |
|
18 |
EPS = 1e-6
|
19 |
|
|
|
41 |
|
42 |
# Config paths
|
43 |
model_path = fire_path / "model" / conf["model_name"]
|
44 |
+
config_path = fire_path / conf["config_name"] # noqa: F841
|
45 |
+
asmk_bin_path = fire_path / "model" / conf["asmk_name"] # noqa: F841
|
46 |
|
47 |
# Download the model.
|
48 |
if not model_path.exists():
|
49 |
model_path.parent.mkdir(exist_ok=True)
|
50 |
link = self.fire_models[conf["model_name"]]
|
51 |
+
cmd = ["wget", "--quiet", link, "-O", str(model_path)]
|
52 |
logger.info(f"Downloading the FIRe model with `{cmd}`.")
|
53 |
subprocess.run(cmd, check=True)
|
54 |
|
55 |
+
logger.info("Loading fire model...")
|
56 |
|
57 |
# Load net
|
58 |
state = torch.load(model_path)
|
hloc/extractors/lanet.py
CHANGED
@@ -1,14 +1,17 @@
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
4 |
import torch
|
|
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
7 |
-
|
|
|
|
|
|
|
8 |
|
9 |
lanet_path = Path(__file__).parent / "../../third_party/lanet"
|
10 |
-
sys.path.append(str(lanet_path))
|
11 |
-
from network_v0.model import PointModel
|
12 |
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
@@ -26,11 +29,11 @@ class LANet(BaseModel):
|
|
26 |
lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
|
27 |
)
|
28 |
if not model_path.exists():
|
29 |
-
|
30 |
self.net = PointModel(is_test=True)
|
31 |
state_dict = torch.load(model_path, map_location="cpu")
|
32 |
self.net.load_state_dict(state_dict["model_state"])
|
33 |
-
logger.info(
|
34 |
|
35 |
def _forward(self, data):
|
36 |
image = data["image"]
|
|
|
1 |
+
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
import torch
|
6 |
+
from hloc import logger
|
7 |
|
8 |
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
+
lib_path = Path(__file__).parent / "../../third_party"
|
11 |
+
sys.path.append(str(lib_path))
|
12 |
+
from lanet.network_v0.model import PointModel
|
13 |
|
14 |
lanet_path = Path(__file__).parent / "../../third_party/lanet"
|
|
|
|
|
15 |
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
|
|
29 |
lanet_path / "checkpoints" / f'PointModel_{conf["model_name"]}.pth'
|
30 |
)
|
31 |
if not model_path.exists():
|
32 |
+
logger.warning(f"No model found at {model_path}, start downloading")
|
33 |
self.net = PointModel(is_test=True)
|
34 |
state_dict = torch.load(model_path, map_location="cpu")
|
35 |
self.net.load_state_dict(state_dict["model_state"])
|
36 |
+
logger.info("Load LANet model done.")
|
37 |
|
38 |
def _forward(self, data):
|
39 |
image = data["image"]
|
hloc/extractors/netvlad.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from pathlib import Path
|
2 |
import subprocess
|
|
|
|
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import torch.nn as nn
|
@@ -7,8 +8,8 @@ import torch.nn.functional as F
|
|
7 |
import torchvision.models as models
|
8 |
from scipy.io import loadmat
|
9 |
|
10 |
-
from ..utils.base_model import BaseModel
|
11 |
from .. import logger
|
|
|
12 |
|
13 |
EPS = 1e-6
|
14 |
|
@@ -60,7 +61,7 @@ class NetVLAD(BaseModel):
|
|
60 |
if not checkpoint.exists():
|
61 |
checkpoint.parent.mkdir(exist_ok=True, parents=True)
|
62 |
link = self.dir_models[conf["model_name"]]
|
63 |
-
cmd = ["wget", link, "-O", str(checkpoint)]
|
64 |
logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
|
65 |
subprocess.run(cmd, check=True)
|
66 |
|
|
|
|
|
1 |
import subprocess
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
8 |
import torchvision.models as models
|
9 |
from scipy.io import loadmat
|
10 |
|
|
|
11 |
from .. import logger
|
12 |
+
from ..utils.base_model import BaseModel
|
13 |
|
14 |
EPS = 1e-6
|
15 |
|
|
|
61 |
if not checkpoint.exists():
|
62 |
checkpoint.parent.mkdir(exist_ok=True, parents=True)
|
63 |
link = self.dir_models[conf["model_name"]]
|
64 |
+
cmd = ["wget", "--quiet", link, "-O", str(checkpoint)]
|
65 |
logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
|
66 |
subprocess.run(cmd, check=True)
|
67 |
|
hloc/extractors/r2d2.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
3 |
import torchvision.transforms as tvf
|
4 |
|
5 |
-
from ..utils.base_model import BaseModel
|
6 |
from hloc import logger
|
7 |
|
8 |
-
|
9 |
-
|
10 |
r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
|
11 |
-
|
|
|
12 |
|
13 |
|
14 |
class R2D2(BaseModel):
|
@@ -35,7 +36,7 @@ class R2D2(BaseModel):
|
|
35 |
rel_thr=conf["reliability_threshold"],
|
36 |
rep_thr=conf["repetability_threshold"],
|
37 |
)
|
38 |
-
logger.info(
|
39 |
|
40 |
def _forward(self, data):
|
41 |
img = data["image"]
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import torchvision.transforms as tvf
|
5 |
|
|
|
6 |
from hloc import logger
|
7 |
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
|
11 |
+
sys.path.append(str(r2d2_path))
|
12 |
+
from extract import NonMaxSuppression, extract_multiscale, load_network
|
13 |
|
14 |
|
15 |
class R2D2(BaseModel):
|
|
|
36 |
rel_thr=conf["reliability_threshold"],
|
37 |
rep_thr=conf["repetability_threshold"],
|
38 |
)
|
39 |
+
logger.info("Load R2D2 model done.")
|
40 |
|
41 |
def _forward(self, data):
|
42 |
img = data["image"]
|
hloc/extractors/rekd.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
-
|
4 |
import torch
|
5 |
|
6 |
-
from ..utils.base_model import BaseModel
|
7 |
from hloc import logger
|
8 |
|
|
|
|
|
9 |
rekd_path = Path(__file__).parent / "../../third_party"
|
10 |
sys.path.append(str(rekd_path))
|
11 |
from REKD.training.model.REKD import REKD as REKD_
|
@@ -29,7 +30,7 @@ class REKD(BaseModel):
|
|
29 |
self.net = REKD_(is_test=True)
|
30 |
state_dict = torch.load(model_path, map_location="cpu")
|
31 |
self.net.load_state_dict(state_dict["model_state"])
|
32 |
-
logger.info(
|
33 |
|
34 |
def _forward(self, data):
|
35 |
image = data["image"]
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import torch
|
5 |
|
|
|
6 |
from hloc import logger
|
7 |
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
rekd_path = Path(__file__).parent / "../../third_party"
|
11 |
sys.path.append(str(rekd_path))
|
12 |
from REKD.training.model.REKD import REKD as REKD_
|
|
|
30 |
self.net = REKD_(is_test=True)
|
31 |
state_dict = torch.load(model_path, map_location="cpu")
|
32 |
self.net.load_state_dict(state_dict["model_state"])
|
33 |
+
logger.info("Load REKD model done.")
|
34 |
|
35 |
def _forward(self, data):
|
36 |
image = data["image"]
|
hloc/extractors/rord.py
CHANGED
@@ -3,9 +3,10 @@ from pathlib import Path
|
|
3 |
import subprocess
|
4 |
import torch
|
5 |
|
6 |
-
from ..utils.base_model import BaseModel
|
7 |
from hloc import logger
|
8 |
|
|
|
|
|
9 |
rord_path = Path(__file__).parent / "../../third_party"
|
10 |
sys.path.append(str(rord_path))
|
11 |
from RoRD.lib.model_test import D2Net as _RoRD
|
@@ -42,11 +43,10 @@ class RoRD(BaseModel):
|
|
42 |
subprocess.run(cmd, check=True)
|
43 |
except subprocess.CalledProcessError as e:
|
44 |
logger.error(f"Failed to download the RoRD model.")
|
45 |
-
raise e
|
46 |
self.net = _RoRD(
|
47 |
model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
|
48 |
)
|
49 |
-
logger.info(
|
50 |
|
51 |
def _forward(self, data):
|
52 |
image = data["image"]
|
|
|
3 |
import subprocess
|
4 |
import torch
|
5 |
|
|
|
6 |
from hloc import logger
|
7 |
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
rord_path = Path(__file__).parent / "../../third_party"
|
11 |
sys.path.append(str(rord_path))
|
12 |
from RoRD.lib.model_test import D2Net as _RoRD
|
|
|
43 |
subprocess.run(cmd, check=True)
|
44 |
except subprocess.CalledProcessError as e:
|
45 |
logger.error(f"Failed to download the RoRD model.")
|
|
|
46 |
self.net = _RoRD(
|
47 |
model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
|
48 |
)
|
49 |
+
logger.info("Load RoRD model done.")
|
50 |
|
51 |
def _forward(self, data):
|
52 |
image = data["image"]
|
hloc/extractors/sfd2.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: UTF-8 -*-
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torchvision.transforms as tvf
|
6 |
+
|
7 |
+
from .. import logger
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
+
pram_path = Path(__file__).parent / "../../third_party/pram"
|
11 |
+
sys.path.append(str(pram_path))
|
12 |
+
|
13 |
+
from nets.sfd2 import load_sfd2
|
14 |
+
|
15 |
+
|
16 |
+
class SFD2(BaseModel):
|
17 |
+
default_conf = {
|
18 |
+
"max_keypoints": 4096,
|
19 |
+
"model_name": "sfd2_20230511_210205_resnet4x.79.pth",
|
20 |
+
"conf_th": 0.001,
|
21 |
+
}
|
22 |
+
required_inputs = ["image"]
|
23 |
+
|
24 |
+
def _init(self, conf):
|
25 |
+
self.conf = {**self.default_conf, **conf}
|
26 |
+
self.norm_rgb = tvf.Normalize(
|
27 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
28 |
+
)
|
29 |
+
model_fn = pram_path / "weights" / self.conf["model_name"]
|
30 |
+
self.net = load_sfd2(weight_path=model_fn).eval()
|
31 |
+
|
32 |
+
logger.info("Load SFD2 model done.")
|
33 |
+
|
34 |
+
def _forward(self, data):
|
35 |
+
pred = self.net.extract_local_global(
|
36 |
+
data={"image": self.norm_rgb(data["image"])}, config=self.conf
|
37 |
+
)
|
38 |
+
out = {
|
39 |
+
"keypoints": pred["keypoints"][0][None],
|
40 |
+
"scores": pred["scores"][0][None],
|
41 |
+
"descriptors": pred["descriptors"][0][None],
|
42 |
+
}
|
43 |
+
return out
|
hloc/extractors/sift.py
CHANGED
@@ -4,14 +4,15 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from kornia.color import rgb_to_grayscale
|
7 |
-
from packaging import version
|
8 |
from omegaconf import OmegaConf
|
|
|
9 |
|
10 |
try:
|
11 |
import pycolmap
|
12 |
except ImportError:
|
13 |
pycolmap = None
|
14 |
from hloc import logger
|
|
|
15 |
from ..utils.base_model import BaseModel
|
16 |
|
17 |
|
@@ -140,7 +141,7 @@ class SIFT(BaseModel):
|
|
140 |
f"Unknown backend: {backend} not in "
|
141 |
f"{{{','.join(backends)}}}."
|
142 |
)
|
143 |
-
logger.info(
|
144 |
|
145 |
def extract_single_image(self, image: torch.Tensor):
|
146 |
image_np = image.cpu().numpy().squeeze(0)
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from kornia.color import rgb_to_grayscale
|
|
|
7 |
from omegaconf import OmegaConf
|
8 |
+
from packaging import version
|
9 |
|
10 |
try:
|
11 |
import pycolmap
|
12 |
except ImportError:
|
13 |
pycolmap = None
|
14 |
from hloc import logger
|
15 |
+
|
16 |
from ..utils.base_model import BaseModel
|
17 |
|
18 |
|
|
|
141 |
f"Unknown backend: {backend} not in "
|
142 |
f"{{{','.join(backends)}}}."
|
143 |
)
|
144 |
+
logger.info("Load SIFT model done.")
|
145 |
|
146 |
def extract_single_image(self, image: torch.Tensor):
|
147 |
image_np = image.cpu().numpy().squeeze(0)
|
hloc/extractors/superpoint.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
3 |
import torch
|
4 |
|
5 |
-
from ..utils.base_model import BaseModel
|
6 |
from hloc import logger
|
7 |
|
|
|
|
|
8 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
9 |
from SuperGluePretrainedNetwork.models import superpoint # noqa E402
|
10 |
|
@@ -43,7 +45,7 @@ class SuperPoint(BaseModel):
|
|
43 |
if conf["fix_sampling"]:
|
44 |
superpoint.sample_descriptors = sample_descriptors_fix_sampling
|
45 |
self.net = superpoint.SuperPoint(conf)
|
46 |
-
logger.info(
|
47 |
|
48 |
def _forward(self, data):
|
49 |
return self.net(data, self.conf)
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import torch
|
5 |
|
|
|
6 |
from hloc import logger
|
7 |
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
11 |
from SuperGluePretrainedNetwork.models import superpoint # noqa E402
|
12 |
|
|
|
45 |
if conf["fix_sampling"]:
|
46 |
superpoint.sample_descriptors = sample_descriptors_fix_sampling
|
47 |
self.net = superpoint.SuperPoint(conf)
|
48 |
+
logger.info("Load SuperPoint model done.")
|
49 |
|
50 |
def _forward(self, data):
|
51 |
return self.net(data, self.conf)
|
hloc/extractors/xfeat.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
-
|
3 |
from hloc import logger
|
|
|
4 |
from ..utils.base_model import BaseModel
|
5 |
|
6 |
|
@@ -18,7 +19,7 @@ class XFeat(BaseModel):
|
|
18 |
pretrained=True,
|
19 |
top_k=self.conf["max_keypoints"],
|
20 |
)
|
21 |
-
logger.info(
|
22 |
|
23 |
def _forward(self, data):
|
24 |
pred = self.net.detectAndCompute(
|
|
|
1 |
import torch
|
2 |
+
|
3 |
from hloc import logger
|
4 |
+
|
5 |
from ..utils.base_model import BaseModel
|
6 |
|
7 |
|
|
|
19 |
pretrained=True,
|
20 |
top_k=self.conf["max_keypoints"],
|
21 |
)
|
22 |
+
logger.info("Load XFeat(sparse) model done.")
|
23 |
|
24 |
def _forward(self, data):
|
25 |
pred = self.net.detectAndCompute(
|
hloc/match_dense.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import torchvision.transforms.functional as F
|
4 |
-
|
5 |
from .extract_features import read_image, resize_image
|
6 |
-
import cv2
|
7 |
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
|
|
|
1 |
+
from types import SimpleNamespace
|
2 |
+
|
3 |
+
import cv2
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torchvision.transforms.functional as F
|
7 |
+
|
8 |
from .extract_features import read_image, resize_image
|
|
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
|
hloc/match_features.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1 |
import argparse
|
2 |
-
from typing import Union, Optional, Dict, List, Tuple
|
3 |
-
from pathlib import Path
|
4 |
import pprint
|
|
|
|
|
5 |
from queue import Queue
|
6 |
from threading import Thread
|
7 |
-
from
|
8 |
-
|
9 |
import h5py
|
|
|
10 |
import torch
|
|
|
11 |
|
12 |
-
from . import
|
13 |
from .utils.base_model import dynamic_load
|
14 |
from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
|
15 |
-
import numpy as np
|
16 |
|
17 |
"""
|
18 |
A set of standard configurations that can be directly selected from the command
|
@@ -162,6 +163,13 @@ confs = {
|
|
162 |
"match_threshold": 0.2,
|
163 |
},
|
164 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
}
|
166 |
|
167 |
|
|
|
1 |
import argparse
|
|
|
|
|
2 |
import pprint
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
from queue import Queue
|
6 |
from threading import Thread
|
7 |
+
from typing import Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
import h5py
|
10 |
+
import numpy as np
|
11 |
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
|
14 |
+
from . import logger, matchers
|
15 |
from .utils.base_model import dynamic_load
|
16 |
from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
|
|
|
17 |
|
18 |
"""
|
19 |
A set of standard configurations that can be directly selected from the command
|
|
|
163 |
"match_threshold": 0.2,
|
164 |
},
|
165 |
},
|
166 |
+
"imp": {
|
167 |
+
"output": "matches-imp",
|
168 |
+
"model": {
|
169 |
+
"name": "imp",
|
170 |
+
"match_threshold": 0.2,
|
171 |
+
},
|
172 |
+
},
|
173 |
}
|
174 |
|
175 |
|
hloc/matchers/adalam.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import torch
|
2 |
-
|
3 |
-
from ..utils.base_model import BaseModel
|
4 |
-
|
5 |
from kornia.feature.adalam import AdalamFilter
|
6 |
from kornia.utils.helpers import get_cuda_device_if_available
|
7 |
|
|
|
|
|
8 |
|
9 |
class AdaLAM(BaseModel):
|
10 |
# See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
|
|
|
1 |
import torch
|
|
|
|
|
|
|
2 |
from kornia.feature.adalam import AdalamFilter
|
3 |
from kornia.utils.helpers import get_cuda_device_if_available
|
4 |
|
5 |
+
from ..utils.base_model import BaseModel
|
6 |
+
|
7 |
|
8 |
class AdaLAM(BaseModel):
|
9 |
# See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
|
hloc/matchers/aspanformer.py
CHANGED
@@ -1,17 +1,16 @@
|
|
|
|
1 |
import sys
|
2 |
-
import torch
|
3 |
-
from ..utils.base_model import BaseModel
|
4 |
-
from ..utils import do_system
|
5 |
from pathlib import Path
|
6 |
-
|
|
|
7 |
|
8 |
from .. import logger
|
|
|
9 |
|
10 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
11 |
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
|
12 |
from ASpanFormer.src.config.default import get_cfg_defaults
|
13 |
from ASpanFormer.src.utils.misc import lower_config
|
14 |
-
from ASpanFormer.demo import demo_utils
|
15 |
|
16 |
aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
|
17 |
|
@@ -85,7 +84,7 @@ class ASpanFormer(BaseModel):
|
|
85 |
"state_dict"
|
86 |
]
|
87 |
self.net.load_state_dict(state_dict, strict=False)
|
88 |
-
logger.info(
|
89 |
|
90 |
def _forward(self, data):
|
91 |
data_ = {
|
|
|
1 |
+
import subprocess
|
2 |
import sys
|
|
|
|
|
|
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
|
7 |
from .. import logger
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
|
10 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
11 |
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
|
12 |
from ASpanFormer.src.config.default import get_cfg_defaults
|
13 |
from ASpanFormer.src.utils.misc import lower_config
|
|
|
14 |
|
15 |
aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
|
16 |
|
|
|
84 |
"state_dict"
|
85 |
]
|
86 |
self.net.load_state_dict(state_dict, strict=False)
|
87 |
+
logger.info("Loaded Aspanformer model")
|
88 |
|
89 |
def _forward(self, data):
|
90 |
data_ = {
|
hloc/matchers/cotr.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1 |
-
import sys
|
2 |
import argparse
|
3 |
-
import
|
4 |
-
import warnings
|
5 |
-
import numpy as np
|
6 |
from pathlib import Path
|
|
|
|
|
|
|
7 |
from torchvision.transforms import ToPILImage
|
|
|
8 |
from ..utils.base_model import BaseModel
|
9 |
|
10 |
sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
|
11 |
-
from COTR.utils import utils as utils_cotr
|
12 |
-
from COTR.models import build_model
|
13 |
-
from COTR.options.options import *
|
14 |
-
from COTR.options.options_utils import *
|
15 |
-
from COTR.inference.inference_helper import triangulate_corr
|
16 |
from COTR.inference.sparse_engine import SparseEngine
|
|
|
|
|
|
|
|
|
17 |
|
18 |
utils_cotr.fix_randomness(0)
|
19 |
torch.set_grad_enabled(False)
|
@@ -33,7 +33,7 @@ class COTR(BaseModel):
|
|
33 |
|
34 |
def _init(self, conf):
|
35 |
parser = argparse.ArgumentParser()
|
36 |
-
set_COTR_arguments(parser)
|
37 |
opt = parser.parse_args()
|
38 |
opt.command = " ".join(sys.argv)
|
39 |
opt.load_weights_path = str(
|
|
|
|
|
1 |
import argparse
|
2 |
+
import sys
|
|
|
|
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
from torchvision.transforms import ToPILImage
|
8 |
+
|
9 |
from ..utils.base_model import BaseModel
|
10 |
|
11 |
sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
|
|
|
|
|
|
|
|
|
|
|
12 |
from COTR.inference.sparse_engine import SparseEngine
|
13 |
+
from COTR.models import build_model
|
14 |
+
from COTR.options.options import * # noqa: F403
|
15 |
+
from COTR.options.options_utils import * # noqa: F403
|
16 |
+
from COTR.utils import utils as utils_cotr
|
17 |
|
18 |
utils_cotr.fix_randomness(0)
|
19 |
torch.set_grad_enabled(False)
|
|
|
33 |
|
34 |
def _init(self, conf):
|
35 |
parser = argparse.ArgumentParser()
|
36 |
+
set_COTR_arguments(parser) # noqa: F405
|
37 |
opt = parser.parse_args()
|
38 |
opt.command = " ".join(sys.argv)
|
39 |
opt.load_weights_path = str(
|
hloc/matchers/dkm.py
CHANGED
@@ -1,10 +1,12 @@
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
-
|
6 |
-
from ..utils.base_model import BaseModel
|
7 |
from .. import logger
|
|
|
8 |
|
9 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
10 |
from DKM.dkm import DKMv3_outdoor
|
@@ -37,11 +39,11 @@ class DKMv3(BaseModel):
|
|
37 |
if not model_path.exists():
|
38 |
model_path.parent.mkdir(exist_ok=True)
|
39 |
link = self.dkm_models[conf["model_name"]]
|
40 |
-
cmd = ["wget", link, "-O", str(model_path)]
|
41 |
logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
|
42 |
subprocess.run(cmd, check=True)
|
43 |
self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device)
|
44 |
-
logger.info(
|
45 |
|
46 |
def _forward(self, data):
|
47 |
img0 = data["image0"].cpu().numpy().squeeze() * 255
|
|
|
1 |
+
import subprocess
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
import torch
|
6 |
from PIL import Image
|
7 |
+
|
|
|
8 |
from .. import logger
|
9 |
+
from ..utils.base_model import BaseModel
|
10 |
|
11 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
12 |
from DKM.dkm import DKMv3_outdoor
|
|
|
39 |
if not model_path.exists():
|
40 |
model_path.parent.mkdir(exist_ok=True)
|
41 |
link = self.dkm_models[conf["model_name"]]
|
42 |
+
cmd = ["wget", "--quiet", link, "-O", str(model_path)]
|
43 |
logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
|
44 |
subprocess.run(cmd, check=True)
|
45 |
self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=device)
|
46 |
+
logger.info("Loading DKMv3 model done")
|
47 |
|
48 |
def _forward(self, data):
|
49 |
img0 = data["image0"].cpu().numpy().squeeze() * 255
|