Spaces:
Running
Running
Vincentqyw
commited on
Commit
•
e8fe67e
1
Parent(s):
8811cfe
update: roma and dust3r
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- third_party/RoMa +0 -1
- third_party/RoMa/.gitignore +11 -0
- third_party/RoMa/LICENSE +21 -0
- third_party/RoMa/README.md +123 -0
- third_party/RoMa/assets/sacre_coeur_A.jpg +3 -0
- third_party/RoMa/assets/sacre_coeur_B.jpg +3 -0
- third_party/RoMa/assets/toronto_A.jpg +3 -0
- third_party/RoMa/assets/toronto_B.jpg +3 -0
- third_party/RoMa/data/.gitignore +2 -0
- third_party/RoMa/demo/demo_3D_effect.py +46 -0
- third_party/RoMa/demo/demo_fundamental.py +33 -0
- third_party/RoMa/demo/demo_match.py +47 -0
- third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
- third_party/RoMa/demo/gif/.gitignore +2 -0
- third_party/RoMa/requirements.txt +14 -0
- third_party/RoMa/romatch/__init__.py +8 -0
- third_party/RoMa/romatch/benchmarks/__init__.py +6 -0
- third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
- third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
- third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
- third_party/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
- third_party/RoMa/romatch/checkpointing/__init__.py +1 -0
- third_party/RoMa/romatch/checkpointing/checkpoint.py +60 -0
- third_party/RoMa/romatch/datasets/__init__.py +2 -0
- third_party/RoMa/romatch/datasets/megadepth.py +232 -0
- third_party/RoMa/romatch/datasets/scannet.py +160 -0
- third_party/RoMa/romatch/losses/__init__.py +1 -0
- third_party/RoMa/romatch/losses/robust_loss.py +161 -0
- third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py +160 -0
- third_party/RoMa/romatch/models/__init__.py +1 -0
- third_party/RoMa/romatch/models/encoders.py +119 -0
- third_party/RoMa/romatch/models/matcher.py +772 -0
- third_party/RoMa/romatch/models/model_zoo/__init__.py +70 -0
- third_party/RoMa/romatch/models/model_zoo/roma_models.py +170 -0
- third_party/RoMa/romatch/models/tiny.py +304 -0
- third_party/RoMa/romatch/models/transformer/__init__.py +47 -0
- third_party/RoMa/romatch/models/transformer/dinov2.py +359 -0
- third_party/RoMa/romatch/models/transformer/layers/__init__.py +12 -0
- third_party/RoMa/romatch/models/transformer/layers/attention.py +81 -0
- third_party/RoMa/romatch/models/transformer/layers/block.py +252 -0
- third_party/RoMa/romatch/models/transformer/layers/dino_head.py +59 -0
- third_party/RoMa/romatch/models/transformer/layers/drop_path.py +35 -0
- third_party/RoMa/romatch/models/transformer/layers/layer_scale.py +28 -0
- third_party/RoMa/romatch/models/transformer/layers/mlp.py +41 -0
- third_party/RoMa/romatch/models/transformer/layers/patch_embed.py +89 -0
- third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py +63 -0
- third_party/RoMa/romatch/train/__init__.py +1 -0
- third_party/RoMa/romatch/train/train.py +102 -0
- third_party/RoMa/romatch/utils/__init__.py +16 -0
third_party/RoMa
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 116537a1849f26b1ecf8e1b7ac0980a51befdc07
|
|
|
|
third_party/RoMa/.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.egg-info*
|
2 |
+
*.vscode*
|
3 |
+
*__pycache__*
|
4 |
+
vis*
|
5 |
+
workspace*
|
6 |
+
.venv
|
7 |
+
.DS_Store
|
8 |
+
jobs/*
|
9 |
+
*ignore_me*
|
10 |
+
*.pth
|
11 |
+
wandb*
|
third_party/RoMa/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Johan Edstedt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
third_party/RoMa/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
<p align="center">
|
3 |
+
<h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
|
6 |
+
·
|
7 |
+
<a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
|
8 |
+
·
|
9 |
+
<a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
|
10 |
+
·
|
11 |
+
<a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
|
12 |
+
·
|
13 |
+
<a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
|
14 |
+
</p>
|
15 |
+
<h2 align="center"><p>
|
16 |
+
<a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
|
17 |
+
<a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
|
18 |
+
</p></h2>
|
19 |
+
<div align="center"></div>
|
20 |
+
</p>
|
21 |
+
<br/>
|
22 |
+
<p align="center">
|
23 |
+
<img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
|
24 |
+
<br>
|
25 |
+
<em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Setup/Install
|
29 |
+
In your python environment (tested on Linux python 3.10), run:
|
30 |
+
```bash
|
31 |
+
pip install -e .
|
32 |
+
```
|
33 |
+
## Demo / How to Use
|
34 |
+
We provide two demos in the [demos folder](demo).
|
35 |
+
Here's the gist of it:
|
36 |
+
```python
|
37 |
+
from romatch import roma_outdoor
|
38 |
+
roma_model = roma_outdoor(device=device)
|
39 |
+
# Match
|
40 |
+
warp, certainty = roma_model.match(imA_path, imB_path, device=device)
|
41 |
+
# Sample matches for estimation
|
42 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
43 |
+
# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
|
44 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
45 |
+
# Find a fundamental matrix (or anything else of interest)
|
46 |
+
F, mask = cv2.findFundamentalMat(
|
47 |
+
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
48 |
+
)
|
49 |
+
```
|
50 |
+
|
51 |
+
**New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
|
52 |
+
|
53 |
+
## Settings
|
54 |
+
|
55 |
+
### Resolution
|
56 |
+
By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
|
57 |
+
You can change this at construction (see roma_outdoor kwargs).
|
58 |
+
You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
|
59 |
+
|
60 |
+
### Sampling
|
61 |
+
roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
|
62 |
+
|
63 |
+
|
64 |
+
## Reproducing Results
|
65 |
+
The experiments in the paper are provided in the [experiments folder](experiments).
|
66 |
+
|
67 |
+
### Training
|
68 |
+
1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
|
69 |
+
2. Run the relevant experiment, e.g.,
|
70 |
+
```bash
|
71 |
+
torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
|
72 |
+
```
|
73 |
+
### Testing
|
74 |
+
```bash
|
75 |
+
python experiments/roma_outdoor.py --only_test --benchmark mega-1500
|
76 |
+
```
|
77 |
+
## License
|
78 |
+
All our code except DINOv2 is MIT license.
|
79 |
+
DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
|
80 |
+
|
81 |
+
## Acknowledgement
|
82 |
+
Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
83 |
+
|
84 |
+
## Tiny RoMa
|
85 |
+
If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
|
86 |
+
```python
|
87 |
+
from romatch import tiny_roma_v1_outdoor
|
88 |
+
tiny_roma_model = tiny_roma_v1_outdoor(device=device)
|
89 |
+
```
|
90 |
+
Mega1500:
|
91 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
92 |
+
|----------|----------|----------|----------|
|
93 |
+
| XFeat | 46.4 | 58.9 | 69.2 |
|
94 |
+
| XFeat* | 51.9 | 67.2 | 78.9 |
|
95 |
+
| Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
|
96 |
+
| RoMa | - | - | - |
|
97 |
+
|
98 |
+
Mega-8-Scenes (See DKM):
|
99 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
100 |
+
|----------|----------|----------|----------|
|
101 |
+
| XFeat | - | - | - |
|
102 |
+
| XFeat* | 50.1 | 64.4 | 75.2 |
|
103 |
+
| Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
|
104 |
+
| RoMa | - | - | - |
|
105 |
+
|
106 |
+
IMC22 :'):
|
107 |
+
| | mAA@10 |
|
108 |
+
|----------|----------|
|
109 |
+
| XFeat | 42.1 |
|
110 |
+
| XFeat* | - |
|
111 |
+
| Tiny RoMa v1 | 42.2 |
|
112 |
+
| RoMa | - |
|
113 |
+
|
114 |
+
## BibTeX
|
115 |
+
If you find our models useful, please consider citing our paper!
|
116 |
+
```
|
117 |
+
@article{edstedt2024roma,
|
118 |
+
title={{RoMa: Robust Dense Feature Matching}},
|
119 |
+
author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
|
120 |
+
journal={IEEE Conference on Computer Vision and Pattern Recognition},
|
121 |
+
year={2024}
|
122 |
+
}
|
123 |
+
```
|
third_party/RoMa/assets/sacre_coeur_A.jpg
ADDED
Git LFS Details
|
third_party/RoMa/assets/sacre_coeur_B.jpg
ADDED
Git LFS Details
|
third_party/RoMa/assets/toronto_A.jpg
ADDED
Git LFS Details
|
third_party/RoMa/assets/toronto_B.jpg
ADDED
Git LFS Details
|
third_party/RoMa/data/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
third_party/RoMa/demo/demo_3D_effect.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from romatch.utils.utils import tensor_to_pil
|
6 |
+
|
7 |
+
from romatch import roma_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
16 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
17 |
+
parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
|
18 |
+
|
19 |
+
args, _ = parser.parse_known_args()
|
20 |
+
im1_path = args.im_A_path
|
21 |
+
im2_path = args.im_B_path
|
22 |
+
save_path = args.save_path
|
23 |
+
|
24 |
+
# Create model
|
25 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
26 |
+
roma_model.symmetric = False
|
27 |
+
|
28 |
+
H, W = roma_model.get_output_resolution()
|
29 |
+
|
30 |
+
im1 = Image.open(im1_path).resize((W, H))
|
31 |
+
im2 = Image.open(im2_path).resize((W, H))
|
32 |
+
|
33 |
+
# Match
|
34 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
35 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
36 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
37 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
38 |
+
|
39 |
+
coords_A, coords_B = warp[...,:2], warp[...,2:]
|
40 |
+
for i, x in enumerate(np.linspace(0,2*np.pi,200)):
|
41 |
+
t = (1 + np.cos(x))/2
|
42 |
+
interp_warp = (1-t)*coords_A + t*coords_B
|
43 |
+
im2_transfer_rgb = F.grid_sample(
|
44 |
+
x2[None], interp_warp[None], mode="bilinear", align_corners=False
|
45 |
+
)[0]
|
46 |
+
tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
|
third_party/RoMa/demo/demo_fundamental.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from romatch import roma_outdoor
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
parser = ArgumentParser()
|
12 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
13 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
14 |
+
|
15 |
+
args, _ = parser.parse_known_args()
|
16 |
+
im1_path = args.im_A_path
|
17 |
+
im2_path = args.im_B_path
|
18 |
+
|
19 |
+
# Create model
|
20 |
+
roma_model = roma_outdoor(device=device)
|
21 |
+
|
22 |
+
|
23 |
+
W_A, H_A = Image.open(im1_path).size
|
24 |
+
W_B, H_B = Image.open(im2_path).size
|
25 |
+
|
26 |
+
# Match
|
27 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
28 |
+
# Sample matches for estimation
|
29 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
30 |
+
kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
31 |
+
F, mask = cv2.findFundamentalMat(
|
32 |
+
kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
33 |
+
)
|
third_party/RoMa/demo/demo_match.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from romatch.utils.utils import tensor_to_pil
|
6 |
+
|
7 |
+
from romatch import roma_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
16 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
17 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
18 |
+
|
19 |
+
args, _ = parser.parse_known_args()
|
20 |
+
im1_path = args.im_A_path
|
21 |
+
im2_path = args.im_B_path
|
22 |
+
save_path = args.save_path
|
23 |
+
|
24 |
+
# Create model
|
25 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
26 |
+
|
27 |
+
H, W = roma_model.get_output_resolution()
|
28 |
+
|
29 |
+
im1 = Image.open(im1_path).resize((W, H))
|
30 |
+
im2 = Image.open(im2_path).resize((W, H))
|
31 |
+
|
32 |
+
# Match
|
33 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
34 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
35 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
36 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
37 |
+
|
38 |
+
im2_transfer_rgb = F.grid_sample(
|
39 |
+
x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
40 |
+
)[0]
|
41 |
+
im1_transfer_rgb = F.grid_sample(
|
42 |
+
x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
43 |
+
)[0]
|
44 |
+
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
|
45 |
+
white_im = torch.ones((H,2*W),device=device)
|
46 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
47 |
+
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
third_party/RoMa/demo/demo_match_opencv_sift.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2 as cv
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
14 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
15 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
16 |
+
|
17 |
+
args, _ = parser.parse_known_args()
|
18 |
+
im1_path = args.im_A_path
|
19 |
+
im2_path = args.im_B_path
|
20 |
+
save_path = args.save_path
|
21 |
+
|
22 |
+
img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
|
23 |
+
img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
|
24 |
+
# Initiate SIFT detector
|
25 |
+
sift = cv.SIFT_create()
|
26 |
+
# find the keypoints and descriptors with SIFT
|
27 |
+
kp1, des1 = sift.detectAndCompute(img1,None)
|
28 |
+
kp2, des2 = sift.detectAndCompute(img2,None)
|
29 |
+
# BFMatcher with default params
|
30 |
+
bf = cv.BFMatcher()
|
31 |
+
matches = bf.knnMatch(des1,des2,k=2)
|
32 |
+
# Apply ratio test
|
33 |
+
good = []
|
34 |
+
for m,n in matches:
|
35 |
+
if m.distance < 0.75*n.distance:
|
36 |
+
good.append([m])
|
37 |
+
# cv.drawMatchesKnn expects list of lists as matches.
|
38 |
+
draw_params = dict(matchColor = (255,0,0), # draw matches in red color
|
39 |
+
singlePointColor = None,
|
40 |
+
flags = 2)
|
41 |
+
|
42 |
+
img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
|
43 |
+
Image.fromarray(img3).save("demo/sift_matches.png")
|
third_party/RoMa/demo/gif/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
third_party/RoMa/requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
einops
|
3 |
+
torchvision
|
4 |
+
opencv-python
|
5 |
+
kornia
|
6 |
+
albumentations
|
7 |
+
loguru
|
8 |
+
tqdm
|
9 |
+
matplotlib
|
10 |
+
h5py
|
11 |
+
wandb
|
12 |
+
timm
|
13 |
+
poselib
|
14 |
+
#xformers # Optional, used for memefficient attention
|
third_party/RoMa/romatch/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
|
3 |
+
|
4 |
+
DEBUG_MODE = False
|
5 |
+
RANK = int(os.environ.get('RANK', default = 0))
|
6 |
+
GLOBAL_STEP = 0
|
7 |
+
STEP_SIZE = 1
|
8 |
+
LOCAL_RANK = -1
|
third_party/RoMa/romatch/benchmarks/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
|
2 |
+
from .scannet_benchmark import ScanNetBenchmark
|
3 |
+
from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
|
4 |
+
from .megadepth_dense_benchmark import MegadepthDenseBenchmark
|
5 |
+
from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
|
6 |
+
from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
|
third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
from romatch.utils import pose_auc
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
class HpatchesHomogBenchmark:
|
12 |
+
"""Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
|
13 |
+
|
14 |
+
def __init__(self, dataset_path) -> None:
|
15 |
+
seqs_dir = "hpatches-sequences-release"
|
16 |
+
self.seqs_path = os.path.join(dataset_path, seqs_dir)
|
17 |
+
self.seq_names = sorted(os.listdir(self.seqs_path))
|
18 |
+
# Ignore seqs is same as LoFTR.
|
19 |
+
self.ignore_seqs = set(
|
20 |
+
[
|
21 |
+
"i_contruction",
|
22 |
+
"i_crownnight",
|
23 |
+
"i_dc",
|
24 |
+
"i_pencils",
|
25 |
+
"i_whitebuilding",
|
26 |
+
"v_artisans",
|
27 |
+
"v_astronautis",
|
28 |
+
"v_talent",
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
|
33 |
+
offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
|
34 |
+
im_A_coords = (
|
35 |
+
np.stack(
|
36 |
+
(
|
37 |
+
wq * (im_A_coords[..., 0] + 1) / 2,
|
38 |
+
hq * (im_A_coords[..., 1] + 1) / 2,
|
39 |
+
),
|
40 |
+
axis=-1,
|
41 |
+
)
|
42 |
+
- offset
|
43 |
+
)
|
44 |
+
im_A_to_im_B = (
|
45 |
+
np.stack(
|
46 |
+
(
|
47 |
+
wsup * (im_A_to_im_B[..., 0] + 1) / 2,
|
48 |
+
hsup * (im_A_to_im_B[..., 1] + 1) / 2,
|
49 |
+
),
|
50 |
+
axis=-1,
|
51 |
+
)
|
52 |
+
- offset
|
53 |
+
)
|
54 |
+
return im_A_coords, im_A_to_im_B
|
55 |
+
|
56 |
+
def benchmark(self, model, model_name = None):
|
57 |
+
n_matches = []
|
58 |
+
homog_dists = []
|
59 |
+
for seq_idx, seq_name in tqdm(
|
60 |
+
enumerate(self.seq_names), total=len(self.seq_names)
|
61 |
+
):
|
62 |
+
im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
|
63 |
+
im_A = Image.open(im_A_path)
|
64 |
+
w1, h1 = im_A.size
|
65 |
+
for im_idx in range(2, 7):
|
66 |
+
im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
|
67 |
+
im_B = Image.open(im_B_path)
|
68 |
+
w2, h2 = im_B.size
|
69 |
+
H = np.loadtxt(
|
70 |
+
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
71 |
+
)
|
72 |
+
dense_matches, dense_certainty = model.match(
|
73 |
+
im_A_path, im_B_path
|
74 |
+
)
|
75 |
+
good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
|
76 |
+
pos_a, pos_b = self.convert_coordinates(
|
77 |
+
good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
|
78 |
+
)
|
79 |
+
try:
|
80 |
+
H_pred, inliers = cv2.findHomography(
|
81 |
+
pos_a,
|
82 |
+
pos_b,
|
83 |
+
method = cv2.RANSAC,
|
84 |
+
confidence = 0.99999,
|
85 |
+
ransacReprojThreshold = 3 * min(w2, h2) / 480,
|
86 |
+
)
|
87 |
+
except:
|
88 |
+
H_pred = None
|
89 |
+
if H_pred is None:
|
90 |
+
H_pred = np.zeros((3, 3))
|
91 |
+
H_pred[2, 2] = 1.0
|
92 |
+
corners = np.array(
|
93 |
+
[[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
|
94 |
+
)
|
95 |
+
real_warped_corners = np.dot(corners, np.transpose(H))
|
96 |
+
real_warped_corners = (
|
97 |
+
real_warped_corners[:, :2] / real_warped_corners[:, 2:]
|
98 |
+
)
|
99 |
+
warped_corners = np.dot(corners, np.transpose(H_pred))
|
100 |
+
warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
|
101 |
+
mean_dist = np.mean(
|
102 |
+
np.linalg.norm(real_warped_corners - warped_corners, axis=1)
|
103 |
+
) / (min(w2, h2) / 480.0)
|
104 |
+
homog_dists.append(mean_dist)
|
105 |
+
|
106 |
+
n_matches = np.array(n_matches)
|
107 |
+
thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
108 |
+
auc = pose_auc(np.array(homog_dists), thresholds)
|
109 |
+
return {
|
110 |
+
"hpatches_homog_auc_3": auc[2],
|
111 |
+
"hpatches_homog_auc_5": auc[4],
|
112 |
+
"hpatches_homog_auc_10": auc[9],
|
113 |
+
}
|
third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
from romatch.datasets import MegadepthBuilder
|
5 |
+
from romatch.utils import warp_kpts
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import romatch
|
8 |
+
|
9 |
+
class MegadepthDenseBenchmark:
|
10 |
+
def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
|
11 |
+
mega = MegadepthBuilder(data_root=data_root)
|
12 |
+
self.dataset = ConcatDataset(
|
13 |
+
mega.build_scenes(split="test_loftr", ht=h, wt=w)
|
14 |
+
) # fixed resolution of 384,512
|
15 |
+
self.num_samples = num_samples
|
16 |
+
|
17 |
+
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
|
18 |
+
b, h1, w1, d = dense_matches.shape
|
19 |
+
with torch.no_grad():
|
20 |
+
x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
|
21 |
+
mask, x2 = warp_kpts(
|
22 |
+
x1.double(),
|
23 |
+
depth1.double(),
|
24 |
+
depth2.double(),
|
25 |
+
T_1to2.double(),
|
26 |
+
K1.double(),
|
27 |
+
K2.double(),
|
28 |
+
)
|
29 |
+
x2 = torch.stack(
|
30 |
+
(w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
|
31 |
+
)
|
32 |
+
prob = mask.float().reshape(b, h1, w1)
|
33 |
+
x2_hat = dense_matches[..., 2:]
|
34 |
+
x2_hat = torch.stack(
|
35 |
+
(w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
|
36 |
+
)
|
37 |
+
gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
|
38 |
+
gd = gd[prob == 1]
|
39 |
+
pck_1 = (gd < 1.0).float().mean()
|
40 |
+
pck_3 = (gd < 3.0).float().mean()
|
41 |
+
pck_5 = (gd < 5.0).float().mean()
|
42 |
+
return gd, pck_1, pck_3, pck_5, prob
|
43 |
+
|
44 |
+
def benchmark(self, model, batch_size=8):
|
45 |
+
model.train(False)
|
46 |
+
with torch.no_grad():
|
47 |
+
gd_tot = 0.0
|
48 |
+
pck_1_tot = 0.0
|
49 |
+
pck_3_tot = 0.0
|
50 |
+
pck_5_tot = 0.0
|
51 |
+
sampler = torch.utils.data.WeightedRandomSampler(
|
52 |
+
torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
|
53 |
+
)
|
54 |
+
B = batch_size
|
55 |
+
dataloader = torch.utils.data.DataLoader(
|
56 |
+
self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
|
57 |
+
)
|
58 |
+
for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
|
59 |
+
im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
|
60 |
+
data["im_A"].cuda(),
|
61 |
+
data["im_B"].cuda(),
|
62 |
+
data["im_A_depth"].cuda(),
|
63 |
+
data["im_B_depth"].cuda(),
|
64 |
+
data["T_1to2"].cuda(),
|
65 |
+
data["K1"].cuda(),
|
66 |
+
data["K2"].cuda(),
|
67 |
+
)
|
68 |
+
matches, certainty = model.match(im_A, im_B, batched=True)
|
69 |
+
gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
|
70 |
+
depth1, depth2, T_1to2, K1, K2, matches
|
71 |
+
)
|
72 |
+
if romatch.DEBUG_MODE:
|
73 |
+
from romatch.utils.utils import tensor_to_pil
|
74 |
+
import torch.nn.functional as F
|
75 |
+
path = "vis"
|
76 |
+
H, W = model.get_output_resolution()
|
77 |
+
white_im = torch.ones((B,1,H,W),device="cuda")
|
78 |
+
im_B_transfer_rgb = F.grid_sample(
|
79 |
+
im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
|
80 |
+
)
|
81 |
+
warp_im = im_B_transfer_rgb
|
82 |
+
c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
|
83 |
+
vis_im = c_b * warp_im + (1 - c_b) * white_im
|
84 |
+
for b in range(B):
|
85 |
+
import os
|
86 |
+
os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
|
87 |
+
tensor_to_pil(vis_im[b], unnormalize=True).save(
|
88 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
|
89 |
+
tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
|
90 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
|
91 |
+
tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
|
92 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
|
93 |
+
|
94 |
+
|
95 |
+
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
|
96 |
+
gd_tot + gd.mean(),
|
97 |
+
pck_1_tot + pck_1,
|
98 |
+
pck_3_tot + pck_3,
|
99 |
+
pck_5_tot + pck_5,
|
100 |
+
)
|
101 |
+
return {
|
102 |
+
"epe": gd_tot.item() / len(dataloader),
|
103 |
+
"mega_pck_1": pck_1_tot.item() / len(dataloader),
|
104 |
+
"mega_pck_3": pck_3_tot.item() / len(dataloader),
|
105 |
+
"mega_pck_5": pck_5_tot.item() / len(dataloader),
|
106 |
+
}
|
third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from romatch.utils import *
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import romatch
|
8 |
+
import kornia.geometry.epipolar as kepi
|
9 |
+
|
10 |
+
class MegaDepthPoseEstimationBenchmark:
|
11 |
+
def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
|
12 |
+
if scene_names is None:
|
13 |
+
self.scene_names = [
|
14 |
+
"0015_0.1_0.3.npz",
|
15 |
+
"0015_0.3_0.5.npz",
|
16 |
+
"0022_0.1_0.3.npz",
|
17 |
+
"0022_0.3_0.5.npz",
|
18 |
+
"0022_0.5_0.7.npz",
|
19 |
+
]
|
20 |
+
else:
|
21 |
+
self.scene_names = scene_names
|
22 |
+
self.scenes = [
|
23 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
24 |
+
for scene in self.scene_names
|
25 |
+
]
|
26 |
+
self.data_root = data_root
|
27 |
+
|
28 |
+
def benchmark(self, model, model_name = None):
|
29 |
+
with torch.no_grad():
|
30 |
+
data_root = self.data_root
|
31 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
32 |
+
thresholds = [5, 10, 20]
|
33 |
+
for scene_ind in range(len(self.scenes)):
|
34 |
+
import os
|
35 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
36 |
+
scene = self.scenes[scene_ind]
|
37 |
+
pairs = scene["pair_infos"]
|
38 |
+
intrinsics = scene["intrinsics"]
|
39 |
+
poses = scene["poses"]
|
40 |
+
im_paths = scene["image_paths"]
|
41 |
+
pair_inds = range(len(pairs))
|
42 |
+
for pairind in tqdm(pair_inds):
|
43 |
+
idx1, idx2 = pairs[pairind][0]
|
44 |
+
K1 = intrinsics[idx1].copy()
|
45 |
+
T1 = poses[idx1].copy()
|
46 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
47 |
+
K2 = intrinsics[idx2].copy()
|
48 |
+
T2 = poses[idx2].copy()
|
49 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
50 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
51 |
+
T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
|
52 |
+
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
53 |
+
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
54 |
+
dense_matches, dense_certainty = model.match(
|
55 |
+
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
56 |
+
)
|
57 |
+
sparse_matches,_ = model.sample(
|
58 |
+
dense_matches, dense_certainty, 5_000
|
59 |
+
)
|
60 |
+
|
61 |
+
im_A = Image.open(im_A_path)
|
62 |
+
w1, h1 = im_A.size
|
63 |
+
im_B = Image.open(im_B_path)
|
64 |
+
w2, h2 = im_B.size
|
65 |
+
if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
|
66 |
+
scale1 = 1200 / max(w1, h1)
|
67 |
+
scale2 = 1200 / max(w2, h2)
|
68 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
69 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
70 |
+
K1, K2 = K1.copy(), K2.copy()
|
71 |
+
K1[:2] = K1[:2] * scale1
|
72 |
+
K2[:2] = K2[:2] * scale2
|
73 |
+
|
74 |
+
kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
|
75 |
+
kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
|
76 |
+
for _ in range(5):
|
77 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
78 |
+
kpts1 = kpts1[shuffling]
|
79 |
+
kpts2 = kpts2[shuffling]
|
80 |
+
try:
|
81 |
+
threshold = 0.5
|
82 |
+
norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
83 |
+
R_est, t_est, mask = estimate_pose(
|
84 |
+
kpts1,
|
85 |
+
kpts2,
|
86 |
+
K1,
|
87 |
+
K2,
|
88 |
+
norm_threshold,
|
89 |
+
conf=0.99999,
|
90 |
+
)
|
91 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
92 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
93 |
+
e_pose = max(e_t, e_R)
|
94 |
+
except Exception as e:
|
95 |
+
print(repr(e))
|
96 |
+
e_t, e_R = 90, 90
|
97 |
+
e_pose = max(e_t, e_R)
|
98 |
+
tot_e_t.append(e_t)
|
99 |
+
tot_e_R.append(e_R)
|
100 |
+
tot_e_pose.append(e_pose)
|
101 |
+
tot_e_pose = np.array(tot_e_pose)
|
102 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
103 |
+
acc_5 = (tot_e_pose < 5).mean()
|
104 |
+
acc_10 = (tot_e_pose < 10).mean()
|
105 |
+
acc_15 = (tot_e_pose < 15).mean()
|
106 |
+
acc_20 = (tot_e_pose < 20).mean()
|
107 |
+
map_5 = acc_5
|
108 |
+
map_10 = np.mean([acc_5, acc_10])
|
109 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
110 |
+
print(f"{model_name} auc: {auc}")
|
111 |
+
return {
|
112 |
+
"auc_5": auc[0],
|
113 |
+
"auc_10": auc[1],
|
114 |
+
"auc_20": auc[2],
|
115 |
+
"map_5": map_5,
|
116 |
+
"map_10": map_10,
|
117 |
+
"map_20": map_20,
|
118 |
+
}
|
third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from romatch.utils import *
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import romatch
|
8 |
+
import kornia.geometry.epipolar as kepi
|
9 |
+
|
10 |
+
# wrap cause pyposelib is still in dev
|
11 |
+
# will add in deps later
|
12 |
+
import poselib
|
13 |
+
|
14 |
+
class Mega1500PoseLibBenchmark:
|
15 |
+
def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None:
|
16 |
+
if scene_names is None:
|
17 |
+
self.scene_names = [
|
18 |
+
"0015_0.1_0.3.npz",
|
19 |
+
"0015_0.3_0.5.npz",
|
20 |
+
"0022_0.1_0.3.npz",
|
21 |
+
"0022_0.3_0.5.npz",
|
22 |
+
"0022_0.5_0.7.npz",
|
23 |
+
]
|
24 |
+
else:
|
25 |
+
self.scene_names = scene_names
|
26 |
+
self.scenes = [
|
27 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
28 |
+
for scene in self.scene_names
|
29 |
+
]
|
30 |
+
self.data_root = data_root
|
31 |
+
self.num_ransac_iter = num_ransac_iter
|
32 |
+
self.test_every = test_every
|
33 |
+
|
34 |
+
def benchmark(self, model, model_name = None):
|
35 |
+
with torch.no_grad():
|
36 |
+
data_root = self.data_root
|
37 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
38 |
+
thresholds = [5, 10, 20]
|
39 |
+
for scene_ind in range(len(self.scenes)):
|
40 |
+
import os
|
41 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
42 |
+
scene = self.scenes[scene_ind]
|
43 |
+
pairs = scene["pair_infos"]
|
44 |
+
intrinsics = scene["intrinsics"]
|
45 |
+
poses = scene["poses"]
|
46 |
+
im_paths = scene["image_paths"]
|
47 |
+
pair_inds = range(len(pairs))[::self.test_every]
|
48 |
+
for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
|
49 |
+
idx1, idx2 = pairs[pairind][0]
|
50 |
+
K1 = intrinsics[idx1].copy()
|
51 |
+
T1 = poses[idx1].copy()
|
52 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
53 |
+
K2 = intrinsics[idx2].copy()
|
54 |
+
T2 = poses[idx2].copy()
|
55 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
56 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
57 |
+
T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
|
58 |
+
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
59 |
+
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
60 |
+
dense_matches, dense_certainty = model.match(
|
61 |
+
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
62 |
+
)
|
63 |
+
sparse_matches,_ = model.sample(
|
64 |
+
dense_matches, dense_certainty, 5_000
|
65 |
+
)
|
66 |
+
|
67 |
+
im_A = Image.open(im_A_path)
|
68 |
+
w1, h1 = im_A.size
|
69 |
+
im_B = Image.open(im_B_path)
|
70 |
+
w2, h2 = im_B.size
|
71 |
+
kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
|
72 |
+
kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
|
73 |
+
for _ in range(self.num_ransac_iter):
|
74 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
75 |
+
kpts1 = kpts1[shuffling]
|
76 |
+
kpts2 = kpts2[shuffling]
|
77 |
+
try:
|
78 |
+
threshold = 1
|
79 |
+
camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]}
|
80 |
+
camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]}
|
81 |
+
relpose, res = poselib.estimate_relative_pose(
|
82 |
+
kpts1,
|
83 |
+
kpts2,
|
84 |
+
camera1,
|
85 |
+
camera2,
|
86 |
+
ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000},
|
87 |
+
)
|
88 |
+
Rt_est = relpose.Rt
|
89 |
+
R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:]
|
90 |
+
mask = np.array(res['inliers']).astype(np.float32)
|
91 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
92 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
93 |
+
e_pose = max(e_t, e_R)
|
94 |
+
except Exception as e:
|
95 |
+
print(repr(e))
|
96 |
+
e_t, e_R = 90, 90
|
97 |
+
e_pose = max(e_t, e_R)
|
98 |
+
tot_e_t.append(e_t)
|
99 |
+
tot_e_R.append(e_R)
|
100 |
+
tot_e_pose.append(e_pose)
|
101 |
+
pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
|
102 |
+
tot_e_pose = np.array(tot_e_pose)
|
103 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
104 |
+
acc_5 = (tot_e_pose < 5).mean()
|
105 |
+
acc_10 = (tot_e_pose < 10).mean()
|
106 |
+
acc_15 = (tot_e_pose < 15).mean()
|
107 |
+
acc_20 = (tot_e_pose < 20).mean()
|
108 |
+
map_5 = acc_5
|
109 |
+
map_10 = np.mean([acc_5, acc_10])
|
110 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
111 |
+
print(f"{model_name} auc: {auc}")
|
112 |
+
return {
|
113 |
+
"auc_5": auc[0],
|
114 |
+
"auc_10": auc[1],
|
115 |
+
"auc_20": auc[2],
|
116 |
+
"map_5": map_5,
|
117 |
+
"map_10": map_10,
|
118 |
+
"map_20": map_20,
|
119 |
+
}
|
third_party/RoMa/romatch/benchmarks/scannet_benchmark.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from romatch.utils import *
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
class ScanNetBenchmark:
|
10 |
+
def __init__(self, data_root="data/scannet") -> None:
|
11 |
+
self.data_root = data_root
|
12 |
+
|
13 |
+
def benchmark(self, model, model_name = None):
|
14 |
+
model.train(False)
|
15 |
+
with torch.no_grad():
|
16 |
+
data_root = self.data_root
|
17 |
+
tmp = np.load(osp.join(data_root, "test.npz"))
|
18 |
+
pairs, rel_pose = tmp["name"], tmp["rel_pose"]
|
19 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
20 |
+
pair_inds = np.random.choice(
|
21 |
+
range(len(pairs)), size=len(pairs), replace=False
|
22 |
+
)
|
23 |
+
for pairind in tqdm(pair_inds, smoothing=0.9):
|
24 |
+
scene = pairs[pairind]
|
25 |
+
scene_name = f"scene0{scene[0]}_00"
|
26 |
+
im_A_path = osp.join(
|
27 |
+
self.data_root,
|
28 |
+
"scans_test",
|
29 |
+
scene_name,
|
30 |
+
"color",
|
31 |
+
f"{scene[2]}.jpg",
|
32 |
+
)
|
33 |
+
im_A = Image.open(im_A_path)
|
34 |
+
im_B_path = osp.join(
|
35 |
+
self.data_root,
|
36 |
+
"scans_test",
|
37 |
+
scene_name,
|
38 |
+
"color",
|
39 |
+
f"{scene[3]}.jpg",
|
40 |
+
)
|
41 |
+
im_B = Image.open(im_B_path)
|
42 |
+
T_gt = rel_pose[pairind].reshape(3, 4)
|
43 |
+
R, t = T_gt[:3, :3], T_gt[:3, 3]
|
44 |
+
K = np.stack(
|
45 |
+
[
|
46 |
+
np.array([float(i) for i in r.split()])
|
47 |
+
for r in open(
|
48 |
+
osp.join(
|
49 |
+
self.data_root,
|
50 |
+
"scans_test",
|
51 |
+
scene_name,
|
52 |
+
"intrinsic",
|
53 |
+
"intrinsic_color.txt",
|
54 |
+
),
|
55 |
+
"r",
|
56 |
+
)
|
57 |
+
.read()
|
58 |
+
.split("\n")
|
59 |
+
if r
|
60 |
+
]
|
61 |
+
)
|
62 |
+
w1, h1 = im_A.size
|
63 |
+
w2, h2 = im_B.size
|
64 |
+
K1 = K.copy()
|
65 |
+
K2 = K.copy()
|
66 |
+
dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
|
67 |
+
sparse_matches, sparse_certainty = model.sample(
|
68 |
+
dense_matches, dense_certainty, 5000
|
69 |
+
)
|
70 |
+
scale1 = 480 / min(w1, h1)
|
71 |
+
scale2 = 480 / min(w2, h2)
|
72 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
73 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
74 |
+
K1 = K1 * scale1
|
75 |
+
K2 = K2 * scale2
|
76 |
+
|
77 |
+
offset = 0.5
|
78 |
+
kpts1 = sparse_matches[:, :2]
|
79 |
+
kpts1 = (
|
80 |
+
np.stack(
|
81 |
+
(
|
82 |
+
w1 * (kpts1[:, 0] + 1) / 2 - offset,
|
83 |
+
h1 * (kpts1[:, 1] + 1) / 2 - offset,
|
84 |
+
),
|
85 |
+
axis=-1,
|
86 |
+
)
|
87 |
+
)
|
88 |
+
kpts2 = sparse_matches[:, 2:]
|
89 |
+
kpts2 = (
|
90 |
+
np.stack(
|
91 |
+
(
|
92 |
+
w2 * (kpts2[:, 0] + 1) / 2 - offset,
|
93 |
+
h2 * (kpts2[:, 1] + 1) / 2 - offset,
|
94 |
+
),
|
95 |
+
axis=-1,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
for _ in range(5):
|
99 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
100 |
+
kpts1 = kpts1[shuffling]
|
101 |
+
kpts2 = kpts2[shuffling]
|
102 |
+
try:
|
103 |
+
norm_threshold = 0.5 / (
|
104 |
+
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
105 |
+
R_est, t_est, mask = estimate_pose(
|
106 |
+
kpts1,
|
107 |
+
kpts2,
|
108 |
+
K1,
|
109 |
+
K2,
|
110 |
+
norm_threshold,
|
111 |
+
conf=0.99999,
|
112 |
+
)
|
113 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
114 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
115 |
+
e_pose = max(e_t, e_R)
|
116 |
+
except Exception as e:
|
117 |
+
print(repr(e))
|
118 |
+
e_t, e_R = 90, 90
|
119 |
+
e_pose = max(e_t, e_R)
|
120 |
+
tot_e_t.append(e_t)
|
121 |
+
tot_e_R.append(e_R)
|
122 |
+
tot_e_pose.append(e_pose)
|
123 |
+
tot_e_t.append(e_t)
|
124 |
+
tot_e_R.append(e_R)
|
125 |
+
tot_e_pose.append(e_pose)
|
126 |
+
tot_e_pose = np.array(tot_e_pose)
|
127 |
+
thresholds = [5, 10, 20]
|
128 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
129 |
+
acc_5 = (tot_e_pose < 5).mean()
|
130 |
+
acc_10 = (tot_e_pose < 10).mean()
|
131 |
+
acc_15 = (tot_e_pose < 15).mean()
|
132 |
+
acc_20 = (tot_e_pose < 20).mean()
|
133 |
+
map_5 = acc_5
|
134 |
+
map_10 = np.mean([acc_5, acc_10])
|
135 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
136 |
+
return {
|
137 |
+
"auc_5": auc[0],
|
138 |
+
"auc_10": auc[1],
|
139 |
+
"auc_20": auc[2],
|
140 |
+
"map_5": map_5,
|
141 |
+
"map_10": map_10,
|
142 |
+
"map_20": map_20,
|
143 |
+
}
|
third_party/RoMa/romatch/checkpointing/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .checkpoint import CheckPoint
|
third_party/RoMa/romatch/checkpointing/checkpoint.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
4 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
5 |
+
from loguru import logger
|
6 |
+
import gc
|
7 |
+
|
8 |
+
import romatch
|
9 |
+
|
10 |
+
class CheckPoint:
|
11 |
+
def __init__(self, dir=None, name="tmp"):
|
12 |
+
self.name = name
|
13 |
+
self.dir = dir
|
14 |
+
os.makedirs(self.dir, exist_ok=True)
|
15 |
+
|
16 |
+
def save(
|
17 |
+
self,
|
18 |
+
model,
|
19 |
+
optimizer,
|
20 |
+
lr_scheduler,
|
21 |
+
n,
|
22 |
+
):
|
23 |
+
if romatch.RANK == 0:
|
24 |
+
assert model is not None
|
25 |
+
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
26 |
+
model = model.module
|
27 |
+
states = {
|
28 |
+
"model": model.state_dict(),
|
29 |
+
"n": n,
|
30 |
+
"optimizer": optimizer.state_dict(),
|
31 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
32 |
+
}
|
33 |
+
torch.save(states, self.dir + self.name + f"_latest.pth")
|
34 |
+
logger.info(f"Saved states {list(states.keys())}, at step {n}")
|
35 |
+
|
36 |
+
def load(
|
37 |
+
self,
|
38 |
+
model,
|
39 |
+
optimizer,
|
40 |
+
lr_scheduler,
|
41 |
+
n,
|
42 |
+
):
|
43 |
+
if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
|
44 |
+
states = torch.load(self.dir + self.name + f"_latest.pth")
|
45 |
+
if "model" in states:
|
46 |
+
model.load_state_dict(states["model"])
|
47 |
+
if "n" in states:
|
48 |
+
n = states["n"] if states["n"] else n
|
49 |
+
if "optimizer" in states:
|
50 |
+
try:
|
51 |
+
optimizer.load_state_dict(states["optimizer"])
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Failed to load states for optimizer, with error {e}")
|
54 |
+
if "lr_scheduler" in states:
|
55 |
+
lr_scheduler.load_state_dict(states["lr_scheduler"])
|
56 |
+
print(f"Loaded states {list(states.keys())}, at step {n}")
|
57 |
+
del states
|
58 |
+
gc.collect()
|
59 |
+
torch.cuda.empty_cache()
|
60 |
+
return model, optimizer, lr_scheduler, n
|
third_party/RoMa/romatch/datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .megadepth import MegadepthBuilder
|
2 |
+
from .scannet import ScanNetBuilder
|
third_party/RoMa/romatch/datasets/megadepth.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import h5py
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.functional as tvf
|
7 |
+
import kornia.augmentation as K
|
8 |
+
from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
9 |
+
import romatch
|
10 |
+
from romatch.utils import *
|
11 |
+
import math
|
12 |
+
|
13 |
+
class MegadepthScene:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
data_root,
|
17 |
+
scene_info,
|
18 |
+
ht=384,
|
19 |
+
wt=512,
|
20 |
+
min_overlap=0.0,
|
21 |
+
max_overlap=1.0,
|
22 |
+
shake_t=0,
|
23 |
+
rot_prob=0.0,
|
24 |
+
normalize=True,
|
25 |
+
max_num_pairs = 100_000,
|
26 |
+
scene_name = None,
|
27 |
+
use_horizontal_flip_aug = False,
|
28 |
+
use_single_horizontal_flip_aug = False,
|
29 |
+
colorjiggle_params = None,
|
30 |
+
random_eraser = None,
|
31 |
+
use_randaug = False,
|
32 |
+
randaug_params = None,
|
33 |
+
randomize_size = False,
|
34 |
+
) -> None:
|
35 |
+
self.data_root = data_root
|
36 |
+
self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
|
37 |
+
self.image_paths = scene_info["image_paths"]
|
38 |
+
self.depth_paths = scene_info["depth_paths"]
|
39 |
+
self.intrinsics = scene_info["intrinsics"]
|
40 |
+
self.poses = scene_info["poses"]
|
41 |
+
self.pairs = scene_info["pairs"]
|
42 |
+
self.overlaps = scene_info["overlaps"]
|
43 |
+
threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
|
44 |
+
self.pairs = self.pairs[threshold]
|
45 |
+
self.overlaps = self.overlaps[threshold]
|
46 |
+
if len(self.pairs) > max_num_pairs:
|
47 |
+
pairinds = np.random.choice(
|
48 |
+
np.arange(0, len(self.pairs)), max_num_pairs, replace=False
|
49 |
+
)
|
50 |
+
self.pairs = self.pairs[pairinds]
|
51 |
+
self.overlaps = self.overlaps[pairinds]
|
52 |
+
if randomize_size:
|
53 |
+
area = ht * wt
|
54 |
+
s = int(16 * (math.sqrt(area)//16))
|
55 |
+
sizes = ((ht,wt), (s,s), (wt,ht))
|
56 |
+
choice = romatch.RANK % 3
|
57 |
+
ht, wt = sizes[choice]
|
58 |
+
# counts, bins = np.histogram(self.overlaps,20)
|
59 |
+
# print(counts)
|
60 |
+
self.im_transform_ops = get_tuple_transform_ops(
|
61 |
+
resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
|
62 |
+
)
|
63 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
64 |
+
resize=(ht, wt)
|
65 |
+
)
|
66 |
+
self.wt, self.ht = wt, ht
|
67 |
+
self.shake_t = shake_t
|
68 |
+
self.random_eraser = random_eraser
|
69 |
+
if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
|
70 |
+
raise ValueError("Can't both flip both images and only flip one")
|
71 |
+
self.use_horizontal_flip_aug = use_horizontal_flip_aug
|
72 |
+
self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
|
73 |
+
self.use_randaug = use_randaug
|
74 |
+
|
75 |
+
def load_im(self, im_path):
|
76 |
+
im = Image.open(im_path)
|
77 |
+
return im
|
78 |
+
|
79 |
+
def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
|
80 |
+
im_A = im_A.flip(-1)
|
81 |
+
im_B = im_B.flip(-1)
|
82 |
+
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
|
83 |
+
flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
|
84 |
+
K_A = flip_mat@K_A
|
85 |
+
K_B = flip_mat@K_B
|
86 |
+
|
87 |
+
return im_A, im_B, depth_A, depth_B, K_A, K_B
|
88 |
+
|
89 |
+
def load_depth(self, depth_ref, crop=None):
|
90 |
+
depth = np.array(h5py.File(depth_ref, "r")["depth"])
|
91 |
+
return torch.from_numpy(depth)
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.pairs)
|
95 |
+
|
96 |
+
def scale_intrinsic(self, K, wi, hi):
|
97 |
+
sx, sy = self.wt / wi, self.ht / hi
|
98 |
+
sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
|
99 |
+
return sK @ K
|
100 |
+
|
101 |
+
def rand_shake(self, *things):
|
102 |
+
t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
|
103 |
+
return [
|
104 |
+
tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
|
105 |
+
for thing in things
|
106 |
+
], t
|
107 |
+
|
108 |
+
def __getitem__(self, pair_idx):
|
109 |
+
# read intrinsics of original size
|
110 |
+
idx1, idx2 = self.pairs[pair_idx]
|
111 |
+
K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
|
112 |
+
K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
|
113 |
+
|
114 |
+
# read and compute relative poses
|
115 |
+
T1 = self.poses[idx1]
|
116 |
+
T2 = self.poses[idx2]
|
117 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
|
118 |
+
:4, :4
|
119 |
+
] # (4, 4)
|
120 |
+
|
121 |
+
# Load positive pair data
|
122 |
+
im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
|
123 |
+
depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
|
124 |
+
im_A_ref = os.path.join(self.data_root, im_A)
|
125 |
+
im_B_ref = os.path.join(self.data_root, im_B)
|
126 |
+
depth_A_ref = os.path.join(self.data_root, depth1)
|
127 |
+
depth_B_ref = os.path.join(self.data_root, depth2)
|
128 |
+
im_A = self.load_im(im_A_ref)
|
129 |
+
im_B = self.load_im(im_B_ref)
|
130 |
+
K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
|
131 |
+
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
|
132 |
+
|
133 |
+
if self.use_randaug:
|
134 |
+
im_A, im_B = self.rand_augment(im_A, im_B)
|
135 |
+
|
136 |
+
depth_A = self.load_depth(depth_A_ref)
|
137 |
+
depth_B = self.load_depth(depth_B_ref)
|
138 |
+
# Process images
|
139 |
+
im_A, im_B = self.im_transform_ops((im_A, im_B))
|
140 |
+
depth_A, depth_B = self.depth_transform_ops(
|
141 |
+
(depth_A[None, None], depth_B[None, None])
|
142 |
+
)
|
143 |
+
|
144 |
+
[im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
|
145 |
+
K1[:2, 2] += t
|
146 |
+
K2[:2, 2] += t
|
147 |
+
|
148 |
+
im_A, im_B = im_A[None], im_B[None]
|
149 |
+
if self.random_eraser is not None:
|
150 |
+
im_A, depth_A = self.random_eraser(im_A, depth_A)
|
151 |
+
im_B, depth_B = self.random_eraser(im_B, depth_B)
|
152 |
+
|
153 |
+
if self.use_horizontal_flip_aug:
|
154 |
+
if np.random.rand() > 0.5:
|
155 |
+
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
|
156 |
+
if self.use_single_horizontal_flip_aug:
|
157 |
+
if np.random.rand() > 0.5:
|
158 |
+
im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
|
159 |
+
|
160 |
+
if romatch.DEBUG_MODE:
|
161 |
+
tensor_to_pil(im_A[0], unnormalize=True).save(
|
162 |
+
f"vis/im_A.jpg")
|
163 |
+
tensor_to_pil(im_B[0], unnormalize=True).save(
|
164 |
+
f"vis/im_B.jpg")
|
165 |
+
|
166 |
+
data_dict = {
|
167 |
+
"im_A": im_A[0],
|
168 |
+
"im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
|
169 |
+
"im_B": im_B[0],
|
170 |
+
"im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
|
171 |
+
"im_A_depth": depth_A[0, 0],
|
172 |
+
"im_B_depth": depth_B[0, 0],
|
173 |
+
"K1": K1,
|
174 |
+
"K2": K2,
|
175 |
+
"T_1to2": T_1to2,
|
176 |
+
"im_A_path": im_A_ref,
|
177 |
+
"im_B_path": im_B_ref,
|
178 |
+
|
179 |
+
}
|
180 |
+
return data_dict
|
181 |
+
|
182 |
+
|
183 |
+
class MegadepthBuilder:
|
184 |
+
def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
|
185 |
+
self.data_root = data_root
|
186 |
+
self.scene_info_root = os.path.join(data_root, "prep_scene_info")
|
187 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
188 |
+
self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
|
189 |
+
# LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
|
190 |
+
self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
|
191 |
+
self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
|
192 |
+
self.test_scenes_loftr = ["0015.npy", "0022.npy"]
|
193 |
+
self.loftr_ignore = loftr_ignore
|
194 |
+
self.imc21_ignore = imc21_ignore
|
195 |
+
|
196 |
+
def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
|
197 |
+
if split == "train":
|
198 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes)
|
199 |
+
elif split == "train_loftr":
|
200 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
|
201 |
+
elif split == "test":
|
202 |
+
scene_names = self.test_scenes
|
203 |
+
elif split == "test_loftr":
|
204 |
+
scene_names = self.test_scenes_loftr
|
205 |
+
elif split == "custom":
|
206 |
+
scene_names = scene_names
|
207 |
+
else:
|
208 |
+
raise ValueError(f"Split {split} not available")
|
209 |
+
scenes = []
|
210 |
+
for scene_name in scene_names:
|
211 |
+
if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
|
212 |
+
continue
|
213 |
+
if self.imc21_ignore and scene_name in self.imc21_scenes:
|
214 |
+
continue
|
215 |
+
if ".npy" not in scene_name:
|
216 |
+
continue
|
217 |
+
scene_info = np.load(
|
218 |
+
os.path.join(self.scene_info_root, scene_name), allow_pickle=True
|
219 |
+
).item()
|
220 |
+
scenes.append(
|
221 |
+
MegadepthScene(
|
222 |
+
self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
|
223 |
+
)
|
224 |
+
)
|
225 |
+
return scenes
|
226 |
+
|
227 |
+
def weight_scenes(self, concat_dataset, alpha=0.5):
|
228 |
+
ns = []
|
229 |
+
for d in concat_dataset.datasets:
|
230 |
+
ns.append(len(d))
|
231 |
+
ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
|
232 |
+
return ws
|
third_party/RoMa/romatch/datasets/scannet.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import (
|
9 |
+
Dataset,
|
10 |
+
DataLoader,
|
11 |
+
ConcatDataset)
|
12 |
+
|
13 |
+
import torchvision.transforms.functional as tvf
|
14 |
+
import kornia.augmentation as K
|
15 |
+
import os.path as osp
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import romatch
|
18 |
+
from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
19 |
+
from romatch.utils.transforms import GeometricSequential
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
class ScanNetScene:
|
23 |
+
def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
|
24 |
+
) -> None:
|
25 |
+
self.scene_root = osp.join(data_root,"scans","scans_train")
|
26 |
+
self.data_names = scene_info['name']
|
27 |
+
self.overlaps = scene_info['score']
|
28 |
+
# Only sample 10s
|
29 |
+
valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
|
30 |
+
self.overlaps = self.overlaps[valid]
|
31 |
+
self.data_names = self.data_names[valid]
|
32 |
+
if len(self.data_names) > 10000:
|
33 |
+
pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
|
34 |
+
self.data_names = self.data_names[pairinds]
|
35 |
+
self.overlaps = self.overlaps[pairinds]
|
36 |
+
self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
|
37 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
|
38 |
+
self.wt, self.ht = wt, ht
|
39 |
+
self.shake_t = shake_t
|
40 |
+
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
41 |
+
self.use_horizontal_flip_aug = use_horizontal_flip_aug
|
42 |
+
|
43 |
+
def load_im(self, im_B, crop=None):
|
44 |
+
im = Image.open(im_B)
|
45 |
+
return im
|
46 |
+
|
47 |
+
def load_depth(self, depth_ref, crop=None):
|
48 |
+
depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
|
49 |
+
depth = depth / 1000
|
50 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
51 |
+
return depth
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.data_names)
|
55 |
+
|
56 |
+
def scale_intrinsic(self, K, wi, hi):
|
57 |
+
sx, sy = self.wt / wi, self.ht / hi
|
58 |
+
sK = torch.tensor([[sx, 0, 0],
|
59 |
+
[0, sy, 0],
|
60 |
+
[0, 0, 1]])
|
61 |
+
return sK@K
|
62 |
+
|
63 |
+
def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
|
64 |
+
im_A = im_A.flip(-1)
|
65 |
+
im_B = im_B.flip(-1)
|
66 |
+
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
|
67 |
+
flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
|
68 |
+
K_A = flip_mat@K_A
|
69 |
+
K_B = flip_mat@K_B
|
70 |
+
|
71 |
+
return im_A, im_B, depth_A, depth_B, K_A, K_B
|
72 |
+
def read_scannet_pose(self,path):
|
73 |
+
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
pose_w2c (np.ndarray): (4, 4)
|
77 |
+
"""
|
78 |
+
cam2world = np.loadtxt(path, delimiter=' ')
|
79 |
+
world2cam = np.linalg.inv(cam2world)
|
80 |
+
return world2cam
|
81 |
+
|
82 |
+
|
83 |
+
def read_scannet_intrinsic(self,path):
|
84 |
+
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
|
85 |
+
"""
|
86 |
+
intrinsic = np.loadtxt(path, delimiter=' ')
|
87 |
+
return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
|
88 |
+
|
89 |
+
def __getitem__(self, pair_idx):
|
90 |
+
# read intrinsics of original size
|
91 |
+
data_name = self.data_names[pair_idx]
|
92 |
+
scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
|
93 |
+
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
|
94 |
+
|
95 |
+
# read the intrinsic of depthmap
|
96 |
+
K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
|
97 |
+
scene_name,
|
98 |
+
'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
|
99 |
+
# read and compute relative poses
|
100 |
+
T1 = self.read_scannet_pose(osp.join(self.scene_root,
|
101 |
+
scene_name,
|
102 |
+
'pose', f'{stem_name_1}.txt'))
|
103 |
+
T2 = self.read_scannet_pose(osp.join(self.scene_root,
|
104 |
+
scene_name,
|
105 |
+
'pose', f'{stem_name_2}.txt'))
|
106 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
|
107 |
+
|
108 |
+
# Load positive pair data
|
109 |
+
im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
|
110 |
+
im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
|
111 |
+
depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
|
112 |
+
depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
|
113 |
+
|
114 |
+
im_A = self.load_im(im_A_ref)
|
115 |
+
im_B = self.load_im(im_B_ref)
|
116 |
+
depth_A = self.load_depth(depth_A_ref)
|
117 |
+
depth_B = self.load_depth(depth_B_ref)
|
118 |
+
|
119 |
+
# Recompute camera intrinsic matrix due to the resize
|
120 |
+
K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
|
121 |
+
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
|
122 |
+
# Process images
|
123 |
+
im_A, im_B = self.im_transform_ops((im_A, im_B))
|
124 |
+
depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
|
125 |
+
if self.use_horizontal_flip_aug:
|
126 |
+
if np.random.rand() > 0.5:
|
127 |
+
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
|
128 |
+
|
129 |
+
data_dict = {'im_A': im_A,
|
130 |
+
'im_B': im_B,
|
131 |
+
'im_A_depth': depth_A[0,0],
|
132 |
+
'im_B_depth': depth_B[0,0],
|
133 |
+
'K1': K1,
|
134 |
+
'K2': K2,
|
135 |
+
'T_1to2':T_1to2,
|
136 |
+
}
|
137 |
+
return data_dict
|
138 |
+
|
139 |
+
|
140 |
+
class ScanNetBuilder:
|
141 |
+
def __init__(self, data_root = 'data/scannet') -> None:
|
142 |
+
self.data_root = data_root
|
143 |
+
self.scene_info_root = os.path.join(data_root,'scannet_indices')
|
144 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
145 |
+
|
146 |
+
def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
|
147 |
+
# Note: split doesn't matter here as we always use same scannet_train scenes
|
148 |
+
scene_names = self.all_scenes
|
149 |
+
scenes = []
|
150 |
+
for scene_name in tqdm(scene_names, disable = romatch.RANK > 0):
|
151 |
+
scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
|
152 |
+
scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
|
153 |
+
return scenes
|
154 |
+
|
155 |
+
def weight_scenes(self, concat_dataset, alpha=.5):
|
156 |
+
ns = []
|
157 |
+
for d in concat_dataset.datasets:
|
158 |
+
ns.append(len(d))
|
159 |
+
ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
|
160 |
+
return ws
|
third_party/RoMa/romatch/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .robust_loss import RobustLosses
|
third_party/RoMa/romatch/losses/robust_loss.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops.einops import rearrange
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from romatch.utils.utils import get_gt_warp
|
6 |
+
import wandb
|
7 |
+
import romatch
|
8 |
+
import math
|
9 |
+
|
10 |
+
class RobustLosses(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
robust=False,
|
14 |
+
center_coords=False,
|
15 |
+
scale_normalize=False,
|
16 |
+
ce_weight=0.01,
|
17 |
+
local_loss=True,
|
18 |
+
local_dist=4.0,
|
19 |
+
local_largest_scale=8,
|
20 |
+
smooth_mask = False,
|
21 |
+
depth_interpolation_mode = "bilinear",
|
22 |
+
mask_depth_loss = False,
|
23 |
+
relative_depth_error_threshold = 0.05,
|
24 |
+
alpha = 1.,
|
25 |
+
c = 1e-3,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.robust = robust # measured in pixels
|
29 |
+
self.center_coords = center_coords
|
30 |
+
self.scale_normalize = scale_normalize
|
31 |
+
self.ce_weight = ce_weight
|
32 |
+
self.local_loss = local_loss
|
33 |
+
self.local_dist = local_dist
|
34 |
+
self.local_largest_scale = local_largest_scale
|
35 |
+
self.smooth_mask = smooth_mask
|
36 |
+
self.depth_interpolation_mode = depth_interpolation_mode
|
37 |
+
self.mask_depth_loss = mask_depth_loss
|
38 |
+
self.relative_depth_error_threshold = relative_depth_error_threshold
|
39 |
+
self.avg_overlap = dict()
|
40 |
+
self.alpha = alpha
|
41 |
+
self.c = c
|
42 |
+
|
43 |
+
def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
|
44 |
+
with torch.no_grad():
|
45 |
+
B, C, H, W = scale_gm_cls.shape
|
46 |
+
device = x2.device
|
47 |
+
cls_res = round(math.sqrt(C))
|
48 |
+
G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
|
49 |
+
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
|
50 |
+
GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
|
51 |
+
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
|
52 |
+
certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
|
53 |
+
if not torch.any(cls_loss):
|
54 |
+
cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
55 |
+
|
56 |
+
losses = {
|
57 |
+
f"gm_certainty_loss_{scale}": certainty_loss.mean(),
|
58 |
+
f"gm_cls_loss_{scale}": cls_loss.mean(),
|
59 |
+
}
|
60 |
+
wandb.log(losses, step = romatch.GLOBAL_STEP)
|
61 |
+
return losses
|
62 |
+
|
63 |
+
def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
|
64 |
+
with torch.no_grad():
|
65 |
+
B, C, H, W = delta_cls.shape
|
66 |
+
device = x2.device
|
67 |
+
cls_res = round(math.sqrt(C))
|
68 |
+
G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
|
69 |
+
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
|
70 |
+
GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
|
71 |
+
cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
|
72 |
+
if not torch.any(cls_loss):
|
73 |
+
cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
74 |
+
certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
|
75 |
+
losses = {
|
76 |
+
f"delta_certainty_loss_{scale}": certainty_loss.mean(),
|
77 |
+
f"delta_cls_loss_{scale}": cls_loss.mean(),
|
78 |
+
}
|
79 |
+
wandb.log(losses, step = romatch.GLOBAL_STEP)
|
80 |
+
return losses
|
81 |
+
|
82 |
+
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
|
83 |
+
epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
|
84 |
+
if scale == 1:
|
85 |
+
pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
|
86 |
+
wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
|
87 |
+
|
88 |
+
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
|
89 |
+
a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
|
90 |
+
cs = self.c * scale
|
91 |
+
x = epe[prob > 0.99]
|
92 |
+
reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
|
93 |
+
if not torch.any(reg_loss):
|
94 |
+
reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
95 |
+
losses = {
|
96 |
+
f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
|
97 |
+
f"{mode}_regression_loss_{scale}": reg_loss.mean(),
|
98 |
+
}
|
99 |
+
wandb.log(losses, step = romatch.GLOBAL_STEP)
|
100 |
+
return losses
|
101 |
+
|
102 |
+
def forward(self, corresps, batch):
|
103 |
+
scales = list(corresps.keys())
|
104 |
+
tot_loss = 0.0
|
105 |
+
# scale_weights due to differences in scale for regression gradients and classification gradients
|
106 |
+
scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
|
107 |
+
for scale in scales:
|
108 |
+
scale_corresps = corresps[scale]
|
109 |
+
scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
|
110 |
+
scale_corresps["certainty"],
|
111 |
+
scale_corresps.get("flow_pre_delta"),
|
112 |
+
scale_corresps.get("delta_cls"),
|
113 |
+
scale_corresps.get("offset_scale"),
|
114 |
+
scale_corresps.get("gm_cls"),
|
115 |
+
scale_corresps.get("gm_certainty"),
|
116 |
+
scale_corresps["flow"],
|
117 |
+
scale_corresps.get("gm_flow"),
|
118 |
+
|
119 |
+
)
|
120 |
+
if flow_pre_delta is not None:
|
121 |
+
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
|
122 |
+
b, h, w, d = flow_pre_delta.shape
|
123 |
+
else:
|
124 |
+
# _ = 1
|
125 |
+
b, _, h, w = scale_certainty.shape
|
126 |
+
gt_warp, gt_prob = get_gt_warp(
|
127 |
+
batch["im_A_depth"],
|
128 |
+
batch["im_B_depth"],
|
129 |
+
batch["T_1to2"],
|
130 |
+
batch["K1"],
|
131 |
+
batch["K2"],
|
132 |
+
H=h,
|
133 |
+
W=w,
|
134 |
+
)
|
135 |
+
x2 = gt_warp.float()
|
136 |
+
prob = gt_prob
|
137 |
+
|
138 |
+
if self.local_largest_scale >= scale:
|
139 |
+
prob = prob * (
|
140 |
+
F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
|
141 |
+
< (2 / 512) * (self.local_dist[scale] * scale))
|
142 |
+
|
143 |
+
if scale_gm_cls is not None:
|
144 |
+
gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
|
145 |
+
gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
|
146 |
+
tot_loss = tot_loss + scale_weights[scale] * gm_loss
|
147 |
+
elif scale_gm_flow is not None:
|
148 |
+
gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
|
149 |
+
gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
|
150 |
+
tot_loss = tot_loss + scale_weights[scale] * gm_loss
|
151 |
+
|
152 |
+
if delta_cls is not None:
|
153 |
+
delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
|
154 |
+
delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
|
155 |
+
tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
|
156 |
+
else:
|
157 |
+
delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
|
158 |
+
reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
|
159 |
+
tot_loss = tot_loss + scale_weights[scale] * reg_loss
|
160 |
+
prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
|
161 |
+
return tot_loss
|
third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops.einops import rearrange
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from romatch.utils.utils import get_gt_warp
|
6 |
+
import wandb
|
7 |
+
import romatch
|
8 |
+
import math
|
9 |
+
|
10 |
+
# This is slightly different than regular romatch due to significantly worse corresps
|
11 |
+
# The confidence loss is quite tricky here //Johan
|
12 |
+
|
13 |
+
class RobustLosses(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
robust=False,
|
17 |
+
center_coords=False,
|
18 |
+
scale_normalize=False,
|
19 |
+
ce_weight=0.01,
|
20 |
+
local_loss=True,
|
21 |
+
local_dist=None,
|
22 |
+
smooth_mask = False,
|
23 |
+
depth_interpolation_mode = "bilinear",
|
24 |
+
mask_depth_loss = False,
|
25 |
+
relative_depth_error_threshold = 0.05,
|
26 |
+
alpha = 1.,
|
27 |
+
c = 1e-3,
|
28 |
+
epe_mask_prob_th = None,
|
29 |
+
cert_only_on_consistent_depth = False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
if local_dist is None:
|
33 |
+
local_dist = {}
|
34 |
+
self.robust = robust # measured in pixels
|
35 |
+
self.center_coords = center_coords
|
36 |
+
self.scale_normalize = scale_normalize
|
37 |
+
self.ce_weight = ce_weight
|
38 |
+
self.local_loss = local_loss
|
39 |
+
self.local_dist = local_dist
|
40 |
+
self.smooth_mask = smooth_mask
|
41 |
+
self.depth_interpolation_mode = depth_interpolation_mode
|
42 |
+
self.mask_depth_loss = mask_depth_loss
|
43 |
+
self.relative_depth_error_threshold = relative_depth_error_threshold
|
44 |
+
self.avg_overlap = dict()
|
45 |
+
self.alpha = alpha
|
46 |
+
self.c = c
|
47 |
+
self.epe_mask_prob_th = epe_mask_prob_th
|
48 |
+
self.cert_only_on_consistent_depth = cert_only_on_consistent_depth
|
49 |
+
|
50 |
+
def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale):
|
51 |
+
b, h,w, h,w = corr_volume.shape
|
52 |
+
inv_temp = 10
|
53 |
+
corr_volume = corr_volume.reshape(-1, h*w, h*w)
|
54 |
+
nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2)
|
55 |
+
corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean()
|
56 |
+
|
57 |
+
losses = {
|
58 |
+
f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(),
|
59 |
+
}
|
60 |
+
wandb.log(losses, step = romatch.GLOBAL_STEP)
|
61 |
+
return losses
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
|
66 |
+
epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
|
67 |
+
if scale in self.local_dist:
|
68 |
+
prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float()
|
69 |
+
if scale == 1:
|
70 |
+
pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
|
71 |
+
wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
|
72 |
+
if self.epe_mask_prob_th is not None:
|
73 |
+
# if too far away from gt, certainty should be 0
|
74 |
+
gt_cert = prob * (epe < scale * self.epe_mask_prob_th)
|
75 |
+
else:
|
76 |
+
gt_cert = prob
|
77 |
+
if self.cert_only_on_consistent_depth:
|
78 |
+
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0])
|
79 |
+
else:
|
80 |
+
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert)
|
81 |
+
a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
|
82 |
+
cs = self.c * scale
|
83 |
+
x = epe[prob > 0.99]
|
84 |
+
reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
|
85 |
+
if not torch.any(reg_loss):
|
86 |
+
reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
87 |
+
losses = {
|
88 |
+
f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
|
89 |
+
f"{mode}_regression_loss_{scale}": reg_loss.mean(),
|
90 |
+
}
|
91 |
+
wandb.log(losses, step = romatch.GLOBAL_STEP)
|
92 |
+
return losses
|
93 |
+
|
94 |
+
def forward(self, corresps, batch):
|
95 |
+
scales = list(corresps.keys())
|
96 |
+
tot_loss = 0.0
|
97 |
+
# scale_weights due to differences in scale for regression gradients and classification gradients
|
98 |
+
for scale in scales:
|
99 |
+
scale_corresps = corresps[scale]
|
100 |
+
scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = (
|
101 |
+
scale_corresps["certainty"],
|
102 |
+
scale_corresps.get("flow_pre_delta"),
|
103 |
+
scale_corresps.get("delta_cls"),
|
104 |
+
scale_corresps.get("offset_scale"),
|
105 |
+
scale_corresps.get("corr_volume"),
|
106 |
+
scale_corresps.get("gm_certainty"),
|
107 |
+
scale_corresps["flow"],
|
108 |
+
scale_corresps.get("gm_flow"),
|
109 |
+
|
110 |
+
)
|
111 |
+
if flow_pre_delta is not None:
|
112 |
+
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
|
113 |
+
b, h, w, d = flow_pre_delta.shape
|
114 |
+
else:
|
115 |
+
# _ = 1
|
116 |
+
b, _, h, w = scale_certainty.shape
|
117 |
+
gt_warp, gt_prob = get_gt_warp(
|
118 |
+
batch["im_A_depth"],
|
119 |
+
batch["im_B_depth"],
|
120 |
+
batch["T_1to2"],
|
121 |
+
batch["K1"],
|
122 |
+
batch["K2"],
|
123 |
+
H=h,
|
124 |
+
W=w,
|
125 |
+
)
|
126 |
+
x2 = gt_warp.float()
|
127 |
+
prob = gt_prob
|
128 |
+
|
129 |
+
if scale_gm_corr_volume is not None:
|
130 |
+
gt_warp_back, _ = get_gt_warp(
|
131 |
+
batch["im_B_depth"],
|
132 |
+
batch["im_A_depth"],
|
133 |
+
batch["T_1to2"].inverse(),
|
134 |
+
batch["K2"],
|
135 |
+
batch["K1"],
|
136 |
+
H=h,
|
137 |
+
W=w,
|
138 |
+
)
|
139 |
+
grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device)
|
140 |
+
#fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1)
|
141 |
+
#diff = (fwd_bck - grid).norm(dim = -1)
|
142 |
+
with torch.no_grad():
|
143 |
+
D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2))
|
144 |
+
D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2))
|
145 |
+
inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
|
146 |
+
* (D_A == D_A.min(dim=-2, keepdim = True).values)
|
147 |
+
* (D_B < 0.01)
|
148 |
+
* (D_A < 0.01))
|
149 |
+
|
150 |
+
gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale)
|
151 |
+
gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"]
|
152 |
+
tot_loss = tot_loss + gm_loss
|
153 |
+
elif scale_gm_flow is not None:
|
154 |
+
gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
|
155 |
+
gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
|
156 |
+
tot_loss = tot_loss + gm_loss
|
157 |
+
delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
|
158 |
+
reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
|
159 |
+
tot_loss = tot_loss + reg_loss
|
160 |
+
return tot_loss
|
third_party/RoMa/romatch/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_zoo import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
|
third_party/RoMa/romatch/models/encoders.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
import torch
|
3 |
+
from torch import device
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.models as tvm
|
7 |
+
import gc
|
8 |
+
|
9 |
+
|
10 |
+
class ResNet50(nn.Module):
|
11 |
+
def __init__(self, pretrained=False, high_res = False, weights = None,
|
12 |
+
dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
|
13 |
+
super().__init__()
|
14 |
+
if dilation is None:
|
15 |
+
dilation = [False,False,False]
|
16 |
+
if anti_aliased:
|
17 |
+
pass
|
18 |
+
else:
|
19 |
+
if weights is not None:
|
20 |
+
self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
|
21 |
+
else:
|
22 |
+
self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
|
23 |
+
|
24 |
+
self.high_res = high_res
|
25 |
+
self.freeze_bn = freeze_bn
|
26 |
+
self.early_exit = early_exit
|
27 |
+
self.amp = amp
|
28 |
+
self.amp_dtype = amp_dtype
|
29 |
+
|
30 |
+
def forward(self, x, **kwargs):
|
31 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
32 |
+
net = self.net
|
33 |
+
feats = {1:x}
|
34 |
+
x = net.conv1(x)
|
35 |
+
x = net.bn1(x)
|
36 |
+
x = net.relu(x)
|
37 |
+
feats[2] = x
|
38 |
+
x = net.maxpool(x)
|
39 |
+
x = net.layer1(x)
|
40 |
+
feats[4] = x
|
41 |
+
x = net.layer2(x)
|
42 |
+
feats[8] = x
|
43 |
+
if self.early_exit:
|
44 |
+
return feats
|
45 |
+
x = net.layer3(x)
|
46 |
+
feats[16] = x
|
47 |
+
x = net.layer4(x)
|
48 |
+
feats[32] = x
|
49 |
+
return feats
|
50 |
+
|
51 |
+
def train(self, mode=True):
|
52 |
+
super().train(mode)
|
53 |
+
if self.freeze_bn:
|
54 |
+
for m in self.modules():
|
55 |
+
if isinstance(m, nn.BatchNorm2d):
|
56 |
+
m.eval()
|
57 |
+
pass
|
58 |
+
|
59 |
+
class VGG19(nn.Module):
|
60 |
+
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
61 |
+
super().__init__()
|
62 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
63 |
+
self.amp = amp
|
64 |
+
self.amp_dtype = amp_dtype
|
65 |
+
|
66 |
+
def forward(self, x, **kwargs):
|
67 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
68 |
+
feats = {}
|
69 |
+
scale = 1
|
70 |
+
for layer in self.layers:
|
71 |
+
if isinstance(layer, nn.MaxPool2d):
|
72 |
+
feats[scale] = x
|
73 |
+
scale = scale*2
|
74 |
+
x = layer(x)
|
75 |
+
return feats
|
76 |
+
|
77 |
+
class CNNandDinov2(nn.Module):
|
78 |
+
def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
|
79 |
+
super().__init__()
|
80 |
+
if dinov2_weights is None:
|
81 |
+
dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
|
82 |
+
from .transformer import vit_large
|
83 |
+
vit_kwargs = dict(img_size= 518,
|
84 |
+
patch_size= 14,
|
85 |
+
init_values = 1.0,
|
86 |
+
ffn_layer = "mlp",
|
87 |
+
block_chunks = 0,
|
88 |
+
)
|
89 |
+
|
90 |
+
dinov2_vitl14 = vit_large(**vit_kwargs).eval()
|
91 |
+
dinov2_vitl14.load_state_dict(dinov2_weights)
|
92 |
+
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
|
93 |
+
if not use_vgg:
|
94 |
+
self.cnn = ResNet50(**cnn_kwargs)
|
95 |
+
else:
|
96 |
+
self.cnn = VGG19(**cnn_kwargs)
|
97 |
+
self.amp = amp
|
98 |
+
self.amp_dtype = amp_dtype
|
99 |
+
if self.amp:
|
100 |
+
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
101 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
102 |
+
|
103 |
+
|
104 |
+
def train(self, mode: bool = True):
|
105 |
+
return self.cnn.train(mode)
|
106 |
+
|
107 |
+
def forward(self, x, upsample = False):
|
108 |
+
B,C,H,W = x.shape
|
109 |
+
feature_pyramid = self.cnn(x)
|
110 |
+
|
111 |
+
if not upsample:
|
112 |
+
with torch.no_grad():
|
113 |
+
if self.dinov2_vitl14[0].device != x.device:
|
114 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
115 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
116 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
117 |
+
del dinov2_features_16
|
118 |
+
feature_pyramid[16] = features_16
|
119 |
+
return feature_pyramid
|
third_party/RoMa/romatch/models/matcher.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
import warnings
|
9 |
+
from warnings import warn
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
import romatch
|
13 |
+
from romatch.utils import get_tuple_transform_ops
|
14 |
+
from romatch.utils.local_correlation import local_correlation
|
15 |
+
from romatch.utils.utils import cls_to_flow_refine
|
16 |
+
from romatch.utils.kde import kde
|
17 |
+
from typing import Union
|
18 |
+
|
19 |
+
class ConvRefiner(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
in_dim=6,
|
23 |
+
hidden_dim=16,
|
24 |
+
out_dim=2,
|
25 |
+
dw=False,
|
26 |
+
kernel_size=5,
|
27 |
+
hidden_blocks=3,
|
28 |
+
displacement_emb = None,
|
29 |
+
displacement_emb_dim = None,
|
30 |
+
local_corr_radius = None,
|
31 |
+
corr_in_other = None,
|
32 |
+
no_im_B_fm = False,
|
33 |
+
amp = False,
|
34 |
+
concat_logits = False,
|
35 |
+
use_bias_block_1 = True,
|
36 |
+
use_cosine_corr = False,
|
37 |
+
disable_local_corr_grad = False,
|
38 |
+
is_classifier = False,
|
39 |
+
sample_mode = "bilinear",
|
40 |
+
norm_type = nn.BatchNorm2d,
|
41 |
+
bn_momentum = 0.1,
|
42 |
+
amp_dtype = torch.float16,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.bn_momentum = bn_momentum
|
46 |
+
self.block1 = self.create_block(
|
47 |
+
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
|
48 |
+
)
|
49 |
+
self.hidden_blocks = nn.Sequential(
|
50 |
+
*[
|
51 |
+
self.create_block(
|
52 |
+
hidden_dim,
|
53 |
+
hidden_dim,
|
54 |
+
dw=dw,
|
55 |
+
kernel_size=kernel_size,
|
56 |
+
norm_type=norm_type,
|
57 |
+
)
|
58 |
+
for hb in range(hidden_blocks)
|
59 |
+
]
|
60 |
+
)
|
61 |
+
self.hidden_blocks = self.hidden_blocks
|
62 |
+
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
63 |
+
if displacement_emb:
|
64 |
+
self.has_displacement_emb = True
|
65 |
+
self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
|
66 |
+
else:
|
67 |
+
self.has_displacement_emb = False
|
68 |
+
self.local_corr_radius = local_corr_radius
|
69 |
+
self.corr_in_other = corr_in_other
|
70 |
+
self.no_im_B_fm = no_im_B_fm
|
71 |
+
self.amp = amp
|
72 |
+
self.concat_logits = concat_logits
|
73 |
+
self.use_cosine_corr = use_cosine_corr
|
74 |
+
self.disable_local_corr_grad = disable_local_corr_grad
|
75 |
+
self.is_classifier = is_classifier
|
76 |
+
self.sample_mode = sample_mode
|
77 |
+
self.amp_dtype = amp_dtype
|
78 |
+
|
79 |
+
def create_block(
|
80 |
+
self,
|
81 |
+
in_dim,
|
82 |
+
out_dim,
|
83 |
+
dw=False,
|
84 |
+
kernel_size=5,
|
85 |
+
bias = True,
|
86 |
+
norm_type = nn.BatchNorm2d,
|
87 |
+
):
|
88 |
+
num_groups = 1 if not dw else in_dim
|
89 |
+
if dw:
|
90 |
+
assert (
|
91 |
+
out_dim % in_dim == 0
|
92 |
+
), "outdim must be divisible by indim for depthwise"
|
93 |
+
conv1 = nn.Conv2d(
|
94 |
+
in_dim,
|
95 |
+
out_dim,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
stride=1,
|
98 |
+
padding=kernel_size // 2,
|
99 |
+
groups=num_groups,
|
100 |
+
bias=bias,
|
101 |
+
)
|
102 |
+
norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
|
103 |
+
relu = nn.ReLU(inplace=True)
|
104 |
+
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
105 |
+
return nn.Sequential(conv1, norm, relu, conv2)
|
106 |
+
|
107 |
+
def forward(self, x, y, flow, scale_factor = 1, logits = None):
|
108 |
+
b,c,hs,ws = x.shape
|
109 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
110 |
+
with torch.no_grad():
|
111 |
+
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
|
112 |
+
if self.has_displacement_emb:
|
113 |
+
im_A_coords = torch.meshgrid(
|
114 |
+
(
|
115 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
|
116 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
|
117 |
+
)
|
118 |
+
)
|
119 |
+
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
120 |
+
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
121 |
+
in_displacement = flow-im_A_coords
|
122 |
+
emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
|
123 |
+
if self.local_corr_radius:
|
124 |
+
if self.corr_in_other:
|
125 |
+
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
126 |
+
local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
|
127 |
+
sample_mode = self.sample_mode)
|
128 |
+
else:
|
129 |
+
raise NotImplementedError("Local corr in own frame should not be used.")
|
130 |
+
if self.no_im_B_fm:
|
131 |
+
x_hat = torch.zeros_like(x)
|
132 |
+
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
133 |
+
else:
|
134 |
+
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
135 |
+
else:
|
136 |
+
if self.no_im_B_fm:
|
137 |
+
x_hat = torch.zeros_like(x)
|
138 |
+
d = torch.cat((x, x_hat), dim=1)
|
139 |
+
if self.concat_logits:
|
140 |
+
d = torch.cat((d, logits), dim=1)
|
141 |
+
d = self.block1(d)
|
142 |
+
d = self.hidden_blocks(d)
|
143 |
+
d = self.out_conv(d.float())
|
144 |
+
displacement, certainty = d[:, :-1], d[:, -1:]
|
145 |
+
return displacement, certainty
|
146 |
+
|
147 |
+
class CosKernel(nn.Module): # similar to softmax kernel
|
148 |
+
def __init__(self, T, learn_temperature=False):
|
149 |
+
super().__init__()
|
150 |
+
self.learn_temperature = learn_temperature
|
151 |
+
if self.learn_temperature:
|
152 |
+
self.T = nn.Parameter(torch.tensor(T))
|
153 |
+
else:
|
154 |
+
self.T = T
|
155 |
+
|
156 |
+
def __call__(self, x, y, eps=1e-6):
|
157 |
+
c = torch.einsum("bnd,bmd->bnm", x, y) / (
|
158 |
+
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
|
159 |
+
)
|
160 |
+
if self.learn_temperature:
|
161 |
+
T = self.T.abs() + 0.01
|
162 |
+
else:
|
163 |
+
T = torch.tensor(self.T, device=c.device)
|
164 |
+
K = ((c - 1.0) / T).exp()
|
165 |
+
return K
|
166 |
+
|
167 |
+
class GP(nn.Module):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
kernel,
|
171 |
+
T=1,
|
172 |
+
learn_temperature=False,
|
173 |
+
only_attention=False,
|
174 |
+
gp_dim=64,
|
175 |
+
basis="fourier",
|
176 |
+
covar_size=5,
|
177 |
+
only_nearest_neighbour=False,
|
178 |
+
sigma_noise=0.1,
|
179 |
+
no_cov=False,
|
180 |
+
predict_features = False,
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
184 |
+
self.sigma_noise = sigma_noise
|
185 |
+
self.covar_size = covar_size
|
186 |
+
self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
|
187 |
+
self.only_attention = only_attention
|
188 |
+
self.only_nearest_neighbour = only_nearest_neighbour
|
189 |
+
self.basis = basis
|
190 |
+
self.no_cov = no_cov
|
191 |
+
self.dim = gp_dim
|
192 |
+
self.predict_features = predict_features
|
193 |
+
|
194 |
+
def get_local_cov(self, cov):
|
195 |
+
K = self.covar_size
|
196 |
+
b, h, w, h, w = cov.shape
|
197 |
+
hw = h * w
|
198 |
+
cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
|
199 |
+
delta = torch.stack(
|
200 |
+
torch.meshgrid(
|
201 |
+
torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
|
202 |
+
),
|
203 |
+
dim=-1,
|
204 |
+
)
|
205 |
+
positions = torch.stack(
|
206 |
+
torch.meshgrid(
|
207 |
+
torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
|
208 |
+
),
|
209 |
+
dim=-1,
|
210 |
+
)
|
211 |
+
neighbours = positions[:, :, None, None, :] + delta[None, :, :]
|
212 |
+
points = torch.arange(hw)[:, None].expand(hw, K**2)
|
213 |
+
local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
|
214 |
+
:,
|
215 |
+
points.flatten(),
|
216 |
+
neighbours[..., 0].flatten(),
|
217 |
+
neighbours[..., 1].flatten(),
|
218 |
+
].reshape(b, h, w, K**2)
|
219 |
+
return local_cov
|
220 |
+
|
221 |
+
def reshape(self, x):
|
222 |
+
return rearrange(x, "b d h w -> b (h w) d")
|
223 |
+
|
224 |
+
def project_to_basis(self, x):
|
225 |
+
if self.basis == "fourier":
|
226 |
+
return torch.cos(8 * math.pi * self.pos_conv(x))
|
227 |
+
elif self.basis == "linear":
|
228 |
+
return self.pos_conv(x)
|
229 |
+
else:
|
230 |
+
raise ValueError(
|
231 |
+
"No other bases other than fourier and linear currently im_Bed in public release"
|
232 |
+
)
|
233 |
+
|
234 |
+
def get_pos_enc(self, y):
|
235 |
+
b, c, h, w = y.shape
|
236 |
+
coarse_coords = torch.meshgrid(
|
237 |
+
(
|
238 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
|
239 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
|
240 |
+
)
|
241 |
+
)
|
242 |
+
|
243 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
244 |
+
None
|
245 |
+
].expand(b, h, w, 2)
|
246 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
247 |
+
coarse_embedded_coords = self.project_to_basis(coarse_coords)
|
248 |
+
return coarse_embedded_coords
|
249 |
+
|
250 |
+
def forward(self, x, y, **kwargs):
|
251 |
+
b, c, h1, w1 = x.shape
|
252 |
+
b, c, h2, w2 = y.shape
|
253 |
+
f = self.get_pos_enc(y)
|
254 |
+
b, d, h2, w2 = f.shape
|
255 |
+
x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
|
256 |
+
K_xx = self.K(x, x)
|
257 |
+
K_yy = self.K(y, y)
|
258 |
+
K_xy = self.K(x, y)
|
259 |
+
K_yx = K_xy.permute(0, 2, 1)
|
260 |
+
sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
|
261 |
+
with warnings.catch_warnings():
|
262 |
+
K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
|
263 |
+
|
264 |
+
mu_x = K_xy.matmul(K_yy_inv.matmul(f))
|
265 |
+
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
266 |
+
if not self.no_cov:
|
267 |
+
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
268 |
+
cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
|
269 |
+
local_cov_x = self.get_local_cov(cov_x)
|
270 |
+
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
271 |
+
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
272 |
+
else:
|
273 |
+
gp_feats = mu_x
|
274 |
+
return gp_feats
|
275 |
+
|
276 |
+
class Decoder(nn.Module):
|
277 |
+
def __init__(
|
278 |
+
self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
|
279 |
+
num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
|
280 |
+
flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
self.embedding_decoder = embedding_decoder
|
284 |
+
self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
|
285 |
+
self.gps = gps
|
286 |
+
self.proj = proj
|
287 |
+
self.conv_refiner = conv_refiner
|
288 |
+
self.detach = detach
|
289 |
+
if pos_embeddings is None:
|
290 |
+
self.pos_embeddings = {}
|
291 |
+
else:
|
292 |
+
self.pos_embeddings = pos_embeddings
|
293 |
+
if scales == "all":
|
294 |
+
self.scales = ["32", "16", "8", "4", "2", "1"]
|
295 |
+
else:
|
296 |
+
self.scales = scales
|
297 |
+
self.warp_noise_std = warp_noise_std
|
298 |
+
self.refine_init = 4
|
299 |
+
self.displacement_dropout_p = displacement_dropout_p
|
300 |
+
self.gm_warp_dropout_p = gm_warp_dropout_p
|
301 |
+
self.flow_upsample_mode = flow_upsample_mode
|
302 |
+
self.amp_dtype = amp_dtype
|
303 |
+
|
304 |
+
def get_placeholder_flow(self, b, h, w, device):
|
305 |
+
coarse_coords = torch.meshgrid(
|
306 |
+
(
|
307 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
308 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
309 |
+
)
|
310 |
+
)
|
311 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
312 |
+
None
|
313 |
+
].expand(b, h, w, 2)
|
314 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
315 |
+
return coarse_coords
|
316 |
+
|
317 |
+
def get_positional_embedding(self, b, h ,w, device):
|
318 |
+
coarse_coords = torch.meshgrid(
|
319 |
+
(
|
320 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
321 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
322 |
+
)
|
323 |
+
)
|
324 |
+
|
325 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
326 |
+
None
|
327 |
+
].expand(b, h, w, 2)
|
328 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
329 |
+
coarse_embedded_coords = self.pos_embedding(coarse_coords)
|
330 |
+
return coarse_embedded_coords
|
331 |
+
|
332 |
+
def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
|
333 |
+
coarse_scales = self.embedding_decoder.scales()
|
334 |
+
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
335 |
+
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
336 |
+
h, w = sizes[1]
|
337 |
+
b = f1[1].shape[0]
|
338 |
+
device = f1[1].device
|
339 |
+
coarsest_scale = int(all_scales[0])
|
340 |
+
old_stuff = torch.zeros(
|
341 |
+
b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
|
342 |
+
)
|
343 |
+
corresps = {}
|
344 |
+
if not upsample:
|
345 |
+
flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
|
346 |
+
certainty = 0.0
|
347 |
+
else:
|
348 |
+
flow = F.interpolate(
|
349 |
+
flow,
|
350 |
+
size=sizes[coarsest_scale],
|
351 |
+
align_corners=False,
|
352 |
+
mode="bilinear",
|
353 |
+
)
|
354 |
+
certainty = F.interpolate(
|
355 |
+
certainty,
|
356 |
+
size=sizes[coarsest_scale],
|
357 |
+
align_corners=False,
|
358 |
+
mode="bilinear",
|
359 |
+
)
|
360 |
+
displacement = 0.0
|
361 |
+
for new_scale in all_scales:
|
362 |
+
ins = int(new_scale)
|
363 |
+
corresps[ins] = {}
|
364 |
+
f1_s, f2_s = f1[ins], f2[ins]
|
365 |
+
if new_scale in self.proj:
|
366 |
+
with torch.autocast("cuda", dtype = self.amp_dtype):
|
367 |
+
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
368 |
+
|
369 |
+
if ins in coarse_scales:
|
370 |
+
old_stuff = F.interpolate(
|
371 |
+
old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
|
372 |
+
)
|
373 |
+
gp_posterior = self.gps[new_scale](f1_s, f2_s)
|
374 |
+
gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
|
375 |
+
gp_posterior, f1_s, old_stuff, new_scale
|
376 |
+
)
|
377 |
+
|
378 |
+
if self.embedding_decoder.is_classifier:
|
379 |
+
flow = cls_to_flow_refine(
|
380 |
+
gm_warp_or_cls,
|
381 |
+
).permute(0,3,1,2)
|
382 |
+
corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
383 |
+
else:
|
384 |
+
corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
385 |
+
flow = gm_warp_or_cls.detach()
|
386 |
+
|
387 |
+
if new_scale in self.conv_refiner:
|
388 |
+
corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
|
389 |
+
delta_flow, delta_certainty = self.conv_refiner[new_scale](
|
390 |
+
f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
|
391 |
+
)
|
392 |
+
corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
|
393 |
+
displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
|
394 |
+
delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
|
395 |
+
flow = flow + displacement
|
396 |
+
certainty = (
|
397 |
+
certainty + delta_certainty
|
398 |
+
) # predict both certainty and displacement
|
399 |
+
corresps[ins].update({
|
400 |
+
"certainty": certainty,
|
401 |
+
"flow": flow,
|
402 |
+
})
|
403 |
+
if new_scale != "1":
|
404 |
+
flow = F.interpolate(
|
405 |
+
flow,
|
406 |
+
size=sizes[ins // 2],
|
407 |
+
mode=self.flow_upsample_mode,
|
408 |
+
)
|
409 |
+
certainty = F.interpolate(
|
410 |
+
certainty,
|
411 |
+
size=sizes[ins // 2],
|
412 |
+
mode=self.flow_upsample_mode,
|
413 |
+
)
|
414 |
+
if self.detach:
|
415 |
+
flow = flow.detach()
|
416 |
+
certainty = certainty.detach()
|
417 |
+
#torch.cuda.empty_cache()
|
418 |
+
return corresps
|
419 |
+
|
420 |
+
|
421 |
+
class RegressionMatcher(nn.Module):
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
encoder,
|
425 |
+
decoder,
|
426 |
+
h=448,
|
427 |
+
w=448,
|
428 |
+
sample_mode = "threshold_balanced",
|
429 |
+
upsample_preds = False,
|
430 |
+
symmetric = False,
|
431 |
+
name = None,
|
432 |
+
attenuate_cert = None,
|
433 |
+
recrop_upsample = False,
|
434 |
+
):
|
435 |
+
super().__init__()
|
436 |
+
self.attenuate_cert = attenuate_cert
|
437 |
+
self.encoder = encoder
|
438 |
+
self.decoder = decoder
|
439 |
+
self.name = name
|
440 |
+
self.w_resized = w
|
441 |
+
self.h_resized = h
|
442 |
+
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
443 |
+
self.sample_mode = sample_mode
|
444 |
+
self.upsample_preds = upsample_preds
|
445 |
+
self.upsample_res = (14*16*6, 14*16*6)
|
446 |
+
self.symmetric = symmetric
|
447 |
+
self.sample_thresh = 0.05
|
448 |
+
self.recrop_upsample = recrop_upsample
|
449 |
+
|
450 |
+
def get_output_resolution(self):
|
451 |
+
if not self.upsample_preds:
|
452 |
+
return self.h_resized, self.w_resized
|
453 |
+
else:
|
454 |
+
return self.upsample_res
|
455 |
+
|
456 |
+
def extract_backbone_features(self, batch, batched = True, upsample = False):
|
457 |
+
x_q = batch["im_A"]
|
458 |
+
x_s = batch["im_B"]
|
459 |
+
if batched:
|
460 |
+
X = torch.cat((x_q, x_s), dim = 0)
|
461 |
+
feature_pyramid = self.encoder(X, upsample = upsample)
|
462 |
+
else:
|
463 |
+
feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
|
464 |
+
return feature_pyramid
|
465 |
+
|
466 |
+
def sample(
|
467 |
+
self,
|
468 |
+
matches,
|
469 |
+
certainty,
|
470 |
+
num=10000,
|
471 |
+
):
|
472 |
+
if "threshold" in self.sample_mode:
|
473 |
+
upper_thresh = self.sample_thresh
|
474 |
+
certainty = certainty.clone()
|
475 |
+
certainty[certainty > upper_thresh] = 1
|
476 |
+
matches, certainty = (
|
477 |
+
matches.reshape(-1, 4),
|
478 |
+
certainty.reshape(-1),
|
479 |
+
)
|
480 |
+
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
481 |
+
good_samples = torch.multinomial(certainty,
|
482 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
483 |
+
replacement=False)
|
484 |
+
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
485 |
+
if "balanced" not in self.sample_mode:
|
486 |
+
return good_matches, good_certainty
|
487 |
+
density = kde(good_matches, std=0.1)
|
488 |
+
p = 1 / (density+1)
|
489 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
490 |
+
balanced_samples = torch.multinomial(p,
|
491 |
+
num_samples = min(num,len(good_certainty)),
|
492 |
+
replacement=False)
|
493 |
+
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
494 |
+
|
495 |
+
def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
|
496 |
+
feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
|
497 |
+
if batched:
|
498 |
+
f_q_pyramid = {
|
499 |
+
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
500 |
+
}
|
501 |
+
f_s_pyramid = {
|
502 |
+
scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
|
503 |
+
}
|
504 |
+
else:
|
505 |
+
f_q_pyramid, f_s_pyramid = feature_pyramid
|
506 |
+
corresps = self.decoder(f_q_pyramid,
|
507 |
+
f_s_pyramid,
|
508 |
+
upsample = upsample,
|
509 |
+
**(batch["corresps"] if "corresps" in batch else {}),
|
510 |
+
scale_factor=scale_factor)
|
511 |
+
|
512 |
+
return corresps
|
513 |
+
|
514 |
+
def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
|
515 |
+
feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
|
516 |
+
f_q_pyramid = feature_pyramid
|
517 |
+
f_s_pyramid = {
|
518 |
+
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
|
519 |
+
for scale, f_scale in feature_pyramid.items()
|
520 |
+
}
|
521 |
+
corresps = self.decoder(f_q_pyramid,
|
522 |
+
f_s_pyramid,
|
523 |
+
upsample = upsample,
|
524 |
+
**(batch["corresps"] if "corresps" in batch else {}),
|
525 |
+
scale_factor=scale_factor)
|
526 |
+
return corresps
|
527 |
+
|
528 |
+
def conf_from_fb_consistency(self, flow_forward, flow_backward, th = 2):
|
529 |
+
# assumes that flow forward is of shape (..., H, W, 2)
|
530 |
+
has_batch = False
|
531 |
+
if len(flow_forward.shape) == 3:
|
532 |
+
flow_forward, flow_backward = flow_forward[None], flow_backward[None]
|
533 |
+
else:
|
534 |
+
has_batch = True
|
535 |
+
H,W = flow_forward.shape[-3:-1]
|
536 |
+
th_n = 2 * th / max(H,W)
|
537 |
+
coords = torch.stack(torch.meshgrid(
|
538 |
+
torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
|
539 |
+
torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing = "xy"),
|
540 |
+
dim = -1).to(flow_forward.device)
|
541 |
+
coords_fb = F.grid_sample(
|
542 |
+
flow_backward.permute(0, 3, 1, 2),
|
543 |
+
flow_forward,
|
544 |
+
align_corners=False, mode="bilinear").permute(0, 2, 3, 1)
|
545 |
+
diff = (coords - coords_fb).norm(dim=-1)
|
546 |
+
in_th = (diff < th_n).float()
|
547 |
+
if not has_batch:
|
548 |
+
in_th = in_th[0]
|
549 |
+
return in_th
|
550 |
+
|
551 |
+
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
|
552 |
+
if coords.shape[-1] == 2:
|
553 |
+
return self._to_pixel_coordinates(coords, H_A, W_A)
|
554 |
+
|
555 |
+
if isinstance(coords, (list, tuple)):
|
556 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
557 |
+
else:
|
558 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
559 |
+
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
|
560 |
+
|
561 |
+
def _to_pixel_coordinates(self, coords, H, W):
|
562 |
+
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
|
563 |
+
return kpts
|
564 |
+
|
565 |
+
def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
|
566 |
+
if isinstance(coords, (list, tuple)):
|
567 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
568 |
+
else:
|
569 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
570 |
+
kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
|
571 |
+
kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
|
572 |
+
return kpts_A, kpts_B
|
573 |
+
|
574 |
+
def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
|
575 |
+
x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
|
576 |
+
cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
|
577 |
+
D = torch.cdist(x_A_to_B, x_B)
|
578 |
+
inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
|
579 |
+
|
580 |
+
if return_tuple:
|
581 |
+
if return_inds:
|
582 |
+
return inds_A, inds_B
|
583 |
+
else:
|
584 |
+
return x_A[inds_A], x_B[inds_B]
|
585 |
+
else:
|
586 |
+
if return_inds:
|
587 |
+
return torch.cat((inds_A, inds_B),dim=-1)
|
588 |
+
else:
|
589 |
+
return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
590 |
+
|
591 |
+
def get_roi(self, certainty, W, H, thr = 0.025):
|
592 |
+
raise NotImplementedError("WIP, disable for now")
|
593 |
+
hs,ws = certainty.shape
|
594 |
+
certainty = certainty/certainty.sum(dim=(-1,-2))
|
595 |
+
cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
|
596 |
+
cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
|
597 |
+
print(cum_certainty_w)
|
598 |
+
print(torch.min(torch.nonzero(cum_certainty_w > thr)))
|
599 |
+
print(torch.min(torch.nonzero(cum_certainty_w < thr)))
|
600 |
+
left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
|
601 |
+
right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
|
602 |
+
top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
|
603 |
+
bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
|
604 |
+
print(left, right, top, bottom)
|
605 |
+
return left, top, right, bottom
|
606 |
+
|
607 |
+
def recrop(self, certainty, image_path):
|
608 |
+
roi = self.get_roi(certainty, *Image.open(image_path).size)
|
609 |
+
return Image.open(image_path).convert("RGB").crop(roi)
|
610 |
+
|
611 |
+
@torch.inference_mode()
|
612 |
+
def match(
|
613 |
+
self,
|
614 |
+
im_A_path: Union[str, os.PathLike, Image.Image],
|
615 |
+
im_B_path: Union[str, os.PathLike, Image.Image],
|
616 |
+
*args,
|
617 |
+
batched=False,
|
618 |
+
device = None,
|
619 |
+
):
|
620 |
+
if device is None:
|
621 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
622 |
+
if isinstance(im_A_path, (str, os.PathLike)):
|
623 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
624 |
+
else:
|
625 |
+
im_A, im_B = im_A_path, im_B_path
|
626 |
+
|
627 |
+
symmetric = self.symmetric
|
628 |
+
self.train(False)
|
629 |
+
with torch.no_grad():
|
630 |
+
if not batched:
|
631 |
+
b = 1
|
632 |
+
w, h = im_A.size
|
633 |
+
w2, h2 = im_B.size
|
634 |
+
# Get images in good format
|
635 |
+
ws = self.w_resized
|
636 |
+
hs = self.h_resized
|
637 |
+
|
638 |
+
test_transform = get_tuple_transform_ops(
|
639 |
+
resize=(hs, ws), normalize=True, clahe = False
|
640 |
+
)
|
641 |
+
im_A, im_B = test_transform((im_A, im_B))
|
642 |
+
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
643 |
+
else:
|
644 |
+
b, c, h, w = im_A.shape
|
645 |
+
b, c, h2, w2 = im_B.shape
|
646 |
+
assert w == w2 and h == h2, "For batched images we assume same size"
|
647 |
+
batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
|
648 |
+
if h != self.h_resized or self.w_resized != w:
|
649 |
+
warn("Model resolution and batch resolution differ, may produce unexpected results")
|
650 |
+
hs, ws = h, w
|
651 |
+
finest_scale = 1
|
652 |
+
# Run matcher
|
653 |
+
if symmetric:
|
654 |
+
corresps = self.forward_symmetric(batch)
|
655 |
+
else:
|
656 |
+
corresps = self.forward(batch, batched = True)
|
657 |
+
|
658 |
+
if self.upsample_preds:
|
659 |
+
hs, ws = self.upsample_res
|
660 |
+
|
661 |
+
if self.attenuate_cert:
|
662 |
+
low_res_certainty = F.interpolate(
|
663 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
664 |
+
)
|
665 |
+
cert_clamp = 0
|
666 |
+
factor = 0.5
|
667 |
+
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
668 |
+
|
669 |
+
if self.upsample_preds:
|
670 |
+
finest_corresps = corresps[finest_scale]
|
671 |
+
torch.cuda.empty_cache()
|
672 |
+
test_transform = get_tuple_transform_ops(
|
673 |
+
resize=(hs, ws), normalize=True
|
674 |
+
)
|
675 |
+
if self.recrop_upsample:
|
676 |
+
raise NotImplementedError("recrop_upsample not implemented")
|
677 |
+
certainty = corresps[finest_scale]["certainty"]
|
678 |
+
print(certainty.shape)
|
679 |
+
im_A = self.recrop(certainty[0,0], im_A_path)
|
680 |
+
im_B = self.recrop(certainty[1,0], im_B_path)
|
681 |
+
#TODO: need to adjust corresps when doing this
|
682 |
+
im_A, im_B = test_transform((im_A, im_B))
|
683 |
+
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
684 |
+
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
685 |
+
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
686 |
+
if symmetric:
|
687 |
+
corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
|
688 |
+
else:
|
689 |
+
corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
|
690 |
+
|
691 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
692 |
+
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
693 |
+
if finest_scale != 1:
|
694 |
+
im_A_to_im_B = F.interpolate(
|
695 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
696 |
+
)
|
697 |
+
certainty = F.interpolate(
|
698 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
699 |
+
)
|
700 |
+
im_A_to_im_B = im_A_to_im_B.permute(
|
701 |
+
0, 2, 3, 1
|
702 |
+
)
|
703 |
+
# Create im_A meshgrid
|
704 |
+
im_A_coords = torch.meshgrid(
|
705 |
+
(
|
706 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
707 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
708 |
+
)
|
709 |
+
)
|
710 |
+
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
711 |
+
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
712 |
+
certainty = certainty.sigmoid() # logits -> probs
|
713 |
+
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
714 |
+
if (im_A_to_im_B.abs() > 1).any() and True:
|
715 |
+
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
716 |
+
certainty[wrong[:,None]] = 0
|
717 |
+
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
718 |
+
if symmetric:
|
719 |
+
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
720 |
+
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
721 |
+
im_B_coords = im_A_coords
|
722 |
+
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
723 |
+
warp = torch.cat((q_warp, s_warp),dim=2)
|
724 |
+
certainty = torch.cat(certainty.chunk(2), dim=3)
|
725 |
+
else:
|
726 |
+
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
727 |
+
if batched:
|
728 |
+
return (
|
729 |
+
warp,
|
730 |
+
certainty[:, 0]
|
731 |
+
)
|
732 |
+
else:
|
733 |
+
return (
|
734 |
+
warp[0],
|
735 |
+
certainty[0, 0],
|
736 |
+
)
|
737 |
+
|
738 |
+
def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
|
739 |
+
im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None, unnormalize = False):
|
740 |
+
#assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
|
741 |
+
H,W2,_ = warp.shape
|
742 |
+
W = W2//2 if symmetric else W2
|
743 |
+
if im_A is None:
|
744 |
+
from PIL import Image
|
745 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
746 |
+
if not isinstance(im_A, torch.Tensor):
|
747 |
+
im_A = im_A.resize((W,H))
|
748 |
+
im_B = im_B.resize((W,H))
|
749 |
+
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
|
750 |
+
if symmetric:
|
751 |
+
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
|
752 |
+
else:
|
753 |
+
if symmetric:
|
754 |
+
x_A = im_A
|
755 |
+
x_B = im_B
|
756 |
+
im_A_transfer_rgb = F.grid_sample(
|
757 |
+
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
758 |
+
)[0]
|
759 |
+
if symmetric:
|
760 |
+
im_B_transfer_rgb = F.grid_sample(
|
761 |
+
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
762 |
+
)[0]
|
763 |
+
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
|
764 |
+
white_im = torch.ones((H,2*W),device=device)
|
765 |
+
else:
|
766 |
+
warp_im = im_A_transfer_rgb
|
767 |
+
white_im = torch.ones((H, W), device = device)
|
768 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
769 |
+
if save_path is not None:
|
770 |
+
from romatch.utils import tensor_to_pil
|
771 |
+
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
|
772 |
+
return vis_im
|
third_party/RoMa/romatch/models/model_zoo/__init__.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import torch
|
3 |
+
from .roma_models import roma_model, tiny_roma_v1_model
|
4 |
+
|
5 |
+
weight_urls = {
|
6 |
+
"romatch": {
|
7 |
+
"outdoor": "https://github.com/Parskatt/storage/releases/download/romatch/roma_outdoor.pth",
|
8 |
+
"indoor": "https://github.com/Parskatt/storage/releases/download/romatch/roma_indoor.pth",
|
9 |
+
},
|
10 |
+
"tiny_roma_v1": {
|
11 |
+
"outdoor": "https://github.com/Parskatt/storage/releases/download/romatch/tiny_roma_v1_outdoor.pth",
|
12 |
+
},
|
13 |
+
"dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
|
14 |
+
}
|
15 |
+
|
16 |
+
def tiny_roma_v1_outdoor(device, weights = None, xfeat = None):
|
17 |
+
if weights is None:
|
18 |
+
weights = torch.hub.load_state_dict_from_url(
|
19 |
+
weight_urls["tiny_roma_v1"]["outdoor"],
|
20 |
+
map_location=device)
|
21 |
+
if xfeat is None:
|
22 |
+
xfeat = torch.hub.load(
|
23 |
+
'verlab/accelerated_features',
|
24 |
+
'XFeat',
|
25 |
+
pretrained = True,
|
26 |
+
top_k = 4096).net
|
27 |
+
|
28 |
+
return tiny_roma_v1_model(weights = weights, xfeat = xfeat).to(device)
|
29 |
+
|
30 |
+
def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
|
31 |
+
if isinstance(coarse_res, int):
|
32 |
+
coarse_res = (coarse_res, coarse_res)
|
33 |
+
if isinstance(upsample_res, int):
|
34 |
+
upsample_res = (upsample_res, upsample_res)
|
35 |
+
|
36 |
+
assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
37 |
+
assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
38 |
+
|
39 |
+
if weights is None:
|
40 |
+
weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["outdoor"],
|
41 |
+
map_location=device)
|
42 |
+
if dinov2_weights is None:
|
43 |
+
dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
|
44 |
+
map_location=device)
|
45 |
+
model = roma_model(resolution=coarse_res, upsample_preds=True,
|
46 |
+
weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
|
47 |
+
model.upsample_res = upsample_res
|
48 |
+
print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
|
49 |
+
return model
|
50 |
+
|
51 |
+
def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
|
52 |
+
if isinstance(coarse_res, int):
|
53 |
+
coarse_res = (coarse_res, coarse_res)
|
54 |
+
if isinstance(upsample_res, int):
|
55 |
+
upsample_res = (upsample_res, upsample_res)
|
56 |
+
|
57 |
+
assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
58 |
+
assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
59 |
+
|
60 |
+
if weights is None:
|
61 |
+
weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["indoor"],
|
62 |
+
map_location=device)
|
63 |
+
if dinov2_weights is None:
|
64 |
+
dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
|
65 |
+
map_location=device)
|
66 |
+
model = roma_model(resolution=coarse_res, upsample_preds=True,
|
67 |
+
weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
|
68 |
+
model.upsample_res = upsample_res
|
69 |
+
print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
|
70 |
+
return model
|
third_party/RoMa/romatch/models/model_zoo/roma_models.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
from romatch.models.matcher import *
|
5 |
+
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
|
6 |
+
from romatch.models.encoders import *
|
7 |
+
from romatch.models.tiny import TinyRoMa
|
8 |
+
|
9 |
+
def tiny_roma_v1_model(weights = None, freeze_xfeat=False, exact_softmax=False, xfeat = None):
|
10 |
+
model = TinyRoMa(
|
11 |
+
xfeat = xfeat,
|
12 |
+
freeze_xfeat=freeze_xfeat,
|
13 |
+
exact_softmax=exact_softmax)
|
14 |
+
if weights is not None:
|
15 |
+
model.load_state_dict(weights)
|
16 |
+
return model
|
17 |
+
|
18 |
+
def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
|
19 |
+
# romatch weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
|
20 |
+
#torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
|
21 |
+
#torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
22 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
23 |
+
gp_dim = 512
|
24 |
+
feat_dim = 512
|
25 |
+
decoder_dim = gp_dim + feat_dim
|
26 |
+
cls_to_coord_res = 64
|
27 |
+
coordinate_decoder = TransformerDecoder(
|
28 |
+
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
|
29 |
+
decoder_dim,
|
30 |
+
cls_to_coord_res**2 + 1,
|
31 |
+
is_classifier=True,
|
32 |
+
amp = True,
|
33 |
+
pos_enc = False,)
|
34 |
+
dw = True
|
35 |
+
hidden_blocks = 8
|
36 |
+
kernel_size = 5
|
37 |
+
displacement_emb = "linear"
|
38 |
+
disable_local_corr_grad = True
|
39 |
+
|
40 |
+
conv_refiner = nn.ModuleDict(
|
41 |
+
{
|
42 |
+
"16": ConvRefiner(
|
43 |
+
2 * 512+128+(2*7+1)**2,
|
44 |
+
2 * 512+128+(2*7+1)**2,
|
45 |
+
2 + 1,
|
46 |
+
kernel_size=kernel_size,
|
47 |
+
dw=dw,
|
48 |
+
hidden_blocks=hidden_blocks,
|
49 |
+
displacement_emb=displacement_emb,
|
50 |
+
displacement_emb_dim=128,
|
51 |
+
local_corr_radius = 7,
|
52 |
+
corr_in_other = True,
|
53 |
+
amp = True,
|
54 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
55 |
+
bn_momentum = 0.01,
|
56 |
+
),
|
57 |
+
"8": ConvRefiner(
|
58 |
+
2 * 512+64+(2*3+1)**2,
|
59 |
+
2 * 512+64+(2*3+1)**2,
|
60 |
+
2 + 1,
|
61 |
+
kernel_size=kernel_size,
|
62 |
+
dw=dw,
|
63 |
+
hidden_blocks=hidden_blocks,
|
64 |
+
displacement_emb=displacement_emb,
|
65 |
+
displacement_emb_dim=64,
|
66 |
+
local_corr_radius = 3,
|
67 |
+
corr_in_other = True,
|
68 |
+
amp = True,
|
69 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
70 |
+
bn_momentum = 0.01,
|
71 |
+
),
|
72 |
+
"4": ConvRefiner(
|
73 |
+
2 * 256+32+(2*2+1)**2,
|
74 |
+
2 * 256+32+(2*2+1)**2,
|
75 |
+
2 + 1,
|
76 |
+
kernel_size=kernel_size,
|
77 |
+
dw=dw,
|
78 |
+
hidden_blocks=hidden_blocks,
|
79 |
+
displacement_emb=displacement_emb,
|
80 |
+
displacement_emb_dim=32,
|
81 |
+
local_corr_radius = 2,
|
82 |
+
corr_in_other = True,
|
83 |
+
amp = True,
|
84 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
85 |
+
bn_momentum = 0.01,
|
86 |
+
),
|
87 |
+
"2": ConvRefiner(
|
88 |
+
2 * 64+16,
|
89 |
+
128+16,
|
90 |
+
2 + 1,
|
91 |
+
kernel_size=kernel_size,
|
92 |
+
dw=dw,
|
93 |
+
hidden_blocks=hidden_blocks,
|
94 |
+
displacement_emb=displacement_emb,
|
95 |
+
displacement_emb_dim=16,
|
96 |
+
amp = True,
|
97 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
98 |
+
bn_momentum = 0.01,
|
99 |
+
),
|
100 |
+
"1": ConvRefiner(
|
101 |
+
2 * 9 + 6,
|
102 |
+
24,
|
103 |
+
2 + 1,
|
104 |
+
kernel_size=kernel_size,
|
105 |
+
dw=dw,
|
106 |
+
hidden_blocks = hidden_blocks,
|
107 |
+
displacement_emb = displacement_emb,
|
108 |
+
displacement_emb_dim = 6,
|
109 |
+
amp = True,
|
110 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
111 |
+
bn_momentum = 0.01,
|
112 |
+
),
|
113 |
+
}
|
114 |
+
)
|
115 |
+
kernel_temperature = 0.2
|
116 |
+
learn_temperature = False
|
117 |
+
no_cov = True
|
118 |
+
kernel = CosKernel
|
119 |
+
only_attention = False
|
120 |
+
basis = "fourier"
|
121 |
+
gp16 = GP(
|
122 |
+
kernel,
|
123 |
+
T=kernel_temperature,
|
124 |
+
learn_temperature=learn_temperature,
|
125 |
+
only_attention=only_attention,
|
126 |
+
gp_dim=gp_dim,
|
127 |
+
basis=basis,
|
128 |
+
no_cov=no_cov,
|
129 |
+
)
|
130 |
+
gps = nn.ModuleDict({"16": gp16})
|
131 |
+
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
|
132 |
+
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
|
133 |
+
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
134 |
+
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
135 |
+
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
136 |
+
proj = nn.ModuleDict({
|
137 |
+
"16": proj16,
|
138 |
+
"8": proj8,
|
139 |
+
"4": proj4,
|
140 |
+
"2": proj2,
|
141 |
+
"1": proj1,
|
142 |
+
})
|
143 |
+
displacement_dropout_p = 0.0
|
144 |
+
gm_warp_dropout_p = 0.0
|
145 |
+
decoder = Decoder(coordinate_decoder,
|
146 |
+
gps,
|
147 |
+
proj,
|
148 |
+
conv_refiner,
|
149 |
+
detach=True,
|
150 |
+
scales=["16", "8", "4", "2", "1"],
|
151 |
+
displacement_dropout_p = displacement_dropout_p,
|
152 |
+
gm_warp_dropout_p = gm_warp_dropout_p)
|
153 |
+
|
154 |
+
encoder = CNNandDinov2(
|
155 |
+
cnn_kwargs = dict(
|
156 |
+
pretrained=False,
|
157 |
+
amp = True),
|
158 |
+
amp = True,
|
159 |
+
use_vgg = True,
|
160 |
+
dinov2_weights = dinov2_weights,
|
161 |
+
amp_dtype=amp_dtype,
|
162 |
+
)
|
163 |
+
h,w = resolution
|
164 |
+
symmetric = True
|
165 |
+
attenuate_cert = True
|
166 |
+
sample_mode = "threshold_balanced"
|
167 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
|
168 |
+
symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
|
169 |
+
matcher.load_state_dict(weights)
|
170 |
+
return matcher
|
third_party/RoMa/romatch/models/tiny.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from pathlib import Path
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from torch import nn
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision.transforms import ToTensor
|
13 |
+
from romatch.utils.kde import kde
|
14 |
+
|
15 |
+
class BasicLayer(nn.Module):
|
16 |
+
"""
|
17 |
+
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
|
18 |
+
"""
|
19 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
|
20 |
+
super().__init__()
|
21 |
+
self.layer = nn.Sequential(
|
22 |
+
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
|
23 |
+
nn.BatchNorm2d(out_channels, affine=False),
|
24 |
+
nn.ReLU(inplace = True) if relu else nn.Identity()
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.layer(x)
|
29 |
+
|
30 |
+
class TinyRoMa(nn.Module):
|
31 |
+
"""
|
32 |
+
Implementation of architecture described in
|
33 |
+
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, xfeat = None,
|
37 |
+
freeze_xfeat = True,
|
38 |
+
sample_mode = "threshold_balanced",
|
39 |
+
symmetric = False,
|
40 |
+
exact_softmax = False):
|
41 |
+
super().__init__()
|
42 |
+
del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
|
43 |
+
if freeze_xfeat:
|
44 |
+
xfeat.train(False)
|
45 |
+
self.xfeat = [xfeat]# hide params from ddp
|
46 |
+
else:
|
47 |
+
self.xfeat = nn.ModuleList([xfeat])
|
48 |
+
self.freeze_xfeat = freeze_xfeat
|
49 |
+
match_dim = 256
|
50 |
+
self.coarse_matcher = nn.Sequential(
|
51 |
+
BasicLayer(64+64+2, match_dim,),
|
52 |
+
BasicLayer(match_dim, match_dim,),
|
53 |
+
BasicLayer(match_dim, match_dim,),
|
54 |
+
BasicLayer(match_dim, match_dim,),
|
55 |
+
nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
|
56 |
+
fine_match_dim = 64
|
57 |
+
self.fine_matcher = nn.Sequential(
|
58 |
+
BasicLayer(24+24+2, fine_match_dim,),
|
59 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
60 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
61 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
62 |
+
nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
|
63 |
+
self.sample_mode = sample_mode
|
64 |
+
self.sample_thresh = 0.05
|
65 |
+
self.symmetric = symmetric
|
66 |
+
self.exact_softmax = exact_softmax
|
67 |
+
|
68 |
+
@property
|
69 |
+
def device(self):
|
70 |
+
return self.fine_matcher[-1].weight.device
|
71 |
+
|
72 |
+
def preprocess_tensor(self, x):
|
73 |
+
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
|
74 |
+
H, W = x.shape[-2:]
|
75 |
+
_H, _W = (H//32) * 32, (W//32) * 32
|
76 |
+
rh, rw = H/_H, W/_W
|
77 |
+
|
78 |
+
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
|
79 |
+
return x, rh, rw
|
80 |
+
|
81 |
+
def forward_single(self, x):
|
82 |
+
with torch.inference_mode(self.freeze_xfeat or not self.training):
|
83 |
+
xfeat = self.xfeat[0]
|
84 |
+
with torch.no_grad():
|
85 |
+
x = x.mean(dim=1, keepdim = True)
|
86 |
+
x = xfeat.norm(x)
|
87 |
+
|
88 |
+
#main backbone
|
89 |
+
x1 = xfeat.block1(x)
|
90 |
+
x2 = xfeat.block2(x1 + xfeat.skip1(x))
|
91 |
+
x3 = xfeat.block3(x2)
|
92 |
+
x4 = xfeat.block4(x3)
|
93 |
+
x5 = xfeat.block5(x4)
|
94 |
+
x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
|
95 |
+
x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
|
96 |
+
feats = xfeat.block_fusion( x3 + x4 + x5 )
|
97 |
+
if self.freeze_xfeat:
|
98 |
+
return x2.clone(), feats.clone()
|
99 |
+
return x2, feats
|
100 |
+
|
101 |
+
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
|
102 |
+
if coords.shape[-1] == 2:
|
103 |
+
return self._to_pixel_coordinates(coords, H_A, W_A)
|
104 |
+
|
105 |
+
if isinstance(coords, (list, tuple)):
|
106 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
107 |
+
else:
|
108 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
109 |
+
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
|
110 |
+
|
111 |
+
def _to_pixel_coordinates(self, coords, H, W):
|
112 |
+
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
|
113 |
+
return kpts
|
114 |
+
|
115 |
+
def pos_embed(self, corr_volume: torch.Tensor):
|
116 |
+
B, H1, W1, H0, W0 = corr_volume.shape
|
117 |
+
grid = torch.stack(
|
118 |
+
torch.meshgrid(
|
119 |
+
torch.linspace(-1+1/W1,1-1/W1, W1),
|
120 |
+
torch.linspace(-1+1/H1,1-1/H1, H1),
|
121 |
+
indexing = "xy"),
|
122 |
+
dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
|
123 |
+
down = 4
|
124 |
+
if not self.training and not self.exact_softmax:
|
125 |
+
grid_lr = torch.stack(
|
126 |
+
torch.meshgrid(
|
127 |
+
torch.linspace(-1+down/W1,1-down/W1, W1//down),
|
128 |
+
torch.linspace(-1+down/H1,1-down/H1, H1//down),
|
129 |
+
indexing = "xy"),
|
130 |
+
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
|
131 |
+
cv = corr_volume
|
132 |
+
best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
|
133 |
+
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
|
134 |
+
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
|
135 |
+
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
|
136 |
+
#print("hej")
|
137 |
+
else:
|
138 |
+
P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
|
139 |
+
pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
|
140 |
+
return pos_embeddings
|
141 |
+
|
142 |
+
def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
|
143 |
+
im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
|
144 |
+
device = warp.device
|
145 |
+
H,W2,_ = warp.shape
|
146 |
+
W = W2//2 if symmetric else W2
|
147 |
+
if im_A is None:
|
148 |
+
from PIL import Image
|
149 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
150 |
+
if not isinstance(im_A, torch.Tensor):
|
151 |
+
im_A = im_A.resize((W,H))
|
152 |
+
im_B = im_B.resize((W,H))
|
153 |
+
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
|
154 |
+
if symmetric:
|
155 |
+
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
|
156 |
+
else:
|
157 |
+
if symmetric:
|
158 |
+
x_A = im_A
|
159 |
+
x_B = im_B
|
160 |
+
im_A_transfer_rgb = F.grid_sample(
|
161 |
+
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
162 |
+
)[0]
|
163 |
+
if symmetric:
|
164 |
+
im_B_transfer_rgb = F.grid_sample(
|
165 |
+
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
166 |
+
)[0]
|
167 |
+
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
|
168 |
+
white_im = torch.ones((H,2*W),device=device)
|
169 |
+
else:
|
170 |
+
warp_im = im_A_transfer_rgb
|
171 |
+
white_im = torch.ones((H, W), device = device)
|
172 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
173 |
+
if save_path is not None:
|
174 |
+
from romatch.utils import tensor_to_pil
|
175 |
+
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
|
176 |
+
return vis_im
|
177 |
+
|
178 |
+
def corr_volume(self, feat0, feat1):
|
179 |
+
"""
|
180 |
+
input:
|
181 |
+
feat0 -> torch.Tensor(B, C, H, W)
|
182 |
+
feat1 -> torch.Tensor(B, C, H, W)
|
183 |
+
return:
|
184 |
+
corr_volume -> torch.Tensor(B, H, W, H, W)
|
185 |
+
"""
|
186 |
+
B, C, H0, W0 = feat0.shape
|
187 |
+
B, C, H1, W1 = feat1.shape
|
188 |
+
feat0 = feat0.view(B, C, H0*W0)
|
189 |
+
feat1 = feat1.view(B, C, H1*W1)
|
190 |
+
corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
|
191 |
+
return corr_volume
|
192 |
+
|
193 |
+
@torch.inference_mode()
|
194 |
+
def match_from_path(self, im0_path, im1_path):
|
195 |
+
device = self.device
|
196 |
+
im0 = ToTensor()(Image.open(im0_path))[None].to(device)
|
197 |
+
im1 = ToTensor()(Image.open(im1_path))[None].to(device)
|
198 |
+
return self.match(im0, im1, batched = False)
|
199 |
+
|
200 |
+
@torch.inference_mode()
|
201 |
+
def match(self, im0, im1, *args, batched = True):
|
202 |
+
# stupid
|
203 |
+
if isinstance(im0, (str, Path)):
|
204 |
+
return self.match_from_path(im0, im1)
|
205 |
+
elif isinstance(im0, Image.Image):
|
206 |
+
batched = False
|
207 |
+
device = self.device
|
208 |
+
im0 = ToTensor()(im0)[None].to(device)
|
209 |
+
im1 = ToTensor()(im1)[None].to(device)
|
210 |
+
|
211 |
+
B,C,H0,W0 = im0.shape
|
212 |
+
B,C,H1,W1 = im1.shape
|
213 |
+
self.train(False)
|
214 |
+
corresps = self.forward({"im_A":im0, "im_B":im1})
|
215 |
+
#return 1,1
|
216 |
+
flow = F.interpolate(
|
217 |
+
corresps[4]["flow"],
|
218 |
+
size = (H0, W0),
|
219 |
+
mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
|
220 |
+
grid = torch.stack(
|
221 |
+
torch.meshgrid(
|
222 |
+
torch.linspace(-1+1/W0,1-1/W0, W0),
|
223 |
+
torch.linspace(-1+1/H0,1-1/H0, H0),
|
224 |
+
indexing = "xy"),
|
225 |
+
dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
|
226 |
+
|
227 |
+
certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
|
228 |
+
warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
|
229 |
+
if batched:
|
230 |
+
return warp, cert
|
231 |
+
else:
|
232 |
+
return warp[0], cert[0]
|
233 |
+
|
234 |
+
def sample(
|
235 |
+
self,
|
236 |
+
matches,
|
237 |
+
certainty,
|
238 |
+
num=5_000,
|
239 |
+
):
|
240 |
+
H,W,_ = matches.shape
|
241 |
+
if "threshold" in self.sample_mode:
|
242 |
+
upper_thresh = self.sample_thresh
|
243 |
+
certainty = certainty.clone()
|
244 |
+
certainty[certainty > upper_thresh] = 1
|
245 |
+
matches, certainty = (
|
246 |
+
matches.reshape(-1, 4),
|
247 |
+
certainty.reshape(-1),
|
248 |
+
)
|
249 |
+
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
250 |
+
good_samples = torch.multinomial(certainty,
|
251 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
252 |
+
replacement=False)
|
253 |
+
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
254 |
+
if "balanced" not in self.sample_mode:
|
255 |
+
return good_matches, good_certainty
|
256 |
+
use_half = True if matches.device.type == "cuda" else False
|
257 |
+
down = 1 if matches.device.type == "cuda" else 8
|
258 |
+
density = kde(good_matches, std=0.1, half = use_half, down = down)
|
259 |
+
p = 1 / (density+1)
|
260 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
261 |
+
balanced_samples = torch.multinomial(p,
|
262 |
+
num_samples = min(num,len(good_certainty)),
|
263 |
+
replacement=False)
|
264 |
+
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
265 |
+
|
266 |
+
|
267 |
+
def forward(self, batch):
|
268 |
+
"""
|
269 |
+
input:
|
270 |
+
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
|
271 |
+
return:
|
272 |
+
|
273 |
+
"""
|
274 |
+
im0 = batch["im_A"]
|
275 |
+
im1 = batch["im_B"]
|
276 |
+
corresps = {}
|
277 |
+
im0, rh0, rw0 = self.preprocess_tensor(im0)
|
278 |
+
im1, rh1, rw1 = self.preprocess_tensor(im1)
|
279 |
+
B, C, H0, W0 = im0.shape
|
280 |
+
B, C, H1, W1 = im1.shape
|
281 |
+
to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
|
282 |
+
|
283 |
+
if im0.shape[-2:] == im1.shape[-2:]:
|
284 |
+
x = torch.cat([im0, im1], dim=0)
|
285 |
+
x = self.forward_single(x)
|
286 |
+
feats_x0_c, feats_x1_c = x[1].chunk(2)
|
287 |
+
feats_x0_f, feats_x1_f = x[0].chunk(2)
|
288 |
+
else:
|
289 |
+
feats_x0_f, feats_x0_c = self.forward_single(im0)
|
290 |
+
feats_x1_f, feats_x1_c = self.forward_single(im1)
|
291 |
+
corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
|
292 |
+
coarse_warp = self.pos_embed(corr_volume)
|
293 |
+
coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
|
294 |
+
feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
|
295 |
+
coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
|
296 |
+
coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
|
297 |
+
corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
|
298 |
+
coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
|
299 |
+
coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
|
300 |
+
feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
|
301 |
+
fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
|
302 |
+
fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
|
303 |
+
corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
|
304 |
+
return corresps
|
third_party/RoMa/romatch/models/transformer/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from romatch.utils.utils import get_grid
|
6 |
+
from .layers.block import Block
|
7 |
+
from .layers.attention import MemEffAttention
|
8 |
+
from .dinov2 import vit_large
|
9 |
+
|
10 |
+
class TransformerDecoder(nn.Module):
|
11 |
+
def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
|
12 |
+
amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
|
13 |
+
super().__init__(*args, **kwargs)
|
14 |
+
self.blocks = blocks
|
15 |
+
self.to_out = nn.Linear(hidden_dim, out_dim)
|
16 |
+
self.hidden_dim = hidden_dim
|
17 |
+
self.out_dim = out_dim
|
18 |
+
self._scales = [16]
|
19 |
+
self.is_classifier = is_classifier
|
20 |
+
self.amp = amp
|
21 |
+
self.amp_dtype = amp_dtype
|
22 |
+
self.pos_enc = pos_enc
|
23 |
+
self.learned_embeddings = learned_embeddings
|
24 |
+
if self.learned_embeddings:
|
25 |
+
self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
|
26 |
+
|
27 |
+
def scales(self):
|
28 |
+
return self._scales.copy()
|
29 |
+
|
30 |
+
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
31 |
+
with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
|
32 |
+
B,C,H,W = gp_posterior.shape
|
33 |
+
x = torch.cat((gp_posterior, features), dim = 1)
|
34 |
+
B,C,H,W = x.shape
|
35 |
+
grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
|
36 |
+
if self.learned_embeddings:
|
37 |
+
pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
|
38 |
+
else:
|
39 |
+
pos_enc = 0
|
40 |
+
tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
|
41 |
+
z = self.blocks(tokens)
|
42 |
+
out = self.to_out(z)
|
43 |
+
out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
|
44 |
+
warp, certainty = out[:, :-1], out[:, -1:]
|
45 |
+
return warp, certainty, None
|
46 |
+
|
47 |
+
|
third_party/RoMa/romatch/models/transformer/dinov2.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
from functools import partial
|
12 |
+
import math
|
13 |
+
import logging
|
14 |
+
from typing import Sequence, Tuple, Union, Callable
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.utils.checkpoint
|
19 |
+
from torch.nn.init import trunc_normal_
|
20 |
+
|
21 |
+
from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
26 |
+
if not depth_first and include_root:
|
27 |
+
fn(module=module, name=name)
|
28 |
+
for child_name, child_module in module.named_children():
|
29 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
30 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
31 |
+
if depth_first and include_root:
|
32 |
+
fn(module=module, name=name)
|
33 |
+
return module
|
34 |
+
|
35 |
+
|
36 |
+
class BlockChunk(nn.ModuleList):
|
37 |
+
def forward(self, x):
|
38 |
+
for b in self:
|
39 |
+
x = b(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class DinoVisionTransformer(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
img_size=224,
|
47 |
+
patch_size=16,
|
48 |
+
in_chans=3,
|
49 |
+
embed_dim=768,
|
50 |
+
depth=12,
|
51 |
+
num_heads=12,
|
52 |
+
mlp_ratio=4.0,
|
53 |
+
qkv_bias=True,
|
54 |
+
ffn_bias=True,
|
55 |
+
proj_bias=True,
|
56 |
+
drop_path_rate=0.0,
|
57 |
+
drop_path_uniform=False,
|
58 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
59 |
+
embed_layer=PatchEmbed,
|
60 |
+
act_layer=nn.GELU,
|
61 |
+
block_fn=Block,
|
62 |
+
ffn_layer="mlp",
|
63 |
+
block_chunks=1,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
img_size (int, tuple): input image size
|
68 |
+
patch_size (int, tuple): patch size
|
69 |
+
in_chans (int): number of input channels
|
70 |
+
embed_dim (int): embedding dimension
|
71 |
+
depth (int): depth of transformer
|
72 |
+
num_heads (int): number of attention heads
|
73 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
74 |
+
qkv_bias (bool): enable bias for qkv if True
|
75 |
+
proj_bias (bool): enable bias for proj in attn if True
|
76 |
+
ffn_bias (bool): enable bias for ffn if True
|
77 |
+
drop_path_rate (float): stochastic depth rate
|
78 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
79 |
+
weight_init (str): weight init scheme
|
80 |
+
init_values (float): layer-scale init values
|
81 |
+
embed_layer (nn.Module): patch embedding layer
|
82 |
+
act_layer (nn.Module): MLP activation layer
|
83 |
+
block_fn (nn.Module): transformer block class
|
84 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
85 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
89 |
+
|
90 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
91 |
+
self.num_tokens = 1
|
92 |
+
self.n_blocks = depth
|
93 |
+
self.num_heads = num_heads
|
94 |
+
self.patch_size = patch_size
|
95 |
+
|
96 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
97 |
+
num_patches = self.patch_embed.num_patches
|
98 |
+
|
99 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
100 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
101 |
+
|
102 |
+
if drop_path_uniform is True:
|
103 |
+
dpr = [drop_path_rate] * depth
|
104 |
+
else:
|
105 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
106 |
+
|
107 |
+
if ffn_layer == "mlp":
|
108 |
+
ffn_layer = Mlp
|
109 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
110 |
+
ffn_layer = SwiGLUFFNFused
|
111 |
+
elif ffn_layer == "identity":
|
112 |
+
|
113 |
+
def f(*args, **kwargs):
|
114 |
+
return nn.Identity()
|
115 |
+
|
116 |
+
ffn_layer = f
|
117 |
+
else:
|
118 |
+
raise NotImplementedError
|
119 |
+
|
120 |
+
blocks_list = [
|
121 |
+
block_fn(
|
122 |
+
dim=embed_dim,
|
123 |
+
num_heads=num_heads,
|
124 |
+
mlp_ratio=mlp_ratio,
|
125 |
+
qkv_bias=qkv_bias,
|
126 |
+
proj_bias=proj_bias,
|
127 |
+
ffn_bias=ffn_bias,
|
128 |
+
drop_path=dpr[i],
|
129 |
+
norm_layer=norm_layer,
|
130 |
+
act_layer=act_layer,
|
131 |
+
ffn_layer=ffn_layer,
|
132 |
+
init_values=init_values,
|
133 |
+
)
|
134 |
+
for i in range(depth)
|
135 |
+
]
|
136 |
+
if block_chunks > 0:
|
137 |
+
self.chunked_blocks = True
|
138 |
+
chunked_blocks = []
|
139 |
+
chunksize = depth // block_chunks
|
140 |
+
for i in range(0, depth, chunksize):
|
141 |
+
# this is to keep the block index consistent if we chunk the block list
|
142 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
143 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
144 |
+
else:
|
145 |
+
self.chunked_blocks = False
|
146 |
+
self.blocks = nn.ModuleList(blocks_list)
|
147 |
+
|
148 |
+
self.norm = norm_layer(embed_dim)
|
149 |
+
self.head = nn.Identity()
|
150 |
+
|
151 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
152 |
+
|
153 |
+
self.init_weights()
|
154 |
+
for param in self.parameters():
|
155 |
+
param.requires_grad = False
|
156 |
+
|
157 |
+
@property
|
158 |
+
def device(self):
|
159 |
+
return self.cls_token.device
|
160 |
+
|
161 |
+
def init_weights(self):
|
162 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
163 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
164 |
+
named_apply(init_weights_vit_timm, self)
|
165 |
+
|
166 |
+
def interpolate_pos_encoding(self, x, w, h):
|
167 |
+
previous_dtype = x.dtype
|
168 |
+
npatch = x.shape[1] - 1
|
169 |
+
N = self.pos_embed.shape[1] - 1
|
170 |
+
if npatch == N and w == h:
|
171 |
+
return self.pos_embed
|
172 |
+
pos_embed = self.pos_embed.float()
|
173 |
+
class_pos_embed = pos_embed[:, 0]
|
174 |
+
patch_pos_embed = pos_embed[:, 1:]
|
175 |
+
dim = x.shape[-1]
|
176 |
+
w0 = w // self.patch_size
|
177 |
+
h0 = h // self.patch_size
|
178 |
+
# we add a small number to avoid floating point error in the interpolation
|
179 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
180 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
181 |
+
|
182 |
+
patch_pos_embed = nn.functional.interpolate(
|
183 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
184 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
185 |
+
mode="bicubic",
|
186 |
+
)
|
187 |
+
|
188 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
189 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
190 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
191 |
+
|
192 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
193 |
+
B, nc, w, h = x.shape
|
194 |
+
x = self.patch_embed(x)
|
195 |
+
if masks is not None:
|
196 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
197 |
+
|
198 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
199 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
200 |
+
|
201 |
+
return x
|
202 |
+
|
203 |
+
def forward_features_list(self, x_list, masks_list):
|
204 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
205 |
+
for blk in self.blocks:
|
206 |
+
x = blk(x)
|
207 |
+
|
208 |
+
all_x = x
|
209 |
+
output = []
|
210 |
+
for x, masks in zip(all_x, masks_list):
|
211 |
+
x_norm = self.norm(x)
|
212 |
+
output.append(
|
213 |
+
{
|
214 |
+
"x_norm_clstoken": x_norm[:, 0],
|
215 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
216 |
+
"x_prenorm": x,
|
217 |
+
"masks": masks,
|
218 |
+
}
|
219 |
+
)
|
220 |
+
return output
|
221 |
+
|
222 |
+
def forward_features(self, x, masks=None):
|
223 |
+
if isinstance(x, list):
|
224 |
+
return self.forward_features_list(x, masks)
|
225 |
+
|
226 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
227 |
+
|
228 |
+
for blk in self.blocks:
|
229 |
+
x = blk(x)
|
230 |
+
|
231 |
+
x_norm = self.norm(x)
|
232 |
+
return {
|
233 |
+
"x_norm_clstoken": x_norm[:, 0],
|
234 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
235 |
+
"x_prenorm": x,
|
236 |
+
"masks": masks,
|
237 |
+
}
|
238 |
+
|
239 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
240 |
+
x = self.prepare_tokens_with_masks(x)
|
241 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
242 |
+
output, total_block_len = [], len(self.blocks)
|
243 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
244 |
+
for i, blk in enumerate(self.blocks):
|
245 |
+
x = blk(x)
|
246 |
+
if i in blocks_to_take:
|
247 |
+
output.append(x)
|
248 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
249 |
+
return output
|
250 |
+
|
251 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
252 |
+
x = self.prepare_tokens_with_masks(x)
|
253 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
254 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
255 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
256 |
+
for block_chunk in self.blocks:
|
257 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
258 |
+
x = blk(x)
|
259 |
+
if i in blocks_to_take:
|
260 |
+
output.append(x)
|
261 |
+
i += 1
|
262 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
263 |
+
return output
|
264 |
+
|
265 |
+
def get_intermediate_layers(
|
266 |
+
self,
|
267 |
+
x: torch.Tensor,
|
268 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
269 |
+
reshape: bool = False,
|
270 |
+
return_class_token: bool = False,
|
271 |
+
norm=True,
|
272 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
273 |
+
if self.chunked_blocks:
|
274 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
275 |
+
else:
|
276 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
277 |
+
if norm:
|
278 |
+
outputs = [self.norm(out) for out in outputs]
|
279 |
+
class_tokens = [out[:, 0] for out in outputs]
|
280 |
+
outputs = [out[:, 1:] for out in outputs]
|
281 |
+
if reshape:
|
282 |
+
B, _, w, h = x.shape
|
283 |
+
outputs = [
|
284 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
285 |
+
for out in outputs
|
286 |
+
]
|
287 |
+
if return_class_token:
|
288 |
+
return tuple(zip(outputs, class_tokens))
|
289 |
+
return tuple(outputs)
|
290 |
+
|
291 |
+
def forward(self, *args, is_training=False, **kwargs):
|
292 |
+
ret = self.forward_features(*args, **kwargs)
|
293 |
+
if is_training:
|
294 |
+
return ret
|
295 |
+
else:
|
296 |
+
return self.head(ret["x_norm_clstoken"])
|
297 |
+
|
298 |
+
|
299 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
300 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
301 |
+
if isinstance(module, nn.Linear):
|
302 |
+
trunc_normal_(module.weight, std=0.02)
|
303 |
+
if module.bias is not None:
|
304 |
+
nn.init.zeros_(module.bias)
|
305 |
+
|
306 |
+
|
307 |
+
def vit_small(patch_size=16, **kwargs):
|
308 |
+
model = DinoVisionTransformer(
|
309 |
+
patch_size=patch_size,
|
310 |
+
embed_dim=384,
|
311 |
+
depth=12,
|
312 |
+
num_heads=6,
|
313 |
+
mlp_ratio=4,
|
314 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
315 |
+
**kwargs,
|
316 |
+
)
|
317 |
+
return model
|
318 |
+
|
319 |
+
|
320 |
+
def vit_base(patch_size=16, **kwargs):
|
321 |
+
model = DinoVisionTransformer(
|
322 |
+
patch_size=patch_size,
|
323 |
+
embed_dim=768,
|
324 |
+
depth=12,
|
325 |
+
num_heads=12,
|
326 |
+
mlp_ratio=4,
|
327 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
328 |
+
**kwargs,
|
329 |
+
)
|
330 |
+
return model
|
331 |
+
|
332 |
+
|
333 |
+
def vit_large(patch_size=16, **kwargs):
|
334 |
+
model = DinoVisionTransformer(
|
335 |
+
patch_size=patch_size,
|
336 |
+
embed_dim=1024,
|
337 |
+
depth=24,
|
338 |
+
num_heads=16,
|
339 |
+
mlp_ratio=4,
|
340 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
341 |
+
**kwargs,
|
342 |
+
)
|
343 |
+
return model
|
344 |
+
|
345 |
+
|
346 |
+
def vit_giant2(patch_size=16, **kwargs):
|
347 |
+
"""
|
348 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
349 |
+
"""
|
350 |
+
model = DinoVisionTransformer(
|
351 |
+
patch_size=patch_size,
|
352 |
+
embed_dim=1536,
|
353 |
+
depth=40,
|
354 |
+
num_heads=24,
|
355 |
+
mlp_ratio=4,
|
356 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
357 |
+
**kwargs,
|
358 |
+
)
|
359 |
+
return model
|
third_party/RoMa/romatch/models/transformer/layers/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .dino_head import DINOHead
|
8 |
+
from .mlp import Mlp
|
9 |
+
from .patch_embed import PatchEmbed
|
10 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
11 |
+
from .block import NestedTensorBlock
|
12 |
+
from .attention import MemEffAttention
|
third_party/RoMa/romatch/models/transformer/layers/attention.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
third_party/RoMa/romatch/models/transformer/layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
third_party/RoMa/romatch/models/transformer/layers/dino_head.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn.init import trunc_normal_
|
10 |
+
from torch.nn.utils import weight_norm
|
11 |
+
|
12 |
+
|
13 |
+
class DINOHead(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_dim,
|
17 |
+
out_dim,
|
18 |
+
use_bn=False,
|
19 |
+
nlayers=3,
|
20 |
+
hidden_dim=2048,
|
21 |
+
bottleneck_dim=256,
|
22 |
+
mlp_bias=True,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
nlayers = max(nlayers, 1)
|
26 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
27 |
+
self.apply(self._init_weights)
|
28 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
29 |
+
self.last_layer.weight_g.data.fill_(1)
|
30 |
+
|
31 |
+
def _init_weights(self, m):
|
32 |
+
if isinstance(m, nn.Linear):
|
33 |
+
trunc_normal_(m.weight, std=0.02)
|
34 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
35 |
+
nn.init.constant_(m.bias, 0)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.mlp(x)
|
39 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
40 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
41 |
+
x = self.last_layer(x)
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
46 |
+
if nlayers == 1:
|
47 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
48 |
+
else:
|
49 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
50 |
+
if use_bn:
|
51 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
52 |
+
layers.append(nn.GELU())
|
53 |
+
for _ in range(nlayers - 2):
|
54 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
55 |
+
if use_bn:
|
56 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
57 |
+
layers.append(nn.GELU())
|
58 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
59 |
+
return nn.Sequential(*layers)
|
third_party/RoMa/romatch/models/transformer/layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
third_party/RoMa/romatch/models/transformer/layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
third_party/RoMa/romatch/models/transformer/layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
third_party/RoMa/romatch/models/transformer/layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
third_party/RoMa/romatch/train/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .train import train_k_epochs
|
third_party/RoMa/romatch/train/train.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from romatch.utils.utils import to_cuda
|
3 |
+
import romatch
|
4 |
+
import torch
|
5 |
+
import wandb
|
6 |
+
|
7 |
+
def log_param_statistics(named_parameters, norm_type = 2):
|
8 |
+
named_parameters = list(named_parameters)
|
9 |
+
grads = [p.grad for n, p in named_parameters if p.grad is not None]
|
10 |
+
weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
|
11 |
+
names = [n for n,p in named_parameters if p.grad is not None]
|
12 |
+
param_norm = torch.stack(weight_norms).norm(p=norm_type)
|
13 |
+
device = grads[0].device
|
14 |
+
grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
|
15 |
+
nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
|
16 |
+
nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
|
17 |
+
total_grad_norm = torch.norm(grad_norms, norm_type)
|
18 |
+
if torch.any(nans_or_infs):
|
19 |
+
print(f"These params have nan or inf grads: {nan_inf_names}")
|
20 |
+
wandb.log({"grad_norm": total_grad_norm.item()}, step = romatch.GLOBAL_STEP)
|
21 |
+
wandb.log({"param_norm": param_norm.item()}, step = romatch.GLOBAL_STEP)
|
22 |
+
|
23 |
+
def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
|
24 |
+
optimizer.zero_grad()
|
25 |
+
out = model(train_batch)
|
26 |
+
l = objective(out, train_batch)
|
27 |
+
grad_scaler.scale(l).backward()
|
28 |
+
grad_scaler.unscale_(optimizer)
|
29 |
+
log_param_statistics(model.named_parameters())
|
30 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
|
31 |
+
grad_scaler.step(optimizer)
|
32 |
+
grad_scaler.update()
|
33 |
+
wandb.log({"grad_scale": grad_scaler._scale.item()}, step = romatch.GLOBAL_STEP)
|
34 |
+
if grad_scaler._scale < 1.:
|
35 |
+
grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
|
36 |
+
romatch.GLOBAL_STEP = romatch.GLOBAL_STEP + romatch.STEP_SIZE # increment global step
|
37 |
+
return {"train_out": out, "train_loss": l.item()}
|
38 |
+
|
39 |
+
|
40 |
+
def train_k_steps(
|
41 |
+
n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, pbar_n_seconds = 1,
|
42 |
+
):
|
43 |
+
for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or romatch.RANK > 0, mininterval=pbar_n_seconds):
|
44 |
+
batch = next(dataloader)
|
45 |
+
model.train(True)
|
46 |
+
batch = to_cuda(batch)
|
47 |
+
train_step(
|
48 |
+
train_batch=batch,
|
49 |
+
model=model,
|
50 |
+
objective=objective,
|
51 |
+
optimizer=optimizer,
|
52 |
+
lr_scheduler=lr_scheduler,
|
53 |
+
grad_scaler=grad_scaler,
|
54 |
+
n=n,
|
55 |
+
grad_clip_norm = grad_clip_norm,
|
56 |
+
)
|
57 |
+
if ema_model is not None:
|
58 |
+
ema_model.update()
|
59 |
+
if warmup is not None:
|
60 |
+
with warmup.dampening():
|
61 |
+
lr_scheduler.step()
|
62 |
+
else:
|
63 |
+
lr_scheduler.step()
|
64 |
+
[wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
|
65 |
+
|
66 |
+
|
67 |
+
def train_epoch(
|
68 |
+
dataloader=None,
|
69 |
+
model=None,
|
70 |
+
objective=None,
|
71 |
+
optimizer=None,
|
72 |
+
lr_scheduler=None,
|
73 |
+
epoch=None,
|
74 |
+
):
|
75 |
+
model.train(True)
|
76 |
+
print(f"At epoch {epoch}")
|
77 |
+
for batch in tqdm(dataloader, mininterval=5.0):
|
78 |
+
batch = to_cuda(batch)
|
79 |
+
train_step(
|
80 |
+
train_batch=batch, model=model, objective=objective, optimizer=optimizer
|
81 |
+
)
|
82 |
+
lr_scheduler.step()
|
83 |
+
return {
|
84 |
+
"model": model,
|
85 |
+
"optimizer": optimizer,
|
86 |
+
"lr_scheduler": lr_scheduler,
|
87 |
+
"epoch": epoch,
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
def train_k_epochs(
|
92 |
+
start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
|
93 |
+
):
|
94 |
+
for epoch in range(start_epoch, end_epoch + 1):
|
95 |
+
train_epoch(
|
96 |
+
dataloader=dataloader,
|
97 |
+
model=model,
|
98 |
+
objective=objective,
|
99 |
+
optimizer=optimizer,
|
100 |
+
lr_scheduler=lr_scheduler,
|
101 |
+
epoch=epoch,
|
102 |
+
)
|
third_party/RoMa/romatch/utils/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (
|
2 |
+
pose_auc,
|
3 |
+
get_pose,
|
4 |
+
compute_relative_pose,
|
5 |
+
compute_pose_error,
|
6 |
+
estimate_pose,
|
7 |
+
estimate_pose_uncalibrated,
|
8 |
+
rotate_intrinsic,
|
9 |
+
get_tuple_transform_ops,
|
10 |
+
get_depth_tuple_transform_ops,
|
11 |
+
warp_kpts,
|
12 |
+
numpy_to_pil,
|
13 |
+
tensor_to_pil,
|
14 |
+
recover_pose,
|
15 |
+
signed_left_to_right_epipolar_distance,
|
16 |
+
)
|