Vincentqyw commited on
Commit
e8fe67e
1 Parent(s): 8811cfe

update: roma and dust3r

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. third_party/RoMa +0 -1
  2. third_party/RoMa/.gitignore +11 -0
  3. third_party/RoMa/LICENSE +21 -0
  4. third_party/RoMa/README.md +123 -0
  5. third_party/RoMa/assets/sacre_coeur_A.jpg +3 -0
  6. third_party/RoMa/assets/sacre_coeur_B.jpg +3 -0
  7. third_party/RoMa/assets/toronto_A.jpg +3 -0
  8. third_party/RoMa/assets/toronto_B.jpg +3 -0
  9. third_party/RoMa/data/.gitignore +2 -0
  10. third_party/RoMa/demo/demo_3D_effect.py +46 -0
  11. third_party/RoMa/demo/demo_fundamental.py +33 -0
  12. third_party/RoMa/demo/demo_match.py +47 -0
  13. third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
  14. third_party/RoMa/demo/gif/.gitignore +2 -0
  15. third_party/RoMa/requirements.txt +14 -0
  16. third_party/RoMa/romatch/__init__.py +8 -0
  17. third_party/RoMa/romatch/benchmarks/__init__.py +6 -0
  18. third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
  19. third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
  20. third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
  21. third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
  22. third_party/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
  23. third_party/RoMa/romatch/checkpointing/__init__.py +1 -0
  24. third_party/RoMa/romatch/checkpointing/checkpoint.py +60 -0
  25. third_party/RoMa/romatch/datasets/__init__.py +2 -0
  26. third_party/RoMa/romatch/datasets/megadepth.py +232 -0
  27. third_party/RoMa/romatch/datasets/scannet.py +160 -0
  28. third_party/RoMa/romatch/losses/__init__.py +1 -0
  29. third_party/RoMa/romatch/losses/robust_loss.py +161 -0
  30. third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py +160 -0
  31. third_party/RoMa/romatch/models/__init__.py +1 -0
  32. third_party/RoMa/romatch/models/encoders.py +119 -0
  33. third_party/RoMa/romatch/models/matcher.py +772 -0
  34. third_party/RoMa/romatch/models/model_zoo/__init__.py +70 -0
  35. third_party/RoMa/romatch/models/model_zoo/roma_models.py +170 -0
  36. third_party/RoMa/romatch/models/tiny.py +304 -0
  37. third_party/RoMa/romatch/models/transformer/__init__.py +47 -0
  38. third_party/RoMa/romatch/models/transformer/dinov2.py +359 -0
  39. third_party/RoMa/romatch/models/transformer/layers/__init__.py +12 -0
  40. third_party/RoMa/romatch/models/transformer/layers/attention.py +81 -0
  41. third_party/RoMa/romatch/models/transformer/layers/block.py +252 -0
  42. third_party/RoMa/romatch/models/transformer/layers/dino_head.py +59 -0
  43. third_party/RoMa/romatch/models/transformer/layers/drop_path.py +35 -0
  44. third_party/RoMa/romatch/models/transformer/layers/layer_scale.py +28 -0
  45. third_party/RoMa/romatch/models/transformer/layers/mlp.py +41 -0
  46. third_party/RoMa/romatch/models/transformer/layers/patch_embed.py +89 -0
  47. third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py +63 -0
  48. third_party/RoMa/romatch/train/__init__.py +1 -0
  49. third_party/RoMa/romatch/train/train.py +102 -0
  50. third_party/RoMa/romatch/utils/__init__.py +16 -0
third_party/RoMa DELETED
@@ -1 +0,0 @@
1
- Subproject commit 116537a1849f26b1ecf8e1b7ac0980a51befdc07
 
 
third_party/RoMa/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info*
2
+ *.vscode*
3
+ *__pycache__*
4
+ vis*
5
+ workspace*
6
+ .venv
7
+ .DS_Store
8
+ jobs/*
9
+ *ignore_me*
10
+ *.pth
11
+ wandb*
third_party/RoMa/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Johan Edstedt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/RoMa/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ <p align="center">
3
+ <h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
4
+ <p align="center">
5
+ <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
6
+ ·
7
+ <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
8
+ ·
9
+ <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
10
+ ·
11
+ <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
12
+ ·
13
+ <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
14
+ </p>
15
+ <h2 align="center"><p>
16
+ <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
17
+ <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
18
+ </p></h2>
19
+ <div align="center"></div>
20
+ </p>
21
+ <br/>
22
+ <p align="center">
23
+ <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
24
+ <br>
25
+ <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
26
+ </p>
27
+
28
+ ## Setup/Install
29
+ In your python environment (tested on Linux python 3.10), run:
30
+ ```bash
31
+ pip install -e .
32
+ ```
33
+ ## Demo / How to Use
34
+ We provide two demos in the [demos folder](demo).
35
+ Here's the gist of it:
36
+ ```python
37
+ from romatch import roma_outdoor
38
+ roma_model = roma_outdoor(device=device)
39
+ # Match
40
+ warp, certainty = roma_model.match(imA_path, imB_path, device=device)
41
+ # Sample matches for estimation
42
+ matches, certainty = roma_model.sample(warp, certainty)
43
+ # Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
44
+ kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
45
+ # Find a fundamental matrix (or anything else of interest)
46
+ F, mask = cv2.findFundamentalMat(
47
+ kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
48
+ )
49
+ ```
50
+
51
+ **New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
52
+
53
+ ## Settings
54
+
55
+ ### Resolution
56
+ By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
57
+ You can change this at construction (see roma_outdoor kwargs).
58
+ You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
59
+
60
+ ### Sampling
61
+ roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
62
+
63
+
64
+ ## Reproducing Results
65
+ The experiments in the paper are provided in the [experiments folder](experiments).
66
+
67
+ ### Training
68
+ 1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
69
+ 2. Run the relevant experiment, e.g.,
70
+ ```bash
71
+ torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
72
+ ```
73
+ ### Testing
74
+ ```bash
75
+ python experiments/roma_outdoor.py --only_test --benchmark mega-1500
76
+ ```
77
+ ## License
78
+ All our code except DINOv2 is MIT license.
79
+ DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
80
+
81
+ ## Acknowledgement
82
+ Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
83
+
84
+ ## Tiny RoMa
85
+ If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
86
+ ```python
87
+ from romatch import tiny_roma_v1_outdoor
88
+ tiny_roma_model = tiny_roma_v1_outdoor(device=device)
89
+ ```
90
+ Mega1500:
91
+ | | AUC@5 | AUC@10 | AUC@20 |
92
+ |----------|----------|----------|----------|
93
+ | XFeat | 46.4 | 58.9 | 69.2 |
94
+ | XFeat* | 51.9 | 67.2 | 78.9 |
95
+ | Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
96
+ | RoMa | - | - | - |
97
+
98
+ Mega-8-Scenes (See DKM):
99
+ | | AUC@5 | AUC@10 | AUC@20 |
100
+ |----------|----------|----------|----------|
101
+ | XFeat | - | - | - |
102
+ | XFeat* | 50.1 | 64.4 | 75.2 |
103
+ | Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
104
+ | RoMa | - | - | - |
105
+
106
+ IMC22 :'):
107
+ | | mAA@10 |
108
+ |----------|----------|
109
+ | XFeat | 42.1 |
110
+ | XFeat* | - |
111
+ | Tiny RoMa v1 | 42.2 |
112
+ | RoMa | - |
113
+
114
+ ## BibTeX
115
+ If you find our models useful, please consider citing our paper!
116
+ ```
117
+ @article{edstedt2024roma,
118
+ title={{RoMa: Robust Dense Feature Matching}},
119
+ author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
120
+ journal={IEEE Conference on Computer Vision and Pattern Recognition},
121
+ year={2024}
122
+ }
123
+ ```
third_party/RoMa/assets/sacre_coeur_A.jpg ADDED

Git LFS Details

  • SHA256: 90d9c5f5a4d76425624989215120fba6f2899190a1d5654b88fa380c64cf6b2c
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
third_party/RoMa/assets/sacre_coeur_B.jpg ADDED

Git LFS Details

  • SHA256: 2f1eb9bdd4d80e480f672d6a729689ac77f9fd5c8deb90f59b377590f3ca4799
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
third_party/RoMa/assets/toronto_A.jpg ADDED

Git LFS Details

  • SHA256: 40270c227df93f0f31b55e0f2ff38eb24f47940c4800c83758a74a5dfd7346ec
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
third_party/RoMa/assets/toronto_B.jpg ADDED

Git LFS Details

  • SHA256: a2c07550ed87e40fca8c38076eb3a81395d760a88bf0b8615167704107deff2f
  • Pointer size: 131 Bytes
  • Size of remote file: 286 kB
third_party/RoMa/data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
third_party/RoMa/demo/demo_3D_effect.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from romatch.utils.utils import tensor_to_pil
6
+
7
+ from romatch import roma_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+
12
+ if __name__ == "__main__":
13
+ from argparse import ArgumentParser
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
16
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
17
+ parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
18
+
19
+ args, _ = parser.parse_known_args()
20
+ im1_path = args.im_A_path
21
+ im2_path = args.im_B_path
22
+ save_path = args.save_path
23
+
24
+ # Create model
25
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
26
+ roma_model.symmetric = False
27
+
28
+ H, W = roma_model.get_output_resolution()
29
+
30
+ im1 = Image.open(im1_path).resize((W, H))
31
+ im2 = Image.open(im2_path).resize((W, H))
32
+
33
+ # Match
34
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
35
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
36
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
37
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
38
+
39
+ coords_A, coords_B = warp[...,:2], warp[...,2:]
40
+ for i, x in enumerate(np.linspace(0,2*np.pi,200)):
41
+ t = (1 + np.cos(x))/2
42
+ interp_warp = (1-t)*coords_A + t*coords_B
43
+ im2_transfer_rgb = F.grid_sample(
44
+ x2[None], interp_warp[None], mode="bilinear", align_corners=False
45
+ )[0]
46
+ tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
third_party/RoMa/demo/demo_fundamental.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import cv2
4
+ from romatch import roma_outdoor
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+
9
+ if __name__ == "__main__":
10
+ from argparse import ArgumentParser
11
+ parser = ArgumentParser()
12
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
13
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
14
+
15
+ args, _ = parser.parse_known_args()
16
+ im1_path = args.im_A_path
17
+ im2_path = args.im_B_path
18
+
19
+ # Create model
20
+ roma_model = roma_outdoor(device=device)
21
+
22
+
23
+ W_A, H_A = Image.open(im1_path).size
24
+ W_B, H_B = Image.open(im2_path).size
25
+
26
+ # Match
27
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
28
+ # Sample matches for estimation
29
+ matches, certainty = roma_model.sample(warp, certainty)
30
+ kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
31
+ F, mask = cv2.findFundamentalMat(
32
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
33
+ )
third_party/RoMa/demo/demo_match.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from romatch.utils.utils import tensor_to_pil
6
+
7
+ from romatch import roma_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+
12
+ if __name__ == "__main__":
13
+ from argparse import ArgumentParser
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
16
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
17
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
18
+
19
+ args, _ = parser.parse_known_args()
20
+ im1_path = args.im_A_path
21
+ im2_path = args.im_B_path
22
+ save_path = args.save_path
23
+
24
+ # Create model
25
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
26
+
27
+ H, W = roma_model.get_output_resolution()
28
+
29
+ im1 = Image.open(im1_path).resize((W, H))
30
+ im2 = Image.open(im2_path).resize((W, H))
31
+
32
+ # Match
33
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
34
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
35
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
36
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
37
+
38
+ im2_transfer_rgb = F.grid_sample(
39
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
40
+ )[0]
41
+ im1_transfer_rgb = F.grid_sample(
42
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
43
+ )[0]
44
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
45
+ white_im = torch.ones((H,2*W),device=device)
46
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
47
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
third_party/RoMa/demo/demo_match_opencv_sift.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import numpy as np
5
+ import cv2 as cv
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+
10
+ if __name__ == "__main__":
11
+ from argparse import ArgumentParser
12
+ parser = ArgumentParser()
13
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
14
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
15
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
16
+
17
+ args, _ = parser.parse_known_args()
18
+ im1_path = args.im_A_path
19
+ im2_path = args.im_B_path
20
+ save_path = args.save_path
21
+
22
+ img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
23
+ img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
24
+ # Initiate SIFT detector
25
+ sift = cv.SIFT_create()
26
+ # find the keypoints and descriptors with SIFT
27
+ kp1, des1 = sift.detectAndCompute(img1,None)
28
+ kp2, des2 = sift.detectAndCompute(img2,None)
29
+ # BFMatcher with default params
30
+ bf = cv.BFMatcher()
31
+ matches = bf.knnMatch(des1,des2,k=2)
32
+ # Apply ratio test
33
+ good = []
34
+ for m,n in matches:
35
+ if m.distance < 0.75*n.distance:
36
+ good.append([m])
37
+ # cv.drawMatchesKnn expects list of lists as matches.
38
+ draw_params = dict(matchColor = (255,0,0), # draw matches in red color
39
+ singlePointColor = None,
40
+ flags = 2)
41
+
42
+ img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
43
+ Image.fromarray(img3).save("demo/sift_matches.png")
third_party/RoMa/demo/gif/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
third_party/RoMa/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ einops
3
+ torchvision
4
+ opencv-python
5
+ kornia
6
+ albumentations
7
+ loguru
8
+ tqdm
9
+ matplotlib
10
+ h5py
11
+ wandb
12
+ timm
13
+ poselib
14
+ #xformers # Optional, used for memefficient attention
third_party/RoMa/romatch/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
3
+
4
+ DEBUG_MODE = False
5
+ RANK = int(os.environ.get('RANK', default = 0))
6
+ GLOBAL_STEP = 0
7
+ STEP_SIZE = 1
8
+ LOCAL_RANK = -1
third_party/RoMa/romatch/benchmarks/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
2
+ from .scannet_benchmark import ScanNetBenchmark
3
+ from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
4
+ from .megadepth_dense_benchmark import MegadepthDenseBenchmark
5
+ from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
6
+ from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import os
5
+
6
+ from tqdm import tqdm
7
+ from romatch.utils import pose_auc
8
+ import cv2
9
+
10
+
11
+ class HpatchesHomogBenchmark:
12
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
13
+
14
+ def __init__(self, dataset_path) -> None:
15
+ seqs_dir = "hpatches-sequences-release"
16
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
17
+ self.seq_names = sorted(os.listdir(self.seqs_path))
18
+ # Ignore seqs is same as LoFTR.
19
+ self.ignore_seqs = set(
20
+ [
21
+ "i_contruction",
22
+ "i_crownnight",
23
+ "i_dc",
24
+ "i_pencils",
25
+ "i_whitebuilding",
26
+ "v_artisans",
27
+ "v_astronautis",
28
+ "v_talent",
29
+ ]
30
+ )
31
+
32
+ def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
33
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
34
+ im_A_coords = (
35
+ np.stack(
36
+ (
37
+ wq * (im_A_coords[..., 0] + 1) / 2,
38
+ hq * (im_A_coords[..., 1] + 1) / 2,
39
+ ),
40
+ axis=-1,
41
+ )
42
+ - offset
43
+ )
44
+ im_A_to_im_B = (
45
+ np.stack(
46
+ (
47
+ wsup * (im_A_to_im_B[..., 0] + 1) / 2,
48
+ hsup * (im_A_to_im_B[..., 1] + 1) / 2,
49
+ ),
50
+ axis=-1,
51
+ )
52
+ - offset
53
+ )
54
+ return im_A_coords, im_A_to_im_B
55
+
56
+ def benchmark(self, model, model_name = None):
57
+ n_matches = []
58
+ homog_dists = []
59
+ for seq_idx, seq_name in tqdm(
60
+ enumerate(self.seq_names), total=len(self.seq_names)
61
+ ):
62
+ im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
63
+ im_A = Image.open(im_A_path)
64
+ w1, h1 = im_A.size
65
+ for im_idx in range(2, 7):
66
+ im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
67
+ im_B = Image.open(im_B_path)
68
+ w2, h2 = im_B.size
69
+ H = np.loadtxt(
70
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
71
+ )
72
+ dense_matches, dense_certainty = model.match(
73
+ im_A_path, im_B_path
74
+ )
75
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
76
+ pos_a, pos_b = self.convert_coordinates(
77
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
78
+ )
79
+ try:
80
+ H_pred, inliers = cv2.findHomography(
81
+ pos_a,
82
+ pos_b,
83
+ method = cv2.RANSAC,
84
+ confidence = 0.99999,
85
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
86
+ )
87
+ except:
88
+ H_pred = None
89
+ if H_pred is None:
90
+ H_pred = np.zeros((3, 3))
91
+ H_pred[2, 2] = 1.0
92
+ corners = np.array(
93
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
94
+ )
95
+ real_warped_corners = np.dot(corners, np.transpose(H))
96
+ real_warped_corners = (
97
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
98
+ )
99
+ warped_corners = np.dot(corners, np.transpose(H_pred))
100
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
101
+ mean_dist = np.mean(
102
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
103
+ ) / (min(w2, h2) / 480.0)
104
+ homog_dists.append(mean_dist)
105
+
106
+ n_matches = np.array(n_matches)
107
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
108
+ auc = pose_auc(np.array(homog_dists), thresholds)
109
+ return {
110
+ "hpatches_homog_auc_3": auc[2],
111
+ "hpatches_homog_auc_5": auc[4],
112
+ "hpatches_homog_auc_10": auc[9],
113
+ }
third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import tqdm
4
+ from romatch.datasets import MegadepthBuilder
5
+ from romatch.utils import warp_kpts
6
+ from torch.utils.data import ConcatDataset
7
+ import romatch
8
+
9
+ class MegadepthDenseBenchmark:
10
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
11
+ mega = MegadepthBuilder(data_root=data_root)
12
+ self.dataset = ConcatDataset(
13
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
14
+ ) # fixed resolution of 384,512
15
+ self.num_samples = num_samples
16
+
17
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
18
+ b, h1, w1, d = dense_matches.shape
19
+ with torch.no_grad():
20
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
21
+ mask, x2 = warp_kpts(
22
+ x1.double(),
23
+ depth1.double(),
24
+ depth2.double(),
25
+ T_1to2.double(),
26
+ K1.double(),
27
+ K2.double(),
28
+ )
29
+ x2 = torch.stack(
30
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
31
+ )
32
+ prob = mask.float().reshape(b, h1, w1)
33
+ x2_hat = dense_matches[..., 2:]
34
+ x2_hat = torch.stack(
35
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
36
+ )
37
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
38
+ gd = gd[prob == 1]
39
+ pck_1 = (gd < 1.0).float().mean()
40
+ pck_3 = (gd < 3.0).float().mean()
41
+ pck_5 = (gd < 5.0).float().mean()
42
+ return gd, pck_1, pck_3, pck_5, prob
43
+
44
+ def benchmark(self, model, batch_size=8):
45
+ model.train(False)
46
+ with torch.no_grad():
47
+ gd_tot = 0.0
48
+ pck_1_tot = 0.0
49
+ pck_3_tot = 0.0
50
+ pck_5_tot = 0.0
51
+ sampler = torch.utils.data.WeightedRandomSampler(
52
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
53
+ )
54
+ B = batch_size
55
+ dataloader = torch.utils.data.DataLoader(
56
+ self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
57
+ )
58
+ for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
59
+ im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
60
+ data["im_A"].cuda(),
61
+ data["im_B"].cuda(),
62
+ data["im_A_depth"].cuda(),
63
+ data["im_B_depth"].cuda(),
64
+ data["T_1to2"].cuda(),
65
+ data["K1"].cuda(),
66
+ data["K2"].cuda(),
67
+ )
68
+ matches, certainty = model.match(im_A, im_B, batched=True)
69
+ gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
70
+ depth1, depth2, T_1to2, K1, K2, matches
71
+ )
72
+ if romatch.DEBUG_MODE:
73
+ from romatch.utils.utils import tensor_to_pil
74
+ import torch.nn.functional as F
75
+ path = "vis"
76
+ H, W = model.get_output_resolution()
77
+ white_im = torch.ones((B,1,H,W),device="cuda")
78
+ im_B_transfer_rgb = F.grid_sample(
79
+ im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
80
+ )
81
+ warp_im = im_B_transfer_rgb
82
+ c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
83
+ vis_im = c_b * warp_im + (1 - c_b) * white_im
84
+ for b in range(B):
85
+ import os
86
+ os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
87
+ tensor_to_pil(vis_im[b], unnormalize=True).save(
88
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
89
+ tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
90
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
91
+ tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
92
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
93
+
94
+
95
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
96
+ gd_tot + gd.mean(),
97
+ pck_1_tot + pck_1,
98
+ pck_3_tot + pck_3,
99
+ pck_5_tot + pck_5,
100
+ )
101
+ return {
102
+ "epe": gd_tot.item() / len(dataloader),
103
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
104
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
105
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
106
+ }
third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from romatch.utils import *
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+ import romatch
8
+ import kornia.geometry.epipolar as kepi
9
+
10
+ class MegaDepthPoseEstimationBenchmark:
11
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
12
+ if scene_names is None:
13
+ self.scene_names = [
14
+ "0015_0.1_0.3.npz",
15
+ "0015_0.3_0.5.npz",
16
+ "0022_0.1_0.3.npz",
17
+ "0022_0.3_0.5.npz",
18
+ "0022_0.5_0.7.npz",
19
+ ]
20
+ else:
21
+ self.scene_names = scene_names
22
+ self.scenes = [
23
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
24
+ for scene in self.scene_names
25
+ ]
26
+ self.data_root = data_root
27
+
28
+ def benchmark(self, model, model_name = None):
29
+ with torch.no_grad():
30
+ data_root = self.data_root
31
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
32
+ thresholds = [5, 10, 20]
33
+ for scene_ind in range(len(self.scenes)):
34
+ import os
35
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
36
+ scene = self.scenes[scene_ind]
37
+ pairs = scene["pair_infos"]
38
+ intrinsics = scene["intrinsics"]
39
+ poses = scene["poses"]
40
+ im_paths = scene["image_paths"]
41
+ pair_inds = range(len(pairs))
42
+ for pairind in tqdm(pair_inds):
43
+ idx1, idx2 = pairs[pairind][0]
44
+ K1 = intrinsics[idx1].copy()
45
+ T1 = poses[idx1].copy()
46
+ R1, t1 = T1[:3, :3], T1[:3, 3]
47
+ K2 = intrinsics[idx2].copy()
48
+ T2 = poses[idx2].copy()
49
+ R2, t2 = T2[:3, :3], T2[:3, 3]
50
+ R, t = compute_relative_pose(R1, t1, R2, t2)
51
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
52
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
53
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
54
+ dense_matches, dense_certainty = model.match(
55
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
56
+ )
57
+ sparse_matches,_ = model.sample(
58
+ dense_matches, dense_certainty, 5_000
59
+ )
60
+
61
+ im_A = Image.open(im_A_path)
62
+ w1, h1 = im_A.size
63
+ im_B = Image.open(im_B_path)
64
+ w2, h2 = im_B.size
65
+ if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
66
+ scale1 = 1200 / max(w1, h1)
67
+ scale2 = 1200 / max(w2, h2)
68
+ w1, h1 = scale1 * w1, scale1 * h1
69
+ w2, h2 = scale2 * w2, scale2 * h2
70
+ K1, K2 = K1.copy(), K2.copy()
71
+ K1[:2] = K1[:2] * scale1
72
+ K2[:2] = K2[:2] * scale2
73
+
74
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
75
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
76
+ for _ in range(5):
77
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
78
+ kpts1 = kpts1[shuffling]
79
+ kpts2 = kpts2[shuffling]
80
+ try:
81
+ threshold = 0.5
82
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
83
+ R_est, t_est, mask = estimate_pose(
84
+ kpts1,
85
+ kpts2,
86
+ K1,
87
+ K2,
88
+ norm_threshold,
89
+ conf=0.99999,
90
+ )
91
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
92
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
93
+ e_pose = max(e_t, e_R)
94
+ except Exception as e:
95
+ print(repr(e))
96
+ e_t, e_R = 90, 90
97
+ e_pose = max(e_t, e_R)
98
+ tot_e_t.append(e_t)
99
+ tot_e_R.append(e_R)
100
+ tot_e_pose.append(e_pose)
101
+ tot_e_pose = np.array(tot_e_pose)
102
+ auc = pose_auc(tot_e_pose, thresholds)
103
+ acc_5 = (tot_e_pose < 5).mean()
104
+ acc_10 = (tot_e_pose < 10).mean()
105
+ acc_15 = (tot_e_pose < 15).mean()
106
+ acc_20 = (tot_e_pose < 20).mean()
107
+ map_5 = acc_5
108
+ map_10 = np.mean([acc_5, acc_10])
109
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
110
+ print(f"{model_name} auc: {auc}")
111
+ return {
112
+ "auc_5": auc[0],
113
+ "auc_10": auc[1],
114
+ "auc_20": auc[2],
115
+ "map_5": map_5,
116
+ "map_10": map_10,
117
+ "map_20": map_20,
118
+ }
third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from romatch.utils import *
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+ import romatch
8
+ import kornia.geometry.epipolar as kepi
9
+
10
+ # wrap cause pyposelib is still in dev
11
+ # will add in deps later
12
+ import poselib
13
+
14
+ class Mega1500PoseLibBenchmark:
15
+ def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None:
16
+ if scene_names is None:
17
+ self.scene_names = [
18
+ "0015_0.1_0.3.npz",
19
+ "0015_0.3_0.5.npz",
20
+ "0022_0.1_0.3.npz",
21
+ "0022_0.3_0.5.npz",
22
+ "0022_0.5_0.7.npz",
23
+ ]
24
+ else:
25
+ self.scene_names = scene_names
26
+ self.scenes = [
27
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
28
+ for scene in self.scene_names
29
+ ]
30
+ self.data_root = data_root
31
+ self.num_ransac_iter = num_ransac_iter
32
+ self.test_every = test_every
33
+
34
+ def benchmark(self, model, model_name = None):
35
+ with torch.no_grad():
36
+ data_root = self.data_root
37
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
38
+ thresholds = [5, 10, 20]
39
+ for scene_ind in range(len(self.scenes)):
40
+ import os
41
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
42
+ scene = self.scenes[scene_ind]
43
+ pairs = scene["pair_infos"]
44
+ intrinsics = scene["intrinsics"]
45
+ poses = scene["poses"]
46
+ im_paths = scene["image_paths"]
47
+ pair_inds = range(len(pairs))[::self.test_every]
48
+ for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
49
+ idx1, idx2 = pairs[pairind][0]
50
+ K1 = intrinsics[idx1].copy()
51
+ T1 = poses[idx1].copy()
52
+ R1, t1 = T1[:3, :3], T1[:3, 3]
53
+ K2 = intrinsics[idx2].copy()
54
+ T2 = poses[idx2].copy()
55
+ R2, t2 = T2[:3, :3], T2[:3, 3]
56
+ R, t = compute_relative_pose(R1, t1, R2, t2)
57
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
58
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
59
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
60
+ dense_matches, dense_certainty = model.match(
61
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
62
+ )
63
+ sparse_matches,_ = model.sample(
64
+ dense_matches, dense_certainty, 5_000
65
+ )
66
+
67
+ im_A = Image.open(im_A_path)
68
+ w1, h1 = im_A.size
69
+ im_B = Image.open(im_B_path)
70
+ w2, h2 = im_B.size
71
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
72
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
73
+ for _ in range(self.num_ransac_iter):
74
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
75
+ kpts1 = kpts1[shuffling]
76
+ kpts2 = kpts2[shuffling]
77
+ try:
78
+ threshold = 1
79
+ camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]}
80
+ camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]}
81
+ relpose, res = poselib.estimate_relative_pose(
82
+ kpts1,
83
+ kpts2,
84
+ camera1,
85
+ camera2,
86
+ ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000},
87
+ )
88
+ Rt_est = relpose.Rt
89
+ R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:]
90
+ mask = np.array(res['inliers']).astype(np.float32)
91
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
92
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
93
+ e_pose = max(e_t, e_R)
94
+ except Exception as e:
95
+ print(repr(e))
96
+ e_t, e_R = 90, 90
97
+ e_pose = max(e_t, e_R)
98
+ tot_e_t.append(e_t)
99
+ tot_e_R.append(e_R)
100
+ tot_e_pose.append(e_pose)
101
+ pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
102
+ tot_e_pose = np.array(tot_e_pose)
103
+ auc = pose_auc(tot_e_pose, thresholds)
104
+ acc_5 = (tot_e_pose < 5).mean()
105
+ acc_10 = (tot_e_pose < 10).mean()
106
+ acc_15 = (tot_e_pose < 15).mean()
107
+ acc_20 = (tot_e_pose < 20).mean()
108
+ map_5 = acc_5
109
+ map_10 = np.mean([acc_5, acc_10])
110
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
111
+ print(f"{model_name} auc: {auc}")
112
+ return {
113
+ "auc_5": auc[0],
114
+ "auc_10": auc[1],
115
+ "auc_20": auc[2],
116
+ "map_5": map_5,
117
+ "map_10": map_10,
118
+ "map_20": map_20,
119
+ }
third_party/RoMa/romatch/benchmarks/scannet_benchmark.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import torch
4
+ from romatch.utils import *
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+
9
+ class ScanNetBenchmark:
10
+ def __init__(self, data_root="data/scannet") -> None:
11
+ self.data_root = data_root
12
+
13
+ def benchmark(self, model, model_name = None):
14
+ model.train(False)
15
+ with torch.no_grad():
16
+ data_root = self.data_root
17
+ tmp = np.load(osp.join(data_root, "test.npz"))
18
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
19
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
20
+ pair_inds = np.random.choice(
21
+ range(len(pairs)), size=len(pairs), replace=False
22
+ )
23
+ for pairind in tqdm(pair_inds, smoothing=0.9):
24
+ scene = pairs[pairind]
25
+ scene_name = f"scene0{scene[0]}_00"
26
+ im_A_path = osp.join(
27
+ self.data_root,
28
+ "scans_test",
29
+ scene_name,
30
+ "color",
31
+ f"{scene[2]}.jpg",
32
+ )
33
+ im_A = Image.open(im_A_path)
34
+ im_B_path = osp.join(
35
+ self.data_root,
36
+ "scans_test",
37
+ scene_name,
38
+ "color",
39
+ f"{scene[3]}.jpg",
40
+ )
41
+ im_B = Image.open(im_B_path)
42
+ T_gt = rel_pose[pairind].reshape(3, 4)
43
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
44
+ K = np.stack(
45
+ [
46
+ np.array([float(i) for i in r.split()])
47
+ for r in open(
48
+ osp.join(
49
+ self.data_root,
50
+ "scans_test",
51
+ scene_name,
52
+ "intrinsic",
53
+ "intrinsic_color.txt",
54
+ ),
55
+ "r",
56
+ )
57
+ .read()
58
+ .split("\n")
59
+ if r
60
+ ]
61
+ )
62
+ w1, h1 = im_A.size
63
+ w2, h2 = im_B.size
64
+ K1 = K.copy()
65
+ K2 = K.copy()
66
+ dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
67
+ sparse_matches, sparse_certainty = model.sample(
68
+ dense_matches, dense_certainty, 5000
69
+ )
70
+ scale1 = 480 / min(w1, h1)
71
+ scale2 = 480 / min(w2, h2)
72
+ w1, h1 = scale1 * w1, scale1 * h1
73
+ w2, h2 = scale2 * w2, scale2 * h2
74
+ K1 = K1 * scale1
75
+ K2 = K2 * scale2
76
+
77
+ offset = 0.5
78
+ kpts1 = sparse_matches[:, :2]
79
+ kpts1 = (
80
+ np.stack(
81
+ (
82
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
83
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
84
+ ),
85
+ axis=-1,
86
+ )
87
+ )
88
+ kpts2 = sparse_matches[:, 2:]
89
+ kpts2 = (
90
+ np.stack(
91
+ (
92
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
93
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
94
+ ),
95
+ axis=-1,
96
+ )
97
+ )
98
+ for _ in range(5):
99
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
100
+ kpts1 = kpts1[shuffling]
101
+ kpts2 = kpts2[shuffling]
102
+ try:
103
+ norm_threshold = 0.5 / (
104
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
105
+ R_est, t_est, mask = estimate_pose(
106
+ kpts1,
107
+ kpts2,
108
+ K1,
109
+ K2,
110
+ norm_threshold,
111
+ conf=0.99999,
112
+ )
113
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
114
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
115
+ e_pose = max(e_t, e_R)
116
+ except Exception as e:
117
+ print(repr(e))
118
+ e_t, e_R = 90, 90
119
+ e_pose = max(e_t, e_R)
120
+ tot_e_t.append(e_t)
121
+ tot_e_R.append(e_R)
122
+ tot_e_pose.append(e_pose)
123
+ tot_e_t.append(e_t)
124
+ tot_e_R.append(e_R)
125
+ tot_e_pose.append(e_pose)
126
+ tot_e_pose = np.array(tot_e_pose)
127
+ thresholds = [5, 10, 20]
128
+ auc = pose_auc(tot_e_pose, thresholds)
129
+ acc_5 = (tot_e_pose < 5).mean()
130
+ acc_10 = (tot_e_pose < 10).mean()
131
+ acc_15 = (tot_e_pose < 15).mean()
132
+ acc_20 = (tot_e_pose < 20).mean()
133
+ map_5 = acc_5
134
+ map_10 = np.mean([acc_5, acc_10])
135
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
136
+ return {
137
+ "auc_5": auc[0],
138
+ "auc_10": auc[1],
139
+ "auc_20": auc[2],
140
+ "map_5": map_5,
141
+ "map_10": map_10,
142
+ "map_20": map_20,
143
+ }
third_party/RoMa/romatch/checkpointing/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .checkpoint import CheckPoint
third_party/RoMa/romatch/checkpointing/checkpoint.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn.parallel.data_parallel import DataParallel
4
+ from torch.nn.parallel.distributed import DistributedDataParallel
5
+ from loguru import logger
6
+ import gc
7
+
8
+ import romatch
9
+
10
+ class CheckPoint:
11
+ def __init__(self, dir=None, name="tmp"):
12
+ self.name = name
13
+ self.dir = dir
14
+ os.makedirs(self.dir, exist_ok=True)
15
+
16
+ def save(
17
+ self,
18
+ model,
19
+ optimizer,
20
+ lr_scheduler,
21
+ n,
22
+ ):
23
+ if romatch.RANK == 0:
24
+ assert model is not None
25
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
26
+ model = model.module
27
+ states = {
28
+ "model": model.state_dict(),
29
+ "n": n,
30
+ "optimizer": optimizer.state_dict(),
31
+ "lr_scheduler": lr_scheduler.state_dict(),
32
+ }
33
+ torch.save(states, self.dir + self.name + f"_latest.pth")
34
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
35
+
36
+ def load(
37
+ self,
38
+ model,
39
+ optimizer,
40
+ lr_scheduler,
41
+ n,
42
+ ):
43
+ if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
44
+ states = torch.load(self.dir + self.name + f"_latest.pth")
45
+ if "model" in states:
46
+ model.load_state_dict(states["model"])
47
+ if "n" in states:
48
+ n = states["n"] if states["n"] else n
49
+ if "optimizer" in states:
50
+ try:
51
+ optimizer.load_state_dict(states["optimizer"])
52
+ except Exception as e:
53
+ print(f"Failed to load states for optimizer, with error {e}")
54
+ if "lr_scheduler" in states:
55
+ lr_scheduler.load_state_dict(states["lr_scheduler"])
56
+ print(f"Loaded states {list(states.keys())}, at step {n}")
57
+ del states
58
+ gc.collect()
59
+ torch.cuda.empty_cache()
60
+ return model, optimizer, lr_scheduler, n
third_party/RoMa/romatch/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .megadepth import MegadepthBuilder
2
+ from .scannet import ScanNetBuilder
third_party/RoMa/romatch/datasets/megadepth.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import h5py
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms.functional as tvf
7
+ import kornia.augmentation as K
8
+ from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
9
+ import romatch
10
+ from romatch.utils import *
11
+ import math
12
+
13
+ class MegadepthScene:
14
+ def __init__(
15
+ self,
16
+ data_root,
17
+ scene_info,
18
+ ht=384,
19
+ wt=512,
20
+ min_overlap=0.0,
21
+ max_overlap=1.0,
22
+ shake_t=0,
23
+ rot_prob=0.0,
24
+ normalize=True,
25
+ max_num_pairs = 100_000,
26
+ scene_name = None,
27
+ use_horizontal_flip_aug = False,
28
+ use_single_horizontal_flip_aug = False,
29
+ colorjiggle_params = None,
30
+ random_eraser = None,
31
+ use_randaug = False,
32
+ randaug_params = None,
33
+ randomize_size = False,
34
+ ) -> None:
35
+ self.data_root = data_root
36
+ self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
37
+ self.image_paths = scene_info["image_paths"]
38
+ self.depth_paths = scene_info["depth_paths"]
39
+ self.intrinsics = scene_info["intrinsics"]
40
+ self.poses = scene_info["poses"]
41
+ self.pairs = scene_info["pairs"]
42
+ self.overlaps = scene_info["overlaps"]
43
+ threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
44
+ self.pairs = self.pairs[threshold]
45
+ self.overlaps = self.overlaps[threshold]
46
+ if len(self.pairs) > max_num_pairs:
47
+ pairinds = np.random.choice(
48
+ np.arange(0, len(self.pairs)), max_num_pairs, replace=False
49
+ )
50
+ self.pairs = self.pairs[pairinds]
51
+ self.overlaps = self.overlaps[pairinds]
52
+ if randomize_size:
53
+ area = ht * wt
54
+ s = int(16 * (math.sqrt(area)//16))
55
+ sizes = ((ht,wt), (s,s), (wt,ht))
56
+ choice = romatch.RANK % 3
57
+ ht, wt = sizes[choice]
58
+ # counts, bins = np.histogram(self.overlaps,20)
59
+ # print(counts)
60
+ self.im_transform_ops = get_tuple_transform_ops(
61
+ resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
62
+ )
63
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
64
+ resize=(ht, wt)
65
+ )
66
+ self.wt, self.ht = wt, ht
67
+ self.shake_t = shake_t
68
+ self.random_eraser = random_eraser
69
+ if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
70
+ raise ValueError("Can't both flip both images and only flip one")
71
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
72
+ self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
73
+ self.use_randaug = use_randaug
74
+
75
+ def load_im(self, im_path):
76
+ im = Image.open(im_path)
77
+ return im
78
+
79
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
80
+ im_A = im_A.flip(-1)
81
+ im_B = im_B.flip(-1)
82
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
83
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
84
+ K_A = flip_mat@K_A
85
+ K_B = flip_mat@K_B
86
+
87
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
88
+
89
+ def load_depth(self, depth_ref, crop=None):
90
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
91
+ return torch.from_numpy(depth)
92
+
93
+ def __len__(self):
94
+ return len(self.pairs)
95
+
96
+ def scale_intrinsic(self, K, wi, hi):
97
+ sx, sy = self.wt / wi, self.ht / hi
98
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
99
+ return sK @ K
100
+
101
+ def rand_shake(self, *things):
102
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
103
+ return [
104
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
105
+ for thing in things
106
+ ], t
107
+
108
+ def __getitem__(self, pair_idx):
109
+ # read intrinsics of original size
110
+ idx1, idx2 = self.pairs[pair_idx]
111
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
112
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
113
+
114
+ # read and compute relative poses
115
+ T1 = self.poses[idx1]
116
+ T2 = self.poses[idx2]
117
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
118
+ :4, :4
119
+ ] # (4, 4)
120
+
121
+ # Load positive pair data
122
+ im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
123
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
124
+ im_A_ref = os.path.join(self.data_root, im_A)
125
+ im_B_ref = os.path.join(self.data_root, im_B)
126
+ depth_A_ref = os.path.join(self.data_root, depth1)
127
+ depth_B_ref = os.path.join(self.data_root, depth2)
128
+ im_A = self.load_im(im_A_ref)
129
+ im_B = self.load_im(im_B_ref)
130
+ K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
131
+ K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
132
+
133
+ if self.use_randaug:
134
+ im_A, im_B = self.rand_augment(im_A, im_B)
135
+
136
+ depth_A = self.load_depth(depth_A_ref)
137
+ depth_B = self.load_depth(depth_B_ref)
138
+ # Process images
139
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
140
+ depth_A, depth_B = self.depth_transform_ops(
141
+ (depth_A[None, None], depth_B[None, None])
142
+ )
143
+
144
+ [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
145
+ K1[:2, 2] += t
146
+ K2[:2, 2] += t
147
+
148
+ im_A, im_B = im_A[None], im_B[None]
149
+ if self.random_eraser is not None:
150
+ im_A, depth_A = self.random_eraser(im_A, depth_A)
151
+ im_B, depth_B = self.random_eraser(im_B, depth_B)
152
+
153
+ if self.use_horizontal_flip_aug:
154
+ if np.random.rand() > 0.5:
155
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
156
+ if self.use_single_horizontal_flip_aug:
157
+ if np.random.rand() > 0.5:
158
+ im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
159
+
160
+ if romatch.DEBUG_MODE:
161
+ tensor_to_pil(im_A[0], unnormalize=True).save(
162
+ f"vis/im_A.jpg")
163
+ tensor_to_pil(im_B[0], unnormalize=True).save(
164
+ f"vis/im_B.jpg")
165
+
166
+ data_dict = {
167
+ "im_A": im_A[0],
168
+ "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
169
+ "im_B": im_B[0],
170
+ "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
171
+ "im_A_depth": depth_A[0, 0],
172
+ "im_B_depth": depth_B[0, 0],
173
+ "K1": K1,
174
+ "K2": K2,
175
+ "T_1to2": T_1to2,
176
+ "im_A_path": im_A_ref,
177
+ "im_B_path": im_B_ref,
178
+
179
+ }
180
+ return data_dict
181
+
182
+
183
+ class MegadepthBuilder:
184
+ def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
185
+ self.data_root = data_root
186
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
187
+ self.all_scenes = os.listdir(self.scene_info_root)
188
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
189
+ # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
190
+ self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
191
+ self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
192
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
193
+ self.loftr_ignore = loftr_ignore
194
+ self.imc21_ignore = imc21_ignore
195
+
196
+ def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
197
+ if split == "train":
198
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
199
+ elif split == "train_loftr":
200
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
201
+ elif split == "test":
202
+ scene_names = self.test_scenes
203
+ elif split == "test_loftr":
204
+ scene_names = self.test_scenes_loftr
205
+ elif split == "custom":
206
+ scene_names = scene_names
207
+ else:
208
+ raise ValueError(f"Split {split} not available")
209
+ scenes = []
210
+ for scene_name in scene_names:
211
+ if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
212
+ continue
213
+ if self.imc21_ignore and scene_name in self.imc21_scenes:
214
+ continue
215
+ if ".npy" not in scene_name:
216
+ continue
217
+ scene_info = np.load(
218
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
219
+ ).item()
220
+ scenes.append(
221
+ MegadepthScene(
222
+ self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
223
+ )
224
+ )
225
+ return scenes
226
+
227
+ def weight_scenes(self, concat_dataset, alpha=0.5):
228
+ ns = []
229
+ for d in concat_dataset.datasets:
230
+ ns.append(len(d))
231
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
232
+ return ws
third_party/RoMa/romatch/datasets/scannet.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import cv2
5
+ import h5py
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import (
9
+ Dataset,
10
+ DataLoader,
11
+ ConcatDataset)
12
+
13
+ import torchvision.transforms.functional as tvf
14
+ import kornia.augmentation as K
15
+ import os.path as osp
16
+ import matplotlib.pyplot as plt
17
+ import romatch
18
+ from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
19
+ from romatch.utils.transforms import GeometricSequential
20
+ from tqdm import tqdm
21
+
22
+ class ScanNetScene:
23
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
24
+ ) -> None:
25
+ self.scene_root = osp.join(data_root,"scans","scans_train")
26
+ self.data_names = scene_info['name']
27
+ self.overlaps = scene_info['score']
28
+ # Only sample 10s
29
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
30
+ self.overlaps = self.overlaps[valid]
31
+ self.data_names = self.data_names[valid]
32
+ if len(self.data_names) > 10000:
33
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
34
+ self.data_names = self.data_names[pairinds]
35
+ self.overlaps = self.overlaps[pairinds]
36
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
37
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
38
+ self.wt, self.ht = wt, ht
39
+ self.shake_t = shake_t
40
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
41
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
42
+
43
+ def load_im(self, im_B, crop=None):
44
+ im = Image.open(im_B)
45
+ return im
46
+
47
+ def load_depth(self, depth_ref, crop=None):
48
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
49
+ depth = depth / 1000
50
+ depth = torch.from_numpy(depth).float() # (h, w)
51
+ return depth
52
+
53
+ def __len__(self):
54
+ return len(self.data_names)
55
+
56
+ def scale_intrinsic(self, K, wi, hi):
57
+ sx, sy = self.wt / wi, self.ht / hi
58
+ sK = torch.tensor([[sx, 0, 0],
59
+ [0, sy, 0],
60
+ [0, 0, 1]])
61
+ return sK@K
62
+
63
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
64
+ im_A = im_A.flip(-1)
65
+ im_B = im_B.flip(-1)
66
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
67
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
68
+ K_A = flip_mat@K_A
69
+ K_B = flip_mat@K_B
70
+
71
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
72
+ def read_scannet_pose(self,path):
73
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
74
+
75
+ Returns:
76
+ pose_w2c (np.ndarray): (4, 4)
77
+ """
78
+ cam2world = np.loadtxt(path, delimiter=' ')
79
+ world2cam = np.linalg.inv(cam2world)
80
+ return world2cam
81
+
82
+
83
+ def read_scannet_intrinsic(self,path):
84
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
85
+ """
86
+ intrinsic = np.loadtxt(path, delimiter=' ')
87
+ return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
88
+
89
+ def __getitem__(self, pair_idx):
90
+ # read intrinsics of original size
91
+ data_name = self.data_names[pair_idx]
92
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
93
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
94
+
95
+ # read the intrinsic of depthmap
96
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
97
+ scene_name,
98
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
99
+ # read and compute relative poses
100
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
101
+ scene_name,
102
+ 'pose', f'{stem_name_1}.txt'))
103
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
104
+ scene_name,
105
+ 'pose', f'{stem_name_2}.txt'))
106
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
107
+
108
+ # Load positive pair data
109
+ im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
110
+ im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
111
+ depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
112
+ depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
113
+
114
+ im_A = self.load_im(im_A_ref)
115
+ im_B = self.load_im(im_B_ref)
116
+ depth_A = self.load_depth(depth_A_ref)
117
+ depth_B = self.load_depth(depth_B_ref)
118
+
119
+ # Recompute camera intrinsic matrix due to the resize
120
+ K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
121
+ K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
122
+ # Process images
123
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
124
+ depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
125
+ if self.use_horizontal_flip_aug:
126
+ if np.random.rand() > 0.5:
127
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
128
+
129
+ data_dict = {'im_A': im_A,
130
+ 'im_B': im_B,
131
+ 'im_A_depth': depth_A[0,0],
132
+ 'im_B_depth': depth_B[0,0],
133
+ 'K1': K1,
134
+ 'K2': K2,
135
+ 'T_1to2':T_1to2,
136
+ }
137
+ return data_dict
138
+
139
+
140
+ class ScanNetBuilder:
141
+ def __init__(self, data_root = 'data/scannet') -> None:
142
+ self.data_root = data_root
143
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
144
+ self.all_scenes = os.listdir(self.scene_info_root)
145
+
146
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
147
+ # Note: split doesn't matter here as we always use same scannet_train scenes
148
+ scene_names = self.all_scenes
149
+ scenes = []
150
+ for scene_name in tqdm(scene_names, disable = romatch.RANK > 0):
151
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
152
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
153
+ return scenes
154
+
155
+ def weight_scenes(self, concat_dataset, alpha=.5):
156
+ ns = []
157
+ for d in concat_dataset.datasets:
158
+ ns.append(len(d))
159
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
160
+ return ws
third_party/RoMa/romatch/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .robust_loss import RobustLosses
third_party/RoMa/romatch/losses/robust_loss.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops.einops import rearrange
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from romatch.utils.utils import get_gt_warp
6
+ import wandb
7
+ import romatch
8
+ import math
9
+
10
+ class RobustLosses(nn.Module):
11
+ def __init__(
12
+ self,
13
+ robust=False,
14
+ center_coords=False,
15
+ scale_normalize=False,
16
+ ce_weight=0.01,
17
+ local_loss=True,
18
+ local_dist=4.0,
19
+ local_largest_scale=8,
20
+ smooth_mask = False,
21
+ depth_interpolation_mode = "bilinear",
22
+ mask_depth_loss = False,
23
+ relative_depth_error_threshold = 0.05,
24
+ alpha = 1.,
25
+ c = 1e-3,
26
+ ):
27
+ super().__init__()
28
+ self.robust = robust # measured in pixels
29
+ self.center_coords = center_coords
30
+ self.scale_normalize = scale_normalize
31
+ self.ce_weight = ce_weight
32
+ self.local_loss = local_loss
33
+ self.local_dist = local_dist
34
+ self.local_largest_scale = local_largest_scale
35
+ self.smooth_mask = smooth_mask
36
+ self.depth_interpolation_mode = depth_interpolation_mode
37
+ self.mask_depth_loss = mask_depth_loss
38
+ self.relative_depth_error_threshold = relative_depth_error_threshold
39
+ self.avg_overlap = dict()
40
+ self.alpha = alpha
41
+ self.c = c
42
+
43
+ def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
44
+ with torch.no_grad():
45
+ B, C, H, W = scale_gm_cls.shape
46
+ device = x2.device
47
+ cls_res = round(math.sqrt(C))
48
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
49
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
50
+ GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
51
+ cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
52
+ certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
53
+ if not torch.any(cls_loss):
54
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
55
+
56
+ losses = {
57
+ f"gm_certainty_loss_{scale}": certainty_loss.mean(),
58
+ f"gm_cls_loss_{scale}": cls_loss.mean(),
59
+ }
60
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
61
+ return losses
62
+
63
+ def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
64
+ with torch.no_grad():
65
+ B, C, H, W = delta_cls.shape
66
+ device = x2.device
67
+ cls_res = round(math.sqrt(C))
68
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
69
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
70
+ GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
71
+ cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
72
+ if not torch.any(cls_loss):
73
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
74
+ certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
75
+ losses = {
76
+ f"delta_certainty_loss_{scale}": certainty_loss.mean(),
77
+ f"delta_cls_loss_{scale}": cls_loss.mean(),
78
+ }
79
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
80
+ return losses
81
+
82
+ def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
83
+ epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
84
+ if scale == 1:
85
+ pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
86
+ wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
87
+
88
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
89
+ a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
90
+ cs = self.c * scale
91
+ x = epe[prob > 0.99]
92
+ reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
93
+ if not torch.any(reg_loss):
94
+ reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
95
+ losses = {
96
+ f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
97
+ f"{mode}_regression_loss_{scale}": reg_loss.mean(),
98
+ }
99
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
100
+ return losses
101
+
102
+ def forward(self, corresps, batch):
103
+ scales = list(corresps.keys())
104
+ tot_loss = 0.0
105
+ # scale_weights due to differences in scale for regression gradients and classification gradients
106
+ scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
107
+ for scale in scales:
108
+ scale_corresps = corresps[scale]
109
+ scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
110
+ scale_corresps["certainty"],
111
+ scale_corresps.get("flow_pre_delta"),
112
+ scale_corresps.get("delta_cls"),
113
+ scale_corresps.get("offset_scale"),
114
+ scale_corresps.get("gm_cls"),
115
+ scale_corresps.get("gm_certainty"),
116
+ scale_corresps["flow"],
117
+ scale_corresps.get("gm_flow"),
118
+
119
+ )
120
+ if flow_pre_delta is not None:
121
+ flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
122
+ b, h, w, d = flow_pre_delta.shape
123
+ else:
124
+ # _ = 1
125
+ b, _, h, w = scale_certainty.shape
126
+ gt_warp, gt_prob = get_gt_warp(
127
+ batch["im_A_depth"],
128
+ batch["im_B_depth"],
129
+ batch["T_1to2"],
130
+ batch["K1"],
131
+ batch["K2"],
132
+ H=h,
133
+ W=w,
134
+ )
135
+ x2 = gt_warp.float()
136
+ prob = gt_prob
137
+
138
+ if self.local_largest_scale >= scale:
139
+ prob = prob * (
140
+ F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
141
+ < (2 / 512) * (self.local_dist[scale] * scale))
142
+
143
+ if scale_gm_cls is not None:
144
+ gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
145
+ gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
146
+ tot_loss = tot_loss + scale_weights[scale] * gm_loss
147
+ elif scale_gm_flow is not None:
148
+ gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
149
+ gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
150
+ tot_loss = tot_loss + scale_weights[scale] * gm_loss
151
+
152
+ if delta_cls is not None:
153
+ delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
154
+ delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
155
+ tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
156
+ else:
157
+ delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
158
+ reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
159
+ tot_loss = tot_loss + scale_weights[scale] * reg_loss
160
+ prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
161
+ return tot_loss
third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops.einops import rearrange
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from romatch.utils.utils import get_gt_warp
6
+ import wandb
7
+ import romatch
8
+ import math
9
+
10
+ # This is slightly different than regular romatch due to significantly worse corresps
11
+ # The confidence loss is quite tricky here //Johan
12
+
13
+ class RobustLosses(nn.Module):
14
+ def __init__(
15
+ self,
16
+ robust=False,
17
+ center_coords=False,
18
+ scale_normalize=False,
19
+ ce_weight=0.01,
20
+ local_loss=True,
21
+ local_dist=None,
22
+ smooth_mask = False,
23
+ depth_interpolation_mode = "bilinear",
24
+ mask_depth_loss = False,
25
+ relative_depth_error_threshold = 0.05,
26
+ alpha = 1.,
27
+ c = 1e-3,
28
+ epe_mask_prob_th = None,
29
+ cert_only_on_consistent_depth = False,
30
+ ):
31
+ super().__init__()
32
+ if local_dist is None:
33
+ local_dist = {}
34
+ self.robust = robust # measured in pixels
35
+ self.center_coords = center_coords
36
+ self.scale_normalize = scale_normalize
37
+ self.ce_weight = ce_weight
38
+ self.local_loss = local_loss
39
+ self.local_dist = local_dist
40
+ self.smooth_mask = smooth_mask
41
+ self.depth_interpolation_mode = depth_interpolation_mode
42
+ self.mask_depth_loss = mask_depth_loss
43
+ self.relative_depth_error_threshold = relative_depth_error_threshold
44
+ self.avg_overlap = dict()
45
+ self.alpha = alpha
46
+ self.c = c
47
+ self.epe_mask_prob_th = epe_mask_prob_th
48
+ self.cert_only_on_consistent_depth = cert_only_on_consistent_depth
49
+
50
+ def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale):
51
+ b, h,w, h,w = corr_volume.shape
52
+ inv_temp = 10
53
+ corr_volume = corr_volume.reshape(-1, h*w, h*w)
54
+ nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2)
55
+ corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean()
56
+
57
+ losses = {
58
+ f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(),
59
+ }
60
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
61
+ return losses
62
+
63
+
64
+
65
+ def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
66
+ epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
67
+ if scale in self.local_dist:
68
+ prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float()
69
+ if scale == 1:
70
+ pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
71
+ wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
72
+ if self.epe_mask_prob_th is not None:
73
+ # if too far away from gt, certainty should be 0
74
+ gt_cert = prob * (epe < scale * self.epe_mask_prob_th)
75
+ else:
76
+ gt_cert = prob
77
+ if self.cert_only_on_consistent_depth:
78
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0])
79
+ else:
80
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert)
81
+ a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
82
+ cs = self.c * scale
83
+ x = epe[prob > 0.99]
84
+ reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
85
+ if not torch.any(reg_loss):
86
+ reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
87
+ losses = {
88
+ f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
89
+ f"{mode}_regression_loss_{scale}": reg_loss.mean(),
90
+ }
91
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
92
+ return losses
93
+
94
+ def forward(self, corresps, batch):
95
+ scales = list(corresps.keys())
96
+ tot_loss = 0.0
97
+ # scale_weights due to differences in scale for regression gradients and classification gradients
98
+ for scale in scales:
99
+ scale_corresps = corresps[scale]
100
+ scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = (
101
+ scale_corresps["certainty"],
102
+ scale_corresps.get("flow_pre_delta"),
103
+ scale_corresps.get("delta_cls"),
104
+ scale_corresps.get("offset_scale"),
105
+ scale_corresps.get("corr_volume"),
106
+ scale_corresps.get("gm_certainty"),
107
+ scale_corresps["flow"],
108
+ scale_corresps.get("gm_flow"),
109
+
110
+ )
111
+ if flow_pre_delta is not None:
112
+ flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
113
+ b, h, w, d = flow_pre_delta.shape
114
+ else:
115
+ # _ = 1
116
+ b, _, h, w = scale_certainty.shape
117
+ gt_warp, gt_prob = get_gt_warp(
118
+ batch["im_A_depth"],
119
+ batch["im_B_depth"],
120
+ batch["T_1to2"],
121
+ batch["K1"],
122
+ batch["K2"],
123
+ H=h,
124
+ W=w,
125
+ )
126
+ x2 = gt_warp.float()
127
+ prob = gt_prob
128
+
129
+ if scale_gm_corr_volume is not None:
130
+ gt_warp_back, _ = get_gt_warp(
131
+ batch["im_B_depth"],
132
+ batch["im_A_depth"],
133
+ batch["T_1to2"].inverse(),
134
+ batch["K2"],
135
+ batch["K1"],
136
+ H=h,
137
+ W=w,
138
+ )
139
+ grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device)
140
+ #fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1)
141
+ #diff = (fwd_bck - grid).norm(dim = -1)
142
+ with torch.no_grad():
143
+ D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2))
144
+ D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2))
145
+ inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
146
+ * (D_A == D_A.min(dim=-2, keepdim = True).values)
147
+ * (D_B < 0.01)
148
+ * (D_A < 0.01))
149
+
150
+ gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale)
151
+ gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"]
152
+ tot_loss = tot_loss + gm_loss
153
+ elif scale_gm_flow is not None:
154
+ gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
155
+ gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
156
+ tot_loss = tot_loss + gm_loss
157
+ delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
158
+ reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
159
+ tot_loss = tot_loss + reg_loss
160
+ return tot_loss
third_party/RoMa/romatch/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_zoo import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
third_party/RoMa/romatch/models/encoders.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ import torch
3
+ from torch import device
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as tvm
7
+ import gc
8
+
9
+
10
+ class ResNet50(nn.Module):
11
+ def __init__(self, pretrained=False, high_res = False, weights = None,
12
+ dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
13
+ super().__init__()
14
+ if dilation is None:
15
+ dilation = [False,False,False]
16
+ if anti_aliased:
17
+ pass
18
+ else:
19
+ if weights is not None:
20
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
21
+ else:
22
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
23
+
24
+ self.high_res = high_res
25
+ self.freeze_bn = freeze_bn
26
+ self.early_exit = early_exit
27
+ self.amp = amp
28
+ self.amp_dtype = amp_dtype
29
+
30
+ def forward(self, x, **kwargs):
31
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
32
+ net = self.net
33
+ feats = {1:x}
34
+ x = net.conv1(x)
35
+ x = net.bn1(x)
36
+ x = net.relu(x)
37
+ feats[2] = x
38
+ x = net.maxpool(x)
39
+ x = net.layer1(x)
40
+ feats[4] = x
41
+ x = net.layer2(x)
42
+ feats[8] = x
43
+ if self.early_exit:
44
+ return feats
45
+ x = net.layer3(x)
46
+ feats[16] = x
47
+ x = net.layer4(x)
48
+ feats[32] = x
49
+ return feats
50
+
51
+ def train(self, mode=True):
52
+ super().train(mode)
53
+ if self.freeze_bn:
54
+ for m in self.modules():
55
+ if isinstance(m, nn.BatchNorm2d):
56
+ m.eval()
57
+ pass
58
+
59
+ class VGG19(nn.Module):
60
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
61
+ super().__init__()
62
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
63
+ self.amp = amp
64
+ self.amp_dtype = amp_dtype
65
+
66
+ def forward(self, x, **kwargs):
67
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
68
+ feats = {}
69
+ scale = 1
70
+ for layer in self.layers:
71
+ if isinstance(layer, nn.MaxPool2d):
72
+ feats[scale] = x
73
+ scale = scale*2
74
+ x = layer(x)
75
+ return feats
76
+
77
+ class CNNandDinov2(nn.Module):
78
+ def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
79
+ super().__init__()
80
+ if dinov2_weights is None:
81
+ dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
82
+ from .transformer import vit_large
83
+ vit_kwargs = dict(img_size= 518,
84
+ patch_size= 14,
85
+ init_values = 1.0,
86
+ ffn_layer = "mlp",
87
+ block_chunks = 0,
88
+ )
89
+
90
+ dinov2_vitl14 = vit_large(**vit_kwargs).eval()
91
+ dinov2_vitl14.load_state_dict(dinov2_weights)
92
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
93
+ if not use_vgg:
94
+ self.cnn = ResNet50(**cnn_kwargs)
95
+ else:
96
+ self.cnn = VGG19(**cnn_kwargs)
97
+ self.amp = amp
98
+ self.amp_dtype = amp_dtype
99
+ if self.amp:
100
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
101
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
102
+
103
+
104
+ def train(self, mode: bool = True):
105
+ return self.cnn.train(mode)
106
+
107
+ def forward(self, x, upsample = False):
108
+ B,C,H,W = x.shape
109
+ feature_pyramid = self.cnn(x)
110
+
111
+ if not upsample:
112
+ with torch.no_grad():
113
+ if self.dinov2_vitl14[0].device != x.device:
114
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
115
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
116
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
117
+ del dinov2_features_16
118
+ feature_pyramid[16] = features_16
119
+ return feature_pyramid
third_party/RoMa/romatch/models/matcher.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ import warnings
9
+ from warnings import warn
10
+ from PIL import Image
11
+
12
+ import romatch
13
+ from romatch.utils import get_tuple_transform_ops
14
+ from romatch.utils.local_correlation import local_correlation
15
+ from romatch.utils.utils import cls_to_flow_refine
16
+ from romatch.utils.kde import kde
17
+ from typing import Union
18
+
19
+ class ConvRefiner(nn.Module):
20
+ def __init__(
21
+ self,
22
+ in_dim=6,
23
+ hidden_dim=16,
24
+ out_dim=2,
25
+ dw=False,
26
+ kernel_size=5,
27
+ hidden_blocks=3,
28
+ displacement_emb = None,
29
+ displacement_emb_dim = None,
30
+ local_corr_radius = None,
31
+ corr_in_other = None,
32
+ no_im_B_fm = False,
33
+ amp = False,
34
+ concat_logits = False,
35
+ use_bias_block_1 = True,
36
+ use_cosine_corr = False,
37
+ disable_local_corr_grad = False,
38
+ is_classifier = False,
39
+ sample_mode = "bilinear",
40
+ norm_type = nn.BatchNorm2d,
41
+ bn_momentum = 0.1,
42
+ amp_dtype = torch.float16,
43
+ ):
44
+ super().__init__()
45
+ self.bn_momentum = bn_momentum
46
+ self.block1 = self.create_block(
47
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
48
+ )
49
+ self.hidden_blocks = nn.Sequential(
50
+ *[
51
+ self.create_block(
52
+ hidden_dim,
53
+ hidden_dim,
54
+ dw=dw,
55
+ kernel_size=kernel_size,
56
+ norm_type=norm_type,
57
+ )
58
+ for hb in range(hidden_blocks)
59
+ ]
60
+ )
61
+ self.hidden_blocks = self.hidden_blocks
62
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
63
+ if displacement_emb:
64
+ self.has_displacement_emb = True
65
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
66
+ else:
67
+ self.has_displacement_emb = False
68
+ self.local_corr_radius = local_corr_radius
69
+ self.corr_in_other = corr_in_other
70
+ self.no_im_B_fm = no_im_B_fm
71
+ self.amp = amp
72
+ self.concat_logits = concat_logits
73
+ self.use_cosine_corr = use_cosine_corr
74
+ self.disable_local_corr_grad = disable_local_corr_grad
75
+ self.is_classifier = is_classifier
76
+ self.sample_mode = sample_mode
77
+ self.amp_dtype = amp_dtype
78
+
79
+ def create_block(
80
+ self,
81
+ in_dim,
82
+ out_dim,
83
+ dw=False,
84
+ kernel_size=5,
85
+ bias = True,
86
+ norm_type = nn.BatchNorm2d,
87
+ ):
88
+ num_groups = 1 if not dw else in_dim
89
+ if dw:
90
+ assert (
91
+ out_dim % in_dim == 0
92
+ ), "outdim must be divisible by indim for depthwise"
93
+ conv1 = nn.Conv2d(
94
+ in_dim,
95
+ out_dim,
96
+ kernel_size=kernel_size,
97
+ stride=1,
98
+ padding=kernel_size // 2,
99
+ groups=num_groups,
100
+ bias=bias,
101
+ )
102
+ norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
103
+ relu = nn.ReLU(inplace=True)
104
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
105
+ return nn.Sequential(conv1, norm, relu, conv2)
106
+
107
+ def forward(self, x, y, flow, scale_factor = 1, logits = None):
108
+ b,c,hs,ws = x.shape
109
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
110
+ with torch.no_grad():
111
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
112
+ if self.has_displacement_emb:
113
+ im_A_coords = torch.meshgrid(
114
+ (
115
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
116
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
117
+ )
118
+ )
119
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
120
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
121
+ in_displacement = flow-im_A_coords
122
+ emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
123
+ if self.local_corr_radius:
124
+ if self.corr_in_other:
125
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
126
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
127
+ sample_mode = self.sample_mode)
128
+ else:
129
+ raise NotImplementedError("Local corr in own frame should not be used.")
130
+ if self.no_im_B_fm:
131
+ x_hat = torch.zeros_like(x)
132
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
133
+ else:
134
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
135
+ else:
136
+ if self.no_im_B_fm:
137
+ x_hat = torch.zeros_like(x)
138
+ d = torch.cat((x, x_hat), dim=1)
139
+ if self.concat_logits:
140
+ d = torch.cat((d, logits), dim=1)
141
+ d = self.block1(d)
142
+ d = self.hidden_blocks(d)
143
+ d = self.out_conv(d.float())
144
+ displacement, certainty = d[:, :-1], d[:, -1:]
145
+ return displacement, certainty
146
+
147
+ class CosKernel(nn.Module): # similar to softmax kernel
148
+ def __init__(self, T, learn_temperature=False):
149
+ super().__init__()
150
+ self.learn_temperature = learn_temperature
151
+ if self.learn_temperature:
152
+ self.T = nn.Parameter(torch.tensor(T))
153
+ else:
154
+ self.T = T
155
+
156
+ def __call__(self, x, y, eps=1e-6):
157
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
158
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
159
+ )
160
+ if self.learn_temperature:
161
+ T = self.T.abs() + 0.01
162
+ else:
163
+ T = torch.tensor(self.T, device=c.device)
164
+ K = ((c - 1.0) / T).exp()
165
+ return K
166
+
167
+ class GP(nn.Module):
168
+ def __init__(
169
+ self,
170
+ kernel,
171
+ T=1,
172
+ learn_temperature=False,
173
+ only_attention=False,
174
+ gp_dim=64,
175
+ basis="fourier",
176
+ covar_size=5,
177
+ only_nearest_neighbour=False,
178
+ sigma_noise=0.1,
179
+ no_cov=False,
180
+ predict_features = False,
181
+ ):
182
+ super().__init__()
183
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
184
+ self.sigma_noise = sigma_noise
185
+ self.covar_size = covar_size
186
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
187
+ self.only_attention = only_attention
188
+ self.only_nearest_neighbour = only_nearest_neighbour
189
+ self.basis = basis
190
+ self.no_cov = no_cov
191
+ self.dim = gp_dim
192
+ self.predict_features = predict_features
193
+
194
+ def get_local_cov(self, cov):
195
+ K = self.covar_size
196
+ b, h, w, h, w = cov.shape
197
+ hw = h * w
198
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
199
+ delta = torch.stack(
200
+ torch.meshgrid(
201
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
202
+ ),
203
+ dim=-1,
204
+ )
205
+ positions = torch.stack(
206
+ torch.meshgrid(
207
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
208
+ ),
209
+ dim=-1,
210
+ )
211
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
212
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
213
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
214
+ :,
215
+ points.flatten(),
216
+ neighbours[..., 0].flatten(),
217
+ neighbours[..., 1].flatten(),
218
+ ].reshape(b, h, w, K**2)
219
+ return local_cov
220
+
221
+ def reshape(self, x):
222
+ return rearrange(x, "b d h w -> b (h w) d")
223
+
224
+ def project_to_basis(self, x):
225
+ if self.basis == "fourier":
226
+ return torch.cos(8 * math.pi * self.pos_conv(x))
227
+ elif self.basis == "linear":
228
+ return self.pos_conv(x)
229
+ else:
230
+ raise ValueError(
231
+ "No other bases other than fourier and linear currently im_Bed in public release"
232
+ )
233
+
234
+ def get_pos_enc(self, y):
235
+ b, c, h, w = y.shape
236
+ coarse_coords = torch.meshgrid(
237
+ (
238
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
239
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
240
+ )
241
+ )
242
+
243
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
244
+ None
245
+ ].expand(b, h, w, 2)
246
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
247
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
248
+ return coarse_embedded_coords
249
+
250
+ def forward(self, x, y, **kwargs):
251
+ b, c, h1, w1 = x.shape
252
+ b, c, h2, w2 = y.shape
253
+ f = self.get_pos_enc(y)
254
+ b, d, h2, w2 = f.shape
255
+ x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
256
+ K_xx = self.K(x, x)
257
+ K_yy = self.K(y, y)
258
+ K_xy = self.K(x, y)
259
+ K_yx = K_xy.permute(0, 2, 1)
260
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
261
+ with warnings.catch_warnings():
262
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
263
+
264
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
265
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
266
+ if not self.no_cov:
267
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
268
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
269
+ local_cov_x = self.get_local_cov(cov_x)
270
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
271
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
272
+ else:
273
+ gp_feats = mu_x
274
+ return gp_feats
275
+
276
+ class Decoder(nn.Module):
277
+ def __init__(
278
+ self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
279
+ num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
280
+ flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
281
+ ):
282
+ super().__init__()
283
+ self.embedding_decoder = embedding_decoder
284
+ self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
285
+ self.gps = gps
286
+ self.proj = proj
287
+ self.conv_refiner = conv_refiner
288
+ self.detach = detach
289
+ if pos_embeddings is None:
290
+ self.pos_embeddings = {}
291
+ else:
292
+ self.pos_embeddings = pos_embeddings
293
+ if scales == "all":
294
+ self.scales = ["32", "16", "8", "4", "2", "1"]
295
+ else:
296
+ self.scales = scales
297
+ self.warp_noise_std = warp_noise_std
298
+ self.refine_init = 4
299
+ self.displacement_dropout_p = displacement_dropout_p
300
+ self.gm_warp_dropout_p = gm_warp_dropout_p
301
+ self.flow_upsample_mode = flow_upsample_mode
302
+ self.amp_dtype = amp_dtype
303
+
304
+ def get_placeholder_flow(self, b, h, w, device):
305
+ coarse_coords = torch.meshgrid(
306
+ (
307
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
308
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
309
+ )
310
+ )
311
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
312
+ None
313
+ ].expand(b, h, w, 2)
314
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
315
+ return coarse_coords
316
+
317
+ def get_positional_embedding(self, b, h ,w, device):
318
+ coarse_coords = torch.meshgrid(
319
+ (
320
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
321
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
322
+ )
323
+ )
324
+
325
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
326
+ None
327
+ ].expand(b, h, w, 2)
328
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
329
+ coarse_embedded_coords = self.pos_embedding(coarse_coords)
330
+ return coarse_embedded_coords
331
+
332
+ def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
333
+ coarse_scales = self.embedding_decoder.scales()
334
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
335
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
336
+ h, w = sizes[1]
337
+ b = f1[1].shape[0]
338
+ device = f1[1].device
339
+ coarsest_scale = int(all_scales[0])
340
+ old_stuff = torch.zeros(
341
+ b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
342
+ )
343
+ corresps = {}
344
+ if not upsample:
345
+ flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
346
+ certainty = 0.0
347
+ else:
348
+ flow = F.interpolate(
349
+ flow,
350
+ size=sizes[coarsest_scale],
351
+ align_corners=False,
352
+ mode="bilinear",
353
+ )
354
+ certainty = F.interpolate(
355
+ certainty,
356
+ size=sizes[coarsest_scale],
357
+ align_corners=False,
358
+ mode="bilinear",
359
+ )
360
+ displacement = 0.0
361
+ for new_scale in all_scales:
362
+ ins = int(new_scale)
363
+ corresps[ins] = {}
364
+ f1_s, f2_s = f1[ins], f2[ins]
365
+ if new_scale in self.proj:
366
+ with torch.autocast("cuda", dtype = self.amp_dtype):
367
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
368
+
369
+ if ins in coarse_scales:
370
+ old_stuff = F.interpolate(
371
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
372
+ )
373
+ gp_posterior = self.gps[new_scale](f1_s, f2_s)
374
+ gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
375
+ gp_posterior, f1_s, old_stuff, new_scale
376
+ )
377
+
378
+ if self.embedding_decoder.is_classifier:
379
+ flow = cls_to_flow_refine(
380
+ gm_warp_or_cls,
381
+ ).permute(0,3,1,2)
382
+ corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
383
+ else:
384
+ corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
385
+ flow = gm_warp_or_cls.detach()
386
+
387
+ if new_scale in self.conv_refiner:
388
+ corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
389
+ delta_flow, delta_certainty = self.conv_refiner[new_scale](
390
+ f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
391
+ )
392
+ corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
393
+ displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
394
+ delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
395
+ flow = flow + displacement
396
+ certainty = (
397
+ certainty + delta_certainty
398
+ ) # predict both certainty and displacement
399
+ corresps[ins].update({
400
+ "certainty": certainty,
401
+ "flow": flow,
402
+ })
403
+ if new_scale != "1":
404
+ flow = F.interpolate(
405
+ flow,
406
+ size=sizes[ins // 2],
407
+ mode=self.flow_upsample_mode,
408
+ )
409
+ certainty = F.interpolate(
410
+ certainty,
411
+ size=sizes[ins // 2],
412
+ mode=self.flow_upsample_mode,
413
+ )
414
+ if self.detach:
415
+ flow = flow.detach()
416
+ certainty = certainty.detach()
417
+ #torch.cuda.empty_cache()
418
+ return corresps
419
+
420
+
421
+ class RegressionMatcher(nn.Module):
422
+ def __init__(
423
+ self,
424
+ encoder,
425
+ decoder,
426
+ h=448,
427
+ w=448,
428
+ sample_mode = "threshold_balanced",
429
+ upsample_preds = False,
430
+ symmetric = False,
431
+ name = None,
432
+ attenuate_cert = None,
433
+ recrop_upsample = False,
434
+ ):
435
+ super().__init__()
436
+ self.attenuate_cert = attenuate_cert
437
+ self.encoder = encoder
438
+ self.decoder = decoder
439
+ self.name = name
440
+ self.w_resized = w
441
+ self.h_resized = h
442
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
443
+ self.sample_mode = sample_mode
444
+ self.upsample_preds = upsample_preds
445
+ self.upsample_res = (14*16*6, 14*16*6)
446
+ self.symmetric = symmetric
447
+ self.sample_thresh = 0.05
448
+ self.recrop_upsample = recrop_upsample
449
+
450
+ def get_output_resolution(self):
451
+ if not self.upsample_preds:
452
+ return self.h_resized, self.w_resized
453
+ else:
454
+ return self.upsample_res
455
+
456
+ def extract_backbone_features(self, batch, batched = True, upsample = False):
457
+ x_q = batch["im_A"]
458
+ x_s = batch["im_B"]
459
+ if batched:
460
+ X = torch.cat((x_q, x_s), dim = 0)
461
+ feature_pyramid = self.encoder(X, upsample = upsample)
462
+ else:
463
+ feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
464
+ return feature_pyramid
465
+
466
+ def sample(
467
+ self,
468
+ matches,
469
+ certainty,
470
+ num=10000,
471
+ ):
472
+ if "threshold" in self.sample_mode:
473
+ upper_thresh = self.sample_thresh
474
+ certainty = certainty.clone()
475
+ certainty[certainty > upper_thresh] = 1
476
+ matches, certainty = (
477
+ matches.reshape(-1, 4),
478
+ certainty.reshape(-1),
479
+ )
480
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
481
+ good_samples = torch.multinomial(certainty,
482
+ num_samples = min(expansion_factor*num, len(certainty)),
483
+ replacement=False)
484
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
485
+ if "balanced" not in self.sample_mode:
486
+ return good_matches, good_certainty
487
+ density = kde(good_matches, std=0.1)
488
+ p = 1 / (density+1)
489
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
490
+ balanced_samples = torch.multinomial(p,
491
+ num_samples = min(num,len(good_certainty)),
492
+ replacement=False)
493
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
494
+
495
+ def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
496
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
497
+ if batched:
498
+ f_q_pyramid = {
499
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
500
+ }
501
+ f_s_pyramid = {
502
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
503
+ }
504
+ else:
505
+ f_q_pyramid, f_s_pyramid = feature_pyramid
506
+ corresps = self.decoder(f_q_pyramid,
507
+ f_s_pyramid,
508
+ upsample = upsample,
509
+ **(batch["corresps"] if "corresps" in batch else {}),
510
+ scale_factor=scale_factor)
511
+
512
+ return corresps
513
+
514
+ def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
515
+ feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
516
+ f_q_pyramid = feature_pyramid
517
+ f_s_pyramid = {
518
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
519
+ for scale, f_scale in feature_pyramid.items()
520
+ }
521
+ corresps = self.decoder(f_q_pyramid,
522
+ f_s_pyramid,
523
+ upsample = upsample,
524
+ **(batch["corresps"] if "corresps" in batch else {}),
525
+ scale_factor=scale_factor)
526
+ return corresps
527
+
528
+ def conf_from_fb_consistency(self, flow_forward, flow_backward, th = 2):
529
+ # assumes that flow forward is of shape (..., H, W, 2)
530
+ has_batch = False
531
+ if len(flow_forward.shape) == 3:
532
+ flow_forward, flow_backward = flow_forward[None], flow_backward[None]
533
+ else:
534
+ has_batch = True
535
+ H,W = flow_forward.shape[-3:-1]
536
+ th_n = 2 * th / max(H,W)
537
+ coords = torch.stack(torch.meshgrid(
538
+ torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
539
+ torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing = "xy"),
540
+ dim = -1).to(flow_forward.device)
541
+ coords_fb = F.grid_sample(
542
+ flow_backward.permute(0, 3, 1, 2),
543
+ flow_forward,
544
+ align_corners=False, mode="bilinear").permute(0, 2, 3, 1)
545
+ diff = (coords - coords_fb).norm(dim=-1)
546
+ in_th = (diff < th_n).float()
547
+ if not has_batch:
548
+ in_th = in_th[0]
549
+ return in_th
550
+
551
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
552
+ if coords.shape[-1] == 2:
553
+ return self._to_pixel_coordinates(coords, H_A, W_A)
554
+
555
+ if isinstance(coords, (list, tuple)):
556
+ kpts_A, kpts_B = coords[0], coords[1]
557
+ else:
558
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
559
+ return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
560
+
561
+ def _to_pixel_coordinates(self, coords, H, W):
562
+ kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
563
+ return kpts
564
+
565
+ def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
566
+ if isinstance(coords, (list, tuple)):
567
+ kpts_A, kpts_B = coords[0], coords[1]
568
+ else:
569
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
570
+ kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
571
+ kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
572
+ return kpts_A, kpts_B
573
+
574
+ def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
575
+ x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
576
+ cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
577
+ D = torch.cdist(x_A_to_B, x_B)
578
+ inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
579
+
580
+ if return_tuple:
581
+ if return_inds:
582
+ return inds_A, inds_B
583
+ else:
584
+ return x_A[inds_A], x_B[inds_B]
585
+ else:
586
+ if return_inds:
587
+ return torch.cat((inds_A, inds_B),dim=-1)
588
+ else:
589
+ return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
590
+
591
+ def get_roi(self, certainty, W, H, thr = 0.025):
592
+ raise NotImplementedError("WIP, disable for now")
593
+ hs,ws = certainty.shape
594
+ certainty = certainty/certainty.sum(dim=(-1,-2))
595
+ cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
596
+ cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
597
+ print(cum_certainty_w)
598
+ print(torch.min(torch.nonzero(cum_certainty_w > thr)))
599
+ print(torch.min(torch.nonzero(cum_certainty_w < thr)))
600
+ left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
601
+ right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
602
+ top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
603
+ bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
604
+ print(left, right, top, bottom)
605
+ return left, top, right, bottom
606
+
607
+ def recrop(self, certainty, image_path):
608
+ roi = self.get_roi(certainty, *Image.open(image_path).size)
609
+ return Image.open(image_path).convert("RGB").crop(roi)
610
+
611
+ @torch.inference_mode()
612
+ def match(
613
+ self,
614
+ im_A_path: Union[str, os.PathLike, Image.Image],
615
+ im_B_path: Union[str, os.PathLike, Image.Image],
616
+ *args,
617
+ batched=False,
618
+ device = None,
619
+ ):
620
+ if device is None:
621
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
622
+ if isinstance(im_A_path, (str, os.PathLike)):
623
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
624
+ else:
625
+ im_A, im_B = im_A_path, im_B_path
626
+
627
+ symmetric = self.symmetric
628
+ self.train(False)
629
+ with torch.no_grad():
630
+ if not batched:
631
+ b = 1
632
+ w, h = im_A.size
633
+ w2, h2 = im_B.size
634
+ # Get images in good format
635
+ ws = self.w_resized
636
+ hs = self.h_resized
637
+
638
+ test_transform = get_tuple_transform_ops(
639
+ resize=(hs, ws), normalize=True, clahe = False
640
+ )
641
+ im_A, im_B = test_transform((im_A, im_B))
642
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
643
+ else:
644
+ b, c, h, w = im_A.shape
645
+ b, c, h2, w2 = im_B.shape
646
+ assert w == w2 and h == h2, "For batched images we assume same size"
647
+ batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
648
+ if h != self.h_resized or self.w_resized != w:
649
+ warn("Model resolution and batch resolution differ, may produce unexpected results")
650
+ hs, ws = h, w
651
+ finest_scale = 1
652
+ # Run matcher
653
+ if symmetric:
654
+ corresps = self.forward_symmetric(batch)
655
+ else:
656
+ corresps = self.forward(batch, batched = True)
657
+
658
+ if self.upsample_preds:
659
+ hs, ws = self.upsample_res
660
+
661
+ if self.attenuate_cert:
662
+ low_res_certainty = F.interpolate(
663
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
664
+ )
665
+ cert_clamp = 0
666
+ factor = 0.5
667
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
668
+
669
+ if self.upsample_preds:
670
+ finest_corresps = corresps[finest_scale]
671
+ torch.cuda.empty_cache()
672
+ test_transform = get_tuple_transform_ops(
673
+ resize=(hs, ws), normalize=True
674
+ )
675
+ if self.recrop_upsample:
676
+ raise NotImplementedError("recrop_upsample not implemented")
677
+ certainty = corresps[finest_scale]["certainty"]
678
+ print(certainty.shape)
679
+ im_A = self.recrop(certainty[0,0], im_A_path)
680
+ im_B = self.recrop(certainty[1,0], im_B_path)
681
+ #TODO: need to adjust corresps when doing this
682
+ im_A, im_B = test_transform((im_A, im_B))
683
+ im_A, im_B = im_A[None].to(device), im_B[None].to(device)
684
+ scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
685
+ batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
686
+ if symmetric:
687
+ corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
688
+ else:
689
+ corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
690
+
691
+ im_A_to_im_B = corresps[finest_scale]["flow"]
692
+ certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
693
+ if finest_scale != 1:
694
+ im_A_to_im_B = F.interpolate(
695
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
696
+ )
697
+ certainty = F.interpolate(
698
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
699
+ )
700
+ im_A_to_im_B = im_A_to_im_B.permute(
701
+ 0, 2, 3, 1
702
+ )
703
+ # Create im_A meshgrid
704
+ im_A_coords = torch.meshgrid(
705
+ (
706
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
707
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
708
+ )
709
+ )
710
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
711
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
712
+ certainty = certainty.sigmoid() # logits -> probs
713
+ im_A_coords = im_A_coords.permute(0, 2, 3, 1)
714
+ if (im_A_to_im_B.abs() > 1).any() and True:
715
+ wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
716
+ certainty[wrong[:,None]] = 0
717
+ im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
718
+ if symmetric:
719
+ A_to_B, B_to_A = im_A_to_im_B.chunk(2)
720
+ q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
721
+ im_B_coords = im_A_coords
722
+ s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
723
+ warp = torch.cat((q_warp, s_warp),dim=2)
724
+ certainty = torch.cat(certainty.chunk(2), dim=3)
725
+ else:
726
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
727
+ if batched:
728
+ return (
729
+ warp,
730
+ certainty[:, 0]
731
+ )
732
+ else:
733
+ return (
734
+ warp[0],
735
+ certainty[0, 0],
736
+ )
737
+
738
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
739
+ im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None, unnormalize = False):
740
+ #assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
741
+ H,W2,_ = warp.shape
742
+ W = W2//2 if symmetric else W2
743
+ if im_A is None:
744
+ from PIL import Image
745
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
746
+ if not isinstance(im_A, torch.Tensor):
747
+ im_A = im_A.resize((W,H))
748
+ im_B = im_B.resize((W,H))
749
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
750
+ if symmetric:
751
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
752
+ else:
753
+ if symmetric:
754
+ x_A = im_A
755
+ x_B = im_B
756
+ im_A_transfer_rgb = F.grid_sample(
757
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
758
+ )[0]
759
+ if symmetric:
760
+ im_B_transfer_rgb = F.grid_sample(
761
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
762
+ )[0]
763
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
764
+ white_im = torch.ones((H,2*W),device=device)
765
+ else:
766
+ warp_im = im_A_transfer_rgb
767
+ white_im = torch.ones((H, W), device = device)
768
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
769
+ if save_path is not None:
770
+ from romatch.utils import tensor_to_pil
771
+ tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
772
+ return vis_im
third_party/RoMa/romatch/models/model_zoo/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import torch
3
+ from .roma_models import roma_model, tiny_roma_v1_model
4
+
5
+ weight_urls = {
6
+ "romatch": {
7
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/romatch/roma_outdoor.pth",
8
+ "indoor": "https://github.com/Parskatt/storage/releases/download/romatch/roma_indoor.pth",
9
+ },
10
+ "tiny_roma_v1": {
11
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/romatch/tiny_roma_v1_outdoor.pth",
12
+ },
13
+ "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
14
+ }
15
+
16
+ def tiny_roma_v1_outdoor(device, weights = None, xfeat = None):
17
+ if weights is None:
18
+ weights = torch.hub.load_state_dict_from_url(
19
+ weight_urls["tiny_roma_v1"]["outdoor"],
20
+ map_location=device)
21
+ if xfeat is None:
22
+ xfeat = torch.hub.load(
23
+ 'verlab/accelerated_features',
24
+ 'XFeat',
25
+ pretrained = True,
26
+ top_k = 4096).net
27
+
28
+ return tiny_roma_v1_model(weights = weights, xfeat = xfeat).to(device)
29
+
30
+ def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
31
+ if isinstance(coarse_res, int):
32
+ coarse_res = (coarse_res, coarse_res)
33
+ if isinstance(upsample_res, int):
34
+ upsample_res = (upsample_res, upsample_res)
35
+
36
+ assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
37
+ assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
38
+
39
+ if weights is None:
40
+ weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["outdoor"],
41
+ map_location=device)
42
+ if dinov2_weights is None:
43
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
44
+ map_location=device)
45
+ model = roma_model(resolution=coarse_res, upsample_preds=True,
46
+ weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
47
+ model.upsample_res = upsample_res
48
+ print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
49
+ return model
50
+
51
+ def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
52
+ if isinstance(coarse_res, int):
53
+ coarse_res = (coarse_res, coarse_res)
54
+ if isinstance(upsample_res, int):
55
+ upsample_res = (upsample_res, upsample_res)
56
+
57
+ assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
58
+ assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
59
+
60
+ if weights is None:
61
+ weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["indoor"],
62
+ map_location=device)
63
+ if dinov2_weights is None:
64
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
65
+ map_location=device)
66
+ model = roma_model(resolution=coarse_res, upsample_preds=True,
67
+ weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
68
+ model.upsample_res = upsample_res
69
+ print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
70
+ return model
third_party/RoMa/romatch/models/model_zoo/roma_models.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch.nn as nn
3
+ import torch
4
+ from romatch.models.matcher import *
5
+ from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
6
+ from romatch.models.encoders import *
7
+ from romatch.models.tiny import TinyRoMa
8
+
9
+ def tiny_roma_v1_model(weights = None, freeze_xfeat=False, exact_softmax=False, xfeat = None):
10
+ model = TinyRoMa(
11
+ xfeat = xfeat,
12
+ freeze_xfeat=freeze_xfeat,
13
+ exact_softmax=exact_softmax)
14
+ if weights is not None:
15
+ model.load_state_dict(weights)
16
+ return model
17
+
18
+ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
19
+ # romatch weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
20
+ #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
21
+ #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
22
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
23
+ gp_dim = 512
24
+ feat_dim = 512
25
+ decoder_dim = gp_dim + feat_dim
26
+ cls_to_coord_res = 64
27
+ coordinate_decoder = TransformerDecoder(
28
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
29
+ decoder_dim,
30
+ cls_to_coord_res**2 + 1,
31
+ is_classifier=True,
32
+ amp = True,
33
+ pos_enc = False,)
34
+ dw = True
35
+ hidden_blocks = 8
36
+ kernel_size = 5
37
+ displacement_emb = "linear"
38
+ disable_local_corr_grad = True
39
+
40
+ conv_refiner = nn.ModuleDict(
41
+ {
42
+ "16": ConvRefiner(
43
+ 2 * 512+128+(2*7+1)**2,
44
+ 2 * 512+128+(2*7+1)**2,
45
+ 2 + 1,
46
+ kernel_size=kernel_size,
47
+ dw=dw,
48
+ hidden_blocks=hidden_blocks,
49
+ displacement_emb=displacement_emb,
50
+ displacement_emb_dim=128,
51
+ local_corr_radius = 7,
52
+ corr_in_other = True,
53
+ amp = True,
54
+ disable_local_corr_grad = disable_local_corr_grad,
55
+ bn_momentum = 0.01,
56
+ ),
57
+ "8": ConvRefiner(
58
+ 2 * 512+64+(2*3+1)**2,
59
+ 2 * 512+64+(2*3+1)**2,
60
+ 2 + 1,
61
+ kernel_size=kernel_size,
62
+ dw=dw,
63
+ hidden_blocks=hidden_blocks,
64
+ displacement_emb=displacement_emb,
65
+ displacement_emb_dim=64,
66
+ local_corr_radius = 3,
67
+ corr_in_other = True,
68
+ amp = True,
69
+ disable_local_corr_grad = disable_local_corr_grad,
70
+ bn_momentum = 0.01,
71
+ ),
72
+ "4": ConvRefiner(
73
+ 2 * 256+32+(2*2+1)**2,
74
+ 2 * 256+32+(2*2+1)**2,
75
+ 2 + 1,
76
+ kernel_size=kernel_size,
77
+ dw=dw,
78
+ hidden_blocks=hidden_blocks,
79
+ displacement_emb=displacement_emb,
80
+ displacement_emb_dim=32,
81
+ local_corr_radius = 2,
82
+ corr_in_other = True,
83
+ amp = True,
84
+ disable_local_corr_grad = disable_local_corr_grad,
85
+ bn_momentum = 0.01,
86
+ ),
87
+ "2": ConvRefiner(
88
+ 2 * 64+16,
89
+ 128+16,
90
+ 2 + 1,
91
+ kernel_size=kernel_size,
92
+ dw=dw,
93
+ hidden_blocks=hidden_blocks,
94
+ displacement_emb=displacement_emb,
95
+ displacement_emb_dim=16,
96
+ amp = True,
97
+ disable_local_corr_grad = disable_local_corr_grad,
98
+ bn_momentum = 0.01,
99
+ ),
100
+ "1": ConvRefiner(
101
+ 2 * 9 + 6,
102
+ 24,
103
+ 2 + 1,
104
+ kernel_size=kernel_size,
105
+ dw=dw,
106
+ hidden_blocks = hidden_blocks,
107
+ displacement_emb = displacement_emb,
108
+ displacement_emb_dim = 6,
109
+ amp = True,
110
+ disable_local_corr_grad = disable_local_corr_grad,
111
+ bn_momentum = 0.01,
112
+ ),
113
+ }
114
+ )
115
+ kernel_temperature = 0.2
116
+ learn_temperature = False
117
+ no_cov = True
118
+ kernel = CosKernel
119
+ only_attention = False
120
+ basis = "fourier"
121
+ gp16 = GP(
122
+ kernel,
123
+ T=kernel_temperature,
124
+ learn_temperature=learn_temperature,
125
+ only_attention=only_attention,
126
+ gp_dim=gp_dim,
127
+ basis=basis,
128
+ no_cov=no_cov,
129
+ )
130
+ gps = nn.ModuleDict({"16": gp16})
131
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
132
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
133
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
134
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
135
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
136
+ proj = nn.ModuleDict({
137
+ "16": proj16,
138
+ "8": proj8,
139
+ "4": proj4,
140
+ "2": proj2,
141
+ "1": proj1,
142
+ })
143
+ displacement_dropout_p = 0.0
144
+ gm_warp_dropout_p = 0.0
145
+ decoder = Decoder(coordinate_decoder,
146
+ gps,
147
+ proj,
148
+ conv_refiner,
149
+ detach=True,
150
+ scales=["16", "8", "4", "2", "1"],
151
+ displacement_dropout_p = displacement_dropout_p,
152
+ gm_warp_dropout_p = gm_warp_dropout_p)
153
+
154
+ encoder = CNNandDinov2(
155
+ cnn_kwargs = dict(
156
+ pretrained=False,
157
+ amp = True),
158
+ amp = True,
159
+ use_vgg = True,
160
+ dinov2_weights = dinov2_weights,
161
+ amp_dtype=amp_dtype,
162
+ )
163
+ h,w = resolution
164
+ symmetric = True
165
+ attenuate_cert = True
166
+ sample_mode = "threshold_balanced"
167
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
168
+ symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
169
+ matcher.load_state_dict(weights)
170
+ return matcher
third_party/RoMa/romatch/models/tiny.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import torch
6
+ from pathlib import Path
7
+ import math
8
+ import numpy as np
9
+
10
+ from torch import nn
11
+ from PIL import Image
12
+ from torchvision.transforms import ToTensor
13
+ from romatch.utils.kde import kde
14
+
15
+ class BasicLayer(nn.Module):
16
+ """
17
+ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
18
+ """
19
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
20
+ super().__init__()
21
+ self.layer = nn.Sequential(
22
+ nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
23
+ nn.BatchNorm2d(out_channels, affine=False),
24
+ nn.ReLU(inplace = True) if relu else nn.Identity()
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.layer(x)
29
+
30
+ class TinyRoMa(nn.Module):
31
+ """
32
+ Implementation of architecture described in
33
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
34
+ """
35
+
36
+ def __init__(self, xfeat = None,
37
+ freeze_xfeat = True,
38
+ sample_mode = "threshold_balanced",
39
+ symmetric = False,
40
+ exact_softmax = False):
41
+ super().__init__()
42
+ del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
43
+ if freeze_xfeat:
44
+ xfeat.train(False)
45
+ self.xfeat = [xfeat]# hide params from ddp
46
+ else:
47
+ self.xfeat = nn.ModuleList([xfeat])
48
+ self.freeze_xfeat = freeze_xfeat
49
+ match_dim = 256
50
+ self.coarse_matcher = nn.Sequential(
51
+ BasicLayer(64+64+2, match_dim,),
52
+ BasicLayer(match_dim, match_dim,),
53
+ BasicLayer(match_dim, match_dim,),
54
+ BasicLayer(match_dim, match_dim,),
55
+ nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
56
+ fine_match_dim = 64
57
+ self.fine_matcher = nn.Sequential(
58
+ BasicLayer(24+24+2, fine_match_dim,),
59
+ BasicLayer(fine_match_dim, fine_match_dim,),
60
+ BasicLayer(fine_match_dim, fine_match_dim,),
61
+ BasicLayer(fine_match_dim, fine_match_dim,),
62
+ nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
63
+ self.sample_mode = sample_mode
64
+ self.sample_thresh = 0.05
65
+ self.symmetric = symmetric
66
+ self.exact_softmax = exact_softmax
67
+
68
+ @property
69
+ def device(self):
70
+ return self.fine_matcher[-1].weight.device
71
+
72
+ def preprocess_tensor(self, x):
73
+ """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
74
+ H, W = x.shape[-2:]
75
+ _H, _W = (H//32) * 32, (W//32) * 32
76
+ rh, rw = H/_H, W/_W
77
+
78
+ x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
79
+ return x, rh, rw
80
+
81
+ def forward_single(self, x):
82
+ with torch.inference_mode(self.freeze_xfeat or not self.training):
83
+ xfeat = self.xfeat[0]
84
+ with torch.no_grad():
85
+ x = x.mean(dim=1, keepdim = True)
86
+ x = xfeat.norm(x)
87
+
88
+ #main backbone
89
+ x1 = xfeat.block1(x)
90
+ x2 = xfeat.block2(x1 + xfeat.skip1(x))
91
+ x3 = xfeat.block3(x2)
92
+ x4 = xfeat.block4(x3)
93
+ x5 = xfeat.block5(x4)
94
+ x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
95
+ x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
96
+ feats = xfeat.block_fusion( x3 + x4 + x5 )
97
+ if self.freeze_xfeat:
98
+ return x2.clone(), feats.clone()
99
+ return x2, feats
100
+
101
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
102
+ if coords.shape[-1] == 2:
103
+ return self._to_pixel_coordinates(coords, H_A, W_A)
104
+
105
+ if isinstance(coords, (list, tuple)):
106
+ kpts_A, kpts_B = coords[0], coords[1]
107
+ else:
108
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
109
+ return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
110
+
111
+ def _to_pixel_coordinates(self, coords, H, W):
112
+ kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
113
+ return kpts
114
+
115
+ def pos_embed(self, corr_volume: torch.Tensor):
116
+ B, H1, W1, H0, W0 = corr_volume.shape
117
+ grid = torch.stack(
118
+ torch.meshgrid(
119
+ torch.linspace(-1+1/W1,1-1/W1, W1),
120
+ torch.linspace(-1+1/H1,1-1/H1, H1),
121
+ indexing = "xy"),
122
+ dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
123
+ down = 4
124
+ if not self.training and not self.exact_softmax:
125
+ grid_lr = torch.stack(
126
+ torch.meshgrid(
127
+ torch.linspace(-1+down/W1,1-down/W1, W1//down),
128
+ torch.linspace(-1+down/H1,1-down/H1, H1//down),
129
+ indexing = "xy"),
130
+ dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
131
+ cv = corr_volume
132
+ best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
133
+ P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
134
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
135
+ pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
136
+ #print("hej")
137
+ else:
138
+ P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
139
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
140
+ return pos_embeddings
141
+
142
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
143
+ im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
144
+ device = warp.device
145
+ H,W2,_ = warp.shape
146
+ W = W2//2 if symmetric else W2
147
+ if im_A is None:
148
+ from PIL import Image
149
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
150
+ if not isinstance(im_A, torch.Tensor):
151
+ im_A = im_A.resize((W,H))
152
+ im_B = im_B.resize((W,H))
153
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
154
+ if symmetric:
155
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
156
+ else:
157
+ if symmetric:
158
+ x_A = im_A
159
+ x_B = im_B
160
+ im_A_transfer_rgb = F.grid_sample(
161
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
162
+ )[0]
163
+ if symmetric:
164
+ im_B_transfer_rgb = F.grid_sample(
165
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
166
+ )[0]
167
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
168
+ white_im = torch.ones((H,2*W),device=device)
169
+ else:
170
+ warp_im = im_A_transfer_rgb
171
+ white_im = torch.ones((H, W), device = device)
172
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
173
+ if save_path is not None:
174
+ from romatch.utils import tensor_to_pil
175
+ tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
176
+ return vis_im
177
+
178
+ def corr_volume(self, feat0, feat1):
179
+ """
180
+ input:
181
+ feat0 -> torch.Tensor(B, C, H, W)
182
+ feat1 -> torch.Tensor(B, C, H, W)
183
+ return:
184
+ corr_volume -> torch.Tensor(B, H, W, H, W)
185
+ """
186
+ B, C, H0, W0 = feat0.shape
187
+ B, C, H1, W1 = feat1.shape
188
+ feat0 = feat0.view(B, C, H0*W0)
189
+ feat1 = feat1.view(B, C, H1*W1)
190
+ corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
191
+ return corr_volume
192
+
193
+ @torch.inference_mode()
194
+ def match_from_path(self, im0_path, im1_path):
195
+ device = self.device
196
+ im0 = ToTensor()(Image.open(im0_path))[None].to(device)
197
+ im1 = ToTensor()(Image.open(im1_path))[None].to(device)
198
+ return self.match(im0, im1, batched = False)
199
+
200
+ @torch.inference_mode()
201
+ def match(self, im0, im1, *args, batched = True):
202
+ # stupid
203
+ if isinstance(im0, (str, Path)):
204
+ return self.match_from_path(im0, im1)
205
+ elif isinstance(im0, Image.Image):
206
+ batched = False
207
+ device = self.device
208
+ im0 = ToTensor()(im0)[None].to(device)
209
+ im1 = ToTensor()(im1)[None].to(device)
210
+
211
+ B,C,H0,W0 = im0.shape
212
+ B,C,H1,W1 = im1.shape
213
+ self.train(False)
214
+ corresps = self.forward({"im_A":im0, "im_B":im1})
215
+ #return 1,1
216
+ flow = F.interpolate(
217
+ corresps[4]["flow"],
218
+ size = (H0, W0),
219
+ mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
220
+ grid = torch.stack(
221
+ torch.meshgrid(
222
+ torch.linspace(-1+1/W0,1-1/W0, W0),
223
+ torch.linspace(-1+1/H0,1-1/H0, H0),
224
+ indexing = "xy"),
225
+ dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
226
+
227
+ certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
228
+ warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
229
+ if batched:
230
+ return warp, cert
231
+ else:
232
+ return warp[0], cert[0]
233
+
234
+ def sample(
235
+ self,
236
+ matches,
237
+ certainty,
238
+ num=5_000,
239
+ ):
240
+ H,W,_ = matches.shape
241
+ if "threshold" in self.sample_mode:
242
+ upper_thresh = self.sample_thresh
243
+ certainty = certainty.clone()
244
+ certainty[certainty > upper_thresh] = 1
245
+ matches, certainty = (
246
+ matches.reshape(-1, 4),
247
+ certainty.reshape(-1),
248
+ )
249
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
250
+ good_samples = torch.multinomial(certainty,
251
+ num_samples = min(expansion_factor*num, len(certainty)),
252
+ replacement=False)
253
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
254
+ if "balanced" not in self.sample_mode:
255
+ return good_matches, good_certainty
256
+ use_half = True if matches.device.type == "cuda" else False
257
+ down = 1 if matches.device.type == "cuda" else 8
258
+ density = kde(good_matches, std=0.1, half = use_half, down = down)
259
+ p = 1 / (density+1)
260
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
261
+ balanced_samples = torch.multinomial(p,
262
+ num_samples = min(num,len(good_certainty)),
263
+ replacement=False)
264
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
265
+
266
+
267
+ def forward(self, batch):
268
+ """
269
+ input:
270
+ x -> torch.Tensor(B, C, H, W) grayscale or rgb images
271
+ return:
272
+
273
+ """
274
+ im0 = batch["im_A"]
275
+ im1 = batch["im_B"]
276
+ corresps = {}
277
+ im0, rh0, rw0 = self.preprocess_tensor(im0)
278
+ im1, rh1, rw1 = self.preprocess_tensor(im1)
279
+ B, C, H0, W0 = im0.shape
280
+ B, C, H1, W1 = im1.shape
281
+ to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
282
+
283
+ if im0.shape[-2:] == im1.shape[-2:]:
284
+ x = torch.cat([im0, im1], dim=0)
285
+ x = self.forward_single(x)
286
+ feats_x0_c, feats_x1_c = x[1].chunk(2)
287
+ feats_x0_f, feats_x1_f = x[0].chunk(2)
288
+ else:
289
+ feats_x0_f, feats_x0_c = self.forward_single(im0)
290
+ feats_x1_f, feats_x1_c = self.forward_single(im1)
291
+ corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
292
+ coarse_warp = self.pos_embed(corr_volume)
293
+ coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
294
+ feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
295
+ coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
296
+ coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
297
+ corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
298
+ coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
299
+ coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
300
+ feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
301
+ fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
302
+ fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
303
+ corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
304
+ return corresps
third_party/RoMa/romatch/models/transformer/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from romatch.utils.utils import get_grid
6
+ from .layers.block import Block
7
+ from .layers.attention import MemEffAttention
8
+ from .dinov2 import vit_large
9
+
10
+ class TransformerDecoder(nn.Module):
11
+ def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
12
+ amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
13
+ super().__init__(*args, **kwargs)
14
+ self.blocks = blocks
15
+ self.to_out = nn.Linear(hidden_dim, out_dim)
16
+ self.hidden_dim = hidden_dim
17
+ self.out_dim = out_dim
18
+ self._scales = [16]
19
+ self.is_classifier = is_classifier
20
+ self.amp = amp
21
+ self.amp_dtype = amp_dtype
22
+ self.pos_enc = pos_enc
23
+ self.learned_embeddings = learned_embeddings
24
+ if self.learned_embeddings:
25
+ self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
26
+
27
+ def scales(self):
28
+ return self._scales.copy()
29
+
30
+ def forward(self, gp_posterior, features, old_stuff, new_scale):
31
+ with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
32
+ B,C,H,W = gp_posterior.shape
33
+ x = torch.cat((gp_posterior, features), dim = 1)
34
+ B,C,H,W = x.shape
35
+ grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
36
+ if self.learned_embeddings:
37
+ pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
38
+ else:
39
+ pos_enc = 0
40
+ tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
41
+ z = self.blocks(tokens)
42
+ out = self.to_out(z)
43
+ out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
44
+ warp, certainty = out[:, :-1], out[:, -1:]
45
+ return warp, certainty, None
46
+
47
+
third_party/RoMa/romatch/models/transformer/dinov2.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ from functools import partial
12
+ import math
13
+ import logging
14
+ from typing import Sequence, Tuple, Union, Callable
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.utils.checkpoint
19
+ from torch.nn.init import trunc_normal_
20
+
21
+ from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22
+
23
+
24
+
25
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
26
+ if not depth_first and include_root:
27
+ fn(module=module, name=name)
28
+ for child_name, child_module in module.named_children():
29
+ child_name = ".".join((name, child_name)) if name else child_name
30
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
31
+ if depth_first and include_root:
32
+ fn(module=module, name=name)
33
+ return module
34
+
35
+
36
+ class BlockChunk(nn.ModuleList):
37
+ def forward(self, x):
38
+ for b in self:
39
+ x = b(x)
40
+ return x
41
+
42
+
43
+ class DinoVisionTransformer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size=224,
47
+ patch_size=16,
48
+ in_chans=3,
49
+ embed_dim=768,
50
+ depth=12,
51
+ num_heads=12,
52
+ mlp_ratio=4.0,
53
+ qkv_bias=True,
54
+ ffn_bias=True,
55
+ proj_bias=True,
56
+ drop_path_rate=0.0,
57
+ drop_path_uniform=False,
58
+ init_values=None, # for layerscale: None or 0 => no layerscale
59
+ embed_layer=PatchEmbed,
60
+ act_layer=nn.GELU,
61
+ block_fn=Block,
62
+ ffn_layer="mlp",
63
+ block_chunks=1,
64
+ ):
65
+ """
66
+ Args:
67
+ img_size (int, tuple): input image size
68
+ patch_size (int, tuple): patch size
69
+ in_chans (int): number of input channels
70
+ embed_dim (int): embedding dimension
71
+ depth (int): depth of transformer
72
+ num_heads (int): number of attention heads
73
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
74
+ qkv_bias (bool): enable bias for qkv if True
75
+ proj_bias (bool): enable bias for proj in attn if True
76
+ ffn_bias (bool): enable bias for ffn if True
77
+ drop_path_rate (float): stochastic depth rate
78
+ drop_path_uniform (bool): apply uniform drop rate across blocks
79
+ weight_init (str): weight init scheme
80
+ init_values (float): layer-scale init values
81
+ embed_layer (nn.Module): patch embedding layer
82
+ act_layer (nn.Module): MLP activation layer
83
+ block_fn (nn.Module): transformer block class
84
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
85
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
86
+ """
87
+ super().__init__()
88
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
89
+
90
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
91
+ self.num_tokens = 1
92
+ self.n_blocks = depth
93
+ self.num_heads = num_heads
94
+ self.patch_size = patch_size
95
+
96
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
97
+ num_patches = self.patch_embed.num_patches
98
+
99
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
100
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
101
+
102
+ if drop_path_uniform is True:
103
+ dpr = [drop_path_rate] * depth
104
+ else:
105
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
106
+
107
+ if ffn_layer == "mlp":
108
+ ffn_layer = Mlp
109
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
110
+ ffn_layer = SwiGLUFFNFused
111
+ elif ffn_layer == "identity":
112
+
113
+ def f(*args, **kwargs):
114
+ return nn.Identity()
115
+
116
+ ffn_layer = f
117
+ else:
118
+ raise NotImplementedError
119
+
120
+ blocks_list = [
121
+ block_fn(
122
+ dim=embed_dim,
123
+ num_heads=num_heads,
124
+ mlp_ratio=mlp_ratio,
125
+ qkv_bias=qkv_bias,
126
+ proj_bias=proj_bias,
127
+ ffn_bias=ffn_bias,
128
+ drop_path=dpr[i],
129
+ norm_layer=norm_layer,
130
+ act_layer=act_layer,
131
+ ffn_layer=ffn_layer,
132
+ init_values=init_values,
133
+ )
134
+ for i in range(depth)
135
+ ]
136
+ if block_chunks > 0:
137
+ self.chunked_blocks = True
138
+ chunked_blocks = []
139
+ chunksize = depth // block_chunks
140
+ for i in range(0, depth, chunksize):
141
+ # this is to keep the block index consistent if we chunk the block list
142
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
143
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
144
+ else:
145
+ self.chunked_blocks = False
146
+ self.blocks = nn.ModuleList(blocks_list)
147
+
148
+ self.norm = norm_layer(embed_dim)
149
+ self.head = nn.Identity()
150
+
151
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
152
+
153
+ self.init_weights()
154
+ for param in self.parameters():
155
+ param.requires_grad = False
156
+
157
+ @property
158
+ def device(self):
159
+ return self.cls_token.device
160
+
161
+ def init_weights(self):
162
+ trunc_normal_(self.pos_embed, std=0.02)
163
+ nn.init.normal_(self.cls_token, std=1e-6)
164
+ named_apply(init_weights_vit_timm, self)
165
+
166
+ def interpolate_pos_encoding(self, x, w, h):
167
+ previous_dtype = x.dtype
168
+ npatch = x.shape[1] - 1
169
+ N = self.pos_embed.shape[1] - 1
170
+ if npatch == N and w == h:
171
+ return self.pos_embed
172
+ pos_embed = self.pos_embed.float()
173
+ class_pos_embed = pos_embed[:, 0]
174
+ patch_pos_embed = pos_embed[:, 1:]
175
+ dim = x.shape[-1]
176
+ w0 = w // self.patch_size
177
+ h0 = h // self.patch_size
178
+ # we add a small number to avoid floating point error in the interpolation
179
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
180
+ w0, h0 = w0 + 0.1, h0 + 0.1
181
+
182
+ patch_pos_embed = nn.functional.interpolate(
183
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
184
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
185
+ mode="bicubic",
186
+ )
187
+
188
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
189
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
190
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
191
+
192
+ def prepare_tokens_with_masks(self, x, masks=None):
193
+ B, nc, w, h = x.shape
194
+ x = self.patch_embed(x)
195
+ if masks is not None:
196
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
197
+
198
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
199
+ x = x + self.interpolate_pos_encoding(x, w, h)
200
+
201
+ return x
202
+
203
+ def forward_features_list(self, x_list, masks_list):
204
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
205
+ for blk in self.blocks:
206
+ x = blk(x)
207
+
208
+ all_x = x
209
+ output = []
210
+ for x, masks in zip(all_x, masks_list):
211
+ x_norm = self.norm(x)
212
+ output.append(
213
+ {
214
+ "x_norm_clstoken": x_norm[:, 0],
215
+ "x_norm_patchtokens": x_norm[:, 1:],
216
+ "x_prenorm": x,
217
+ "masks": masks,
218
+ }
219
+ )
220
+ return output
221
+
222
+ def forward_features(self, x, masks=None):
223
+ if isinstance(x, list):
224
+ return self.forward_features_list(x, masks)
225
+
226
+ x = self.prepare_tokens_with_masks(x, masks)
227
+
228
+ for blk in self.blocks:
229
+ x = blk(x)
230
+
231
+ x_norm = self.norm(x)
232
+ return {
233
+ "x_norm_clstoken": x_norm[:, 0],
234
+ "x_norm_patchtokens": x_norm[:, 1:],
235
+ "x_prenorm": x,
236
+ "masks": masks,
237
+ }
238
+
239
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
240
+ x = self.prepare_tokens_with_masks(x)
241
+ # If n is an int, take the n last blocks. If it's a list, take them
242
+ output, total_block_len = [], len(self.blocks)
243
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
244
+ for i, blk in enumerate(self.blocks):
245
+ x = blk(x)
246
+ if i in blocks_to_take:
247
+ output.append(x)
248
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
249
+ return output
250
+
251
+ def _get_intermediate_layers_chunked(self, x, n=1):
252
+ x = self.prepare_tokens_with_masks(x)
253
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
254
+ # If n is an int, take the n last blocks. If it's a list, take them
255
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
256
+ for block_chunk in self.blocks:
257
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
258
+ x = blk(x)
259
+ if i in blocks_to_take:
260
+ output.append(x)
261
+ i += 1
262
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
263
+ return output
264
+
265
+ def get_intermediate_layers(
266
+ self,
267
+ x: torch.Tensor,
268
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
269
+ reshape: bool = False,
270
+ return_class_token: bool = False,
271
+ norm=True,
272
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
273
+ if self.chunked_blocks:
274
+ outputs = self._get_intermediate_layers_chunked(x, n)
275
+ else:
276
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
277
+ if norm:
278
+ outputs = [self.norm(out) for out in outputs]
279
+ class_tokens = [out[:, 0] for out in outputs]
280
+ outputs = [out[:, 1:] for out in outputs]
281
+ if reshape:
282
+ B, _, w, h = x.shape
283
+ outputs = [
284
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
285
+ for out in outputs
286
+ ]
287
+ if return_class_token:
288
+ return tuple(zip(outputs, class_tokens))
289
+ return tuple(outputs)
290
+
291
+ def forward(self, *args, is_training=False, **kwargs):
292
+ ret = self.forward_features(*args, **kwargs)
293
+ if is_training:
294
+ return ret
295
+ else:
296
+ return self.head(ret["x_norm_clstoken"])
297
+
298
+
299
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
300
+ """ViT weight initialization, original timm impl (for reproducibility)"""
301
+ if isinstance(module, nn.Linear):
302
+ trunc_normal_(module.weight, std=0.02)
303
+ if module.bias is not None:
304
+ nn.init.zeros_(module.bias)
305
+
306
+
307
+ def vit_small(patch_size=16, **kwargs):
308
+ model = DinoVisionTransformer(
309
+ patch_size=patch_size,
310
+ embed_dim=384,
311
+ depth=12,
312
+ num_heads=6,
313
+ mlp_ratio=4,
314
+ block_fn=partial(Block, attn_class=MemEffAttention),
315
+ **kwargs,
316
+ )
317
+ return model
318
+
319
+
320
+ def vit_base(patch_size=16, **kwargs):
321
+ model = DinoVisionTransformer(
322
+ patch_size=patch_size,
323
+ embed_dim=768,
324
+ depth=12,
325
+ num_heads=12,
326
+ mlp_ratio=4,
327
+ block_fn=partial(Block, attn_class=MemEffAttention),
328
+ **kwargs,
329
+ )
330
+ return model
331
+
332
+
333
+ def vit_large(patch_size=16, **kwargs):
334
+ model = DinoVisionTransformer(
335
+ patch_size=patch_size,
336
+ embed_dim=1024,
337
+ depth=24,
338
+ num_heads=16,
339
+ mlp_ratio=4,
340
+ block_fn=partial(Block, attn_class=MemEffAttention),
341
+ **kwargs,
342
+ )
343
+ return model
344
+
345
+
346
+ def vit_giant2(patch_size=16, **kwargs):
347
+ """
348
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
349
+ """
350
+ model = DinoVisionTransformer(
351
+ patch_size=patch_size,
352
+ embed_dim=1536,
353
+ depth=40,
354
+ num_heads=24,
355
+ mlp_ratio=4,
356
+ block_fn=partial(Block, attn_class=MemEffAttention),
357
+ **kwargs,
358
+ )
359
+ return model
third_party/RoMa/romatch/models/transformer/layers/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .dino_head import DINOHead
8
+ from .mlp import Mlp
9
+ from .patch_embed import PatchEmbed
10
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11
+ from .block import NestedTensorBlock
12
+ from .attention import MemEffAttention
third_party/RoMa/romatch/models/transformer/layers/attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
third_party/RoMa/romatch/models/transformer/layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
third_party/RoMa/romatch/models/transformer/layers/dino_head.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn.init import trunc_normal_
10
+ from torch.nn.utils import weight_norm
11
+
12
+
13
+ class DINOHead(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim,
17
+ out_dim,
18
+ use_bn=False,
19
+ nlayers=3,
20
+ hidden_dim=2048,
21
+ bottleneck_dim=256,
22
+ mlp_bias=True,
23
+ ):
24
+ super().__init__()
25
+ nlayers = max(nlayers, 1)
26
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27
+ self.apply(self._init_weights)
28
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29
+ self.last_layer.weight_g.data.fill_(1)
30
+
31
+ def _init_weights(self, m):
32
+ if isinstance(m, nn.Linear):
33
+ trunc_normal_(m.weight, std=0.02)
34
+ if isinstance(m, nn.Linear) and m.bias is not None:
35
+ nn.init.constant_(m.bias, 0)
36
+
37
+ def forward(self, x):
38
+ x = self.mlp(x)
39
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41
+ x = self.last_layer(x)
42
+ return x
43
+
44
+
45
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46
+ if nlayers == 1:
47
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48
+ else:
49
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50
+ if use_bn:
51
+ layers.append(nn.BatchNorm1d(hidden_dim))
52
+ layers.append(nn.GELU())
53
+ for _ in range(nlayers - 2):
54
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55
+ if use_bn:
56
+ layers.append(nn.BatchNorm1d(hidden_dim))
57
+ layers.append(nn.GELU())
58
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59
+ return nn.Sequential(*layers)
third_party/RoMa/romatch/models/transformer/layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
third_party/RoMa/romatch/models/transformer/layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
third_party/RoMa/romatch/models/transformer/layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
third_party/RoMa/romatch/models/transformer/layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
third_party/RoMa/romatch/train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .train import train_k_epochs
third_party/RoMa/romatch/train/train.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from romatch.utils.utils import to_cuda
3
+ import romatch
4
+ import torch
5
+ import wandb
6
+
7
+ def log_param_statistics(named_parameters, norm_type = 2):
8
+ named_parameters = list(named_parameters)
9
+ grads = [p.grad for n, p in named_parameters if p.grad is not None]
10
+ weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
11
+ names = [n for n,p in named_parameters if p.grad is not None]
12
+ param_norm = torch.stack(weight_norms).norm(p=norm_type)
13
+ device = grads[0].device
14
+ grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
15
+ nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
16
+ nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
17
+ total_grad_norm = torch.norm(grad_norms, norm_type)
18
+ if torch.any(nans_or_infs):
19
+ print(f"These params have nan or inf grads: {nan_inf_names}")
20
+ wandb.log({"grad_norm": total_grad_norm.item()}, step = romatch.GLOBAL_STEP)
21
+ wandb.log({"param_norm": param_norm.item()}, step = romatch.GLOBAL_STEP)
22
+
23
+ def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
24
+ optimizer.zero_grad()
25
+ out = model(train_batch)
26
+ l = objective(out, train_batch)
27
+ grad_scaler.scale(l).backward()
28
+ grad_scaler.unscale_(optimizer)
29
+ log_param_statistics(model.named_parameters())
30
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
31
+ grad_scaler.step(optimizer)
32
+ grad_scaler.update()
33
+ wandb.log({"grad_scale": grad_scaler._scale.item()}, step = romatch.GLOBAL_STEP)
34
+ if grad_scaler._scale < 1.:
35
+ grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
36
+ romatch.GLOBAL_STEP = romatch.GLOBAL_STEP + romatch.STEP_SIZE # increment global step
37
+ return {"train_out": out, "train_loss": l.item()}
38
+
39
+
40
+ def train_k_steps(
41
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, pbar_n_seconds = 1,
42
+ ):
43
+ for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or romatch.RANK > 0, mininterval=pbar_n_seconds):
44
+ batch = next(dataloader)
45
+ model.train(True)
46
+ batch = to_cuda(batch)
47
+ train_step(
48
+ train_batch=batch,
49
+ model=model,
50
+ objective=objective,
51
+ optimizer=optimizer,
52
+ lr_scheduler=lr_scheduler,
53
+ grad_scaler=grad_scaler,
54
+ n=n,
55
+ grad_clip_norm = grad_clip_norm,
56
+ )
57
+ if ema_model is not None:
58
+ ema_model.update()
59
+ if warmup is not None:
60
+ with warmup.dampening():
61
+ lr_scheduler.step()
62
+ else:
63
+ lr_scheduler.step()
64
+ [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
65
+
66
+
67
+ def train_epoch(
68
+ dataloader=None,
69
+ model=None,
70
+ objective=None,
71
+ optimizer=None,
72
+ lr_scheduler=None,
73
+ epoch=None,
74
+ ):
75
+ model.train(True)
76
+ print(f"At epoch {epoch}")
77
+ for batch in tqdm(dataloader, mininterval=5.0):
78
+ batch = to_cuda(batch)
79
+ train_step(
80
+ train_batch=batch, model=model, objective=objective, optimizer=optimizer
81
+ )
82
+ lr_scheduler.step()
83
+ return {
84
+ "model": model,
85
+ "optimizer": optimizer,
86
+ "lr_scheduler": lr_scheduler,
87
+ "epoch": epoch,
88
+ }
89
+
90
+
91
+ def train_k_epochs(
92
+ start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
93
+ ):
94
+ for epoch in range(start_epoch, end_epoch + 1):
95
+ train_epoch(
96
+ dataloader=dataloader,
97
+ model=model,
98
+ objective=objective,
99
+ optimizer=optimizer,
100
+ lr_scheduler=lr_scheduler,
101
+ epoch=epoch,
102
+ )
third_party/RoMa/romatch/utils/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ pose_auc,
3
+ get_pose,
4
+ compute_relative_pose,
5
+ compute_pose_error,
6
+ estimate_pose,
7
+ estimate_pose_uncalibrated,
8
+ rotate_intrinsic,
9
+ get_tuple_transform_ops,
10
+ get_depth_tuple_transform_ops,
11
+ warp_kpts,
12
+ numpy_to_pil,
13
+ tensor_to_pil,
14
+ recover_pose,
15
+ signed_left_to_right_epipolar_distance,
16
+ )