Realcat commited on
Commit
9cde3b4
β€’
1 Parent(s): d64a873

update: roma

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. build_docker.sh +1 -0
  2. hloc/matchers/roma.py +3 -1
  3. third_party/{Roma β†’ RoMa}/.gitignore +0 -0
  4. third_party/RoMa/LICENSE +21 -0
  5. third_party/{Roma β†’ RoMa}/README.md +44 -15
  6. third_party/{Roma β†’ RoMa}/assets/sacre_coeur_A.jpg +0 -0
  7. third_party/{Roma β†’ RoMa}/assets/sacre_coeur_B.jpg +0 -0
  8. third_party/RoMa/assets/toronto_A.jpg +3 -0
  9. third_party/RoMa/assets/toronto_B.jpg +3 -0
  10. third_party/{Roma β†’ RoMa}/data/.gitignore +0 -0
  11. third_party/RoMa/demo/demo_3D_effect.py +46 -0
  12. third_party/{Roma β†’ RoMa}/demo/demo_fundamental.py +5 -10
  13. third_party/{Roma β†’ RoMa}/demo/demo_match.py +11 -14
  14. third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
  15. third_party/RoMa/demo/gif/.gitignore +2 -0
  16. third_party/{Roma β†’ RoMa}/pretrained/dinov2_vitl14_pretrain.pth +0 -0
  17. third_party/{Roma β†’ RoMa}/pretrained/roma_outdoor.pth +0 -0
  18. third_party/{Roma β†’ RoMa}/requirements.txt +1 -1
  19. third_party/{Roma β†’ RoMa}/roma/__init__.py +2 -2
  20. third_party/{Roma β†’ RoMa}/roma/benchmarks/__init__.py +0 -0
  21. third_party/{Roma β†’ RoMa}/roma/benchmarks/hpatches_sequences_homog_benchmark.py +7 -5
  22. third_party/{Roma β†’ RoMa}/roma/benchmarks/megadepth_dense_benchmark.py +11 -27
  23. third_party/{Roma β†’ RoMa}/roma/benchmarks/megadepth_pose_estimation_benchmark.py +19 -49
  24. third_party/{Roma β†’ RoMa}/roma/benchmarks/scannet_benchmark.py +30 -27
  25. third_party/{Roma β†’ RoMa}/roma/checkpointing/__init__.py +0 -0
  26. third_party/{Roma β†’ RoMa}/roma/checkpointing/checkpoint.py +4 -5
  27. third_party/{Roma β†’ RoMa}/roma/datasets/__init__.py +1 -1
  28. third_party/{Roma β†’ RoMa}/roma/datasets/megadepth.py +42 -81
  29. third_party/{Roma β†’ RoMa}/roma/datasets/scannet.py +72 -103
  30. third_party/RoMa/roma/losses/__init__.py +1 -0
  31. third_party/{Roma β†’ RoMa}/roma/losses/robust_loss.py +54 -119
  32. third_party/RoMa/roma/models/__init__.py +1 -0
  33. third_party/{Roma β†’ RoMa}/roma/models/encoders.py +7 -15
  34. third_party/{Roma β†’ RoMa}/roma/models/matcher.py +100 -21
  35. third_party/RoMa/roma/models/model_zoo/__init__.py +53 -0
  36. third_party/{Roma β†’ RoMa}/roma/models/model_zoo/roma_models.py +69 -84
  37. third_party/{Roma β†’ RoMa}/roma/models/transformer/__init__.py +14 -46
  38. third_party/{Roma β†’ RoMa}/roma/models/transformer/dinov2.py +23 -71
  39. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/__init__.py +0 -0
  40. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/attention.py +1 -5
  41. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/block.py +13 -45
  42. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/dino_head.py +2 -11
  43. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/drop_path.py +1 -3
  44. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/layer_scale.py +0 -0
  45. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/mlp.py +0 -0
  46. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/patch_embed.py +4 -16
  47. third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/swiglu_ffn.py +0 -0
  48. third_party/{Roma β†’ RoMa}/roma/train/__init__.py +0 -0
  49. third_party/{Roma β†’ RoMa}/roma/train/train.py +15 -39
  50. third_party/{Roma β†’ RoMa}/roma/utils/__init__.py +0 -0
build_docker.sh CHANGED
@@ -1,3 +1,4 @@
1
  docker build -t image-matching-webui:latest . --no-cache
2
  docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
3
  docker push vincentqin/image-matching-webui:latest
 
 
1
  docker build -t image-matching-webui:latest . --no-cache
2
  docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
3
  docker push vincentqin/image-matching-webui:latest
4
+
hloc/matchers/roma.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  from ..utils.base_model import BaseModel
7
  from .. import logger
8
 
9
- roma_path = Path(__file__).parent / "../../third_party/Roma"
10
  sys.path.append(str(roma_path))
11
 
12
  from roma.models.model_zoo.roma_models import roma_model
@@ -63,6 +63,8 @@ class Roma(BaseModel):
63
  weights=weights,
64
  dinov2_weights=dinov2_weights,
65
  device=device,
 
 
66
  )
67
  logger.info(f"Load Roma model done.")
68
 
 
6
  from ..utils.base_model import BaseModel
7
  from .. import logger
8
 
9
+ roma_path = Path(__file__).parent / "../../third_party/RoMa"
10
  sys.path.append(str(roma_path))
11
 
12
  from roma.models.model_zoo.roma_models import roma_model
 
63
  weights=weights,
64
  dinov2_weights=dinov2_weights,
65
  device=device,
66
+ #temp fix issue: https://github.com/Parskatt/RoMa/issues/26
67
+ amp_dtype=torch.float32,
68
  )
69
  logger.info(f"Load Roma model done.")
70
 
third_party/{Roma β†’ RoMa}/.gitignore RENAMED
File without changes
third_party/RoMa/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Johan Edstedt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/{Roma β†’ RoMa}/README.md RENAMED
@@ -1,14 +1,29 @@
1
- # RoMa: Revisiting Robust Losses for Dense Feature Matching
2
- ### [Project Page (TODO)](https://parskatt.github.io/RoMa) | [Paper](https://arxiv.org/abs/2305.15404)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  <br/>
4
-
5
- > RoMa: Revisiting Robust Lossses for Dense Feature Matching
6
- > [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Qiyu Sun](https://scholar.google.com/citations?user=HS2WuHkAAAAJ), [Georg BΓΆkman](https://scholar.google.com/citations?user=FUE3Wd0AAAAJ), [MΓ₯rten WadenbΓ€ck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ)
7
- > Arxiv 2023
8
-
9
- **NOTE!!! Very early code, there might be bugs**
10
-
11
- The codebase is in the [roma folder](roma).
12
 
13
  ## Setup/Install
14
  In your python environment (tested on Linux python 3.10), run:
@@ -32,6 +47,19 @@ F, mask = cv2.findFundamentalMat(
32
  kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
33
  )
34
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ## Reproducing Results
36
  The experiments in the paper are provided in the [experiments folder](experiments).
37
 
@@ -46,7 +74,8 @@ torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outd
46
  python experiments/roma_outdoor.py --only_test --benchmark mega-1500
47
  ```
48
  ## License
49
- Due to our dependency on [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE), the license is sadly non-commercial only for the moment.
 
50
 
51
  ## Acknowledgement
52
  Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
@@ -54,10 +83,10 @@ Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
54
  ## BibTeX
55
  If you find our models useful, please consider citing our paper!
56
  ```
57
- @article{edstedt2023roma,
58
- title={{RoMa}: Revisiting Robust Lossses for Dense Feature Matching},
59
  author={Edstedt, Johan and Sun, Qiyu and BΓΆkman, Georg and WadenbΓ€ck, MΓ₯rten and Felsberg, Michael},
60
- journal={arXiv preprint arXiv:2305.15404},
61
- year={2023}
62
  }
63
  ```
 
1
+ #
2
+ <p align="center">
3
+ <h1 align="center"> <ins>RoMa</ins> πŸ›οΈ:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
4
+ <p align="center">
5
+ <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
6
+ Β·
7
+ <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
8
+ Β·
9
+ <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg BΓΆkman</a>
10
+ Β·
11
+ <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">MΓ₯rten WadenbΓ€ck</a>
12
+ Β·
13
+ <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
14
+ </p>
15
+ <h2 align="center"><p>
16
+ <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
17
+ <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
18
+ </p></h2>
19
+ <div align="center"></div>
20
+ </p>
21
  <br/>
22
+ <p align="center">
23
+ <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
24
+ <br>
25
+ <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
26
+ </p>
 
 
 
27
 
28
  ## Setup/Install
29
  In your python environment (tested on Linux python 3.10), run:
 
47
  kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
48
  )
49
  ```
50
+
51
+ **New**: You can also match arbitrary keypoints with RoMa. A demo for this will be added soon.
52
+ ## Settings
53
+
54
+ ### Resolution
55
+ By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
56
+ You can change this at construction (see roma_outdoor kwargs).
57
+ You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
58
+
59
+ ### Sampling
60
+ roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
61
+
62
+
63
  ## Reproducing Results
64
  The experiments in the paper are provided in the [experiments folder](experiments).
65
 
 
74
  python experiments/roma_outdoor.py --only_test --benchmark mega-1500
75
  ```
76
  ## License
77
+ All our code except DINOv2 is MIT license.
78
+ DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
79
 
80
  ## Acknowledgement
81
  Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
 
83
  ## BibTeX
84
  If you find our models useful, please consider citing our paper!
85
  ```
86
+ @article{edstedt2024roma,
87
+ title={{RoMa: Robust Dense Feature Matching}},
88
  author={Edstedt, Johan and Sun, Qiyu and BΓΆkman, Georg and WadenbΓ€ck, MΓ₯rten and Felsberg, Michael},
89
+ journal={IEEE Conference on Computer Vision and Pattern Recognition},
90
+ year={2024}
91
  }
92
  ```
third_party/{Roma β†’ RoMa}/assets/sacre_coeur_A.jpg RENAMED
File without changes
third_party/{Roma β†’ RoMa}/assets/sacre_coeur_B.jpg RENAMED
File without changes
third_party/RoMa/assets/toronto_A.jpg ADDED

Git LFS Details

  • SHA256: 40270c227df93f0f31b55e0f2ff38eb24f47940c4800c83758a74a5dfd7346ec
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
third_party/RoMa/assets/toronto_B.jpg ADDED

Git LFS Details

  • SHA256: a2c07550ed87e40fca8c38076eb3a81395d760a88bf0b8615167704107deff2f
  • Pointer size: 131 Bytes
  • Size of remote file: 286 kB
third_party/{Roma β†’ RoMa}/data/.gitignore RENAMED
File without changes
third_party/RoMa/demo/demo_3D_effect.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from roma.utils.utils import tensor_to_pil
6
+
7
+ from roma import roma_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+
12
+ if __name__ == "__main__":
13
+ from argparse import ArgumentParser
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
16
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
17
+ parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
18
+
19
+ args, _ = parser.parse_known_args()
20
+ im1_path = args.im_A_path
21
+ im2_path = args.im_B_path
22
+ save_path = args.save_path
23
+
24
+ # Create model
25
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
26
+ roma_model.symmetric = False
27
+
28
+ H, W = roma_model.get_output_resolution()
29
+
30
+ im1 = Image.open(im1_path).resize((W, H))
31
+ im2 = Image.open(im2_path).resize((W, H))
32
+
33
+ # Match
34
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
35
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
36
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
37
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
38
+
39
+ coords_A, coords_B = warp[...,:2], warp[...,2:]
40
+ for i, x in enumerate(np.linspace(0,2*np.pi,200)):
41
+ t = (1 + np.cos(x))/2
42
+ interp_warp = (1-t)*coords_A + t*coords_B
43
+ im2_transfer_rgb = F.grid_sample(
44
+ x2[None], interp_warp[None], mode="bilinear", align_corners=False
45
+ )[0]
46
+ tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
third_party/{Roma β†’ RoMa}/demo/demo_fundamental.py RENAMED
@@ -3,12 +3,11 @@ import torch
3
  import cv2
4
  from roma import roma_outdoor
5
 
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
 
9
  if __name__ == "__main__":
10
  from argparse import ArgumentParser
11
-
12
  parser = ArgumentParser()
13
  parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
14
  parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
@@ -20,6 +19,7 @@ if __name__ == "__main__":
20
  # Create model
21
  roma_model = roma_outdoor(device=device)
22
 
 
23
  W_A, H_A = Image.open(im1_path).size
24
  W_B, H_B = Image.open(im2_path).size
25
 
@@ -27,12 +27,7 @@ if __name__ == "__main__":
27
  warp, certainty = roma_model.match(im1_path, im2_path, device=device)
28
  # Sample matches for estimation
29
  matches, certainty = roma_model.sample(warp, certainty)
30
- kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
31
  F, mask = cv2.findFundamentalMat(
32
- kpts1.cpu().numpy(),
33
- kpts2.cpu().numpy(),
34
- ransacReprojThreshold=0.2,
35
- method=cv2.USAC_MAGSAC,
36
- confidence=0.999999,
37
- maxIters=10000,
38
- )
 
3
  import cv2
4
  from roma import roma_outdoor
5
 
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
8
 
9
  if __name__ == "__main__":
10
  from argparse import ArgumentParser
 
11
  parser = ArgumentParser()
12
  parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
13
  parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
 
19
  # Create model
20
  roma_model = roma_outdoor(device=device)
21
 
22
+
23
  W_A, H_A = Image.open(im1_path).size
24
  W_B, H_B = Image.open(im2_path).size
25
 
 
27
  warp, certainty = roma_model.match(im1_path, im2_path, device=device)
28
  # Sample matches for estimation
29
  matches, certainty = roma_model.sample(warp, certainty)
30
+ kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
31
  F, mask = cv2.findFundamentalMat(
32
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
33
+ )
 
 
 
 
 
third_party/{Roma β†’ RoMa}/demo/demo_match.py RENAMED
@@ -4,20 +4,17 @@ import torch.nn.functional as F
4
  import numpy as np
5
  from roma.utils.utils import tensor_to_pil
6
 
7
- from roma import roma_indoor
8
 
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
 
12
  if __name__ == "__main__":
13
  from argparse import ArgumentParser
14
-
15
  parser = ArgumentParser()
16
- parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
17
- parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
18
- parser.add_argument(
19
- "--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str
20
- )
21
 
22
  args, _ = parser.parse_known_args()
23
  im1_path = args.im_A_path
@@ -25,7 +22,7 @@ if __name__ == "__main__":
25
  save_path = args.save_path
26
 
27
  # Create model
28
- roma_model = roma_indoor(device=device)
29
 
30
  H, W = roma_model.get_output_resolution()
31
 
@@ -39,12 +36,12 @@ if __name__ == "__main__":
39
  x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
40
 
41
  im2_transfer_rgb = F.grid_sample(
42
- x2[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
43
  )[0]
44
  im1_transfer_rgb = F.grid_sample(
45
- x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
46
  )[0]
47
- warp_im = torch.cat((im2_transfer_rgb, im1_transfer_rgb), dim=2)
48
- white_im = torch.ones((H, 2 * W), device=device)
49
  vis_im = certainty * warp_im + (1 - certainty) * white_im
50
- tensor_to_pil(vis_im, unnormalize=False).save(save_path)
 
4
  import numpy as np
5
  from roma.utils.utils import tensor_to_pil
6
 
7
+ from roma import roma_outdoor
8
 
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
 
12
  if __name__ == "__main__":
13
  from argparse import ArgumentParser
 
14
  parser = ArgumentParser()
15
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
16
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
17
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
 
 
18
 
19
  args, _ = parser.parse_known_args()
20
  im1_path = args.im_A_path
 
22
  save_path = args.save_path
23
 
24
  # Create model
25
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
26
 
27
  H, W = roma_model.get_output_resolution()
28
 
 
36
  x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
37
 
38
  im2_transfer_rgb = F.grid_sample(
39
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
40
  )[0]
41
  im1_transfer_rgb = F.grid_sample(
42
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
43
  )[0]
44
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
45
+ white_im = torch.ones((H,2*W),device=device)
46
  vis_im = certainty * warp_im + (1 - certainty) * white_im
47
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
third_party/RoMa/demo/demo_match_opencv_sift.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import numpy as np
5
+ import cv2 as cv
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+
10
+ if __name__ == "__main__":
11
+ from argparse import ArgumentParser
12
+ parser = ArgumentParser()
13
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
14
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
15
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
16
+
17
+ args, _ = parser.parse_known_args()
18
+ im1_path = args.im_A_path
19
+ im2_path = args.im_B_path
20
+ save_path = args.save_path
21
+
22
+ img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
23
+ img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
24
+ # Initiate SIFT detector
25
+ sift = cv.SIFT_create()
26
+ # find the keypoints and descriptors with SIFT
27
+ kp1, des1 = sift.detectAndCompute(img1,None)
28
+ kp2, des2 = sift.detectAndCompute(img2,None)
29
+ # BFMatcher with default params
30
+ bf = cv.BFMatcher()
31
+ matches = bf.knnMatch(des1,des2,k=2)
32
+ # Apply ratio test
33
+ good = []
34
+ for m,n in matches:
35
+ if m.distance < 0.75*n.distance:
36
+ good.append([m])
37
+ # cv.drawMatchesKnn expects list of lists as matches.
38
+ draw_params = dict(matchColor = (255,0,0), # draw matches in red color
39
+ singlePointColor = None,
40
+ flags = 2)
41
+
42
+ img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
43
+ Image.fromarray(img3).save("demo/sift_matches.png")
third_party/RoMa/demo/gif/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
third_party/{Roma β†’ RoMa}/pretrained/dinov2_vitl14_pretrain.pth RENAMED
File without changes
third_party/{Roma β†’ RoMa}/pretrained/roma_outdoor.pth RENAMED
File without changes
third_party/{Roma β†’ RoMa}/requirements.txt RENAMED
@@ -10,4 +10,4 @@ matplotlib
10
  h5py
11
  wandb
12
  timm
13
- xformers # Optional, used for memefficient attention
 
10
  h5py
11
  wandb
12
  timm
13
+ #xformers # Optional, used for memefficient attention
third_party/{Roma β†’ RoMa}/roma/__init__.py RENAMED
@@ -2,7 +2,7 @@ import os
2
  from .models import roma_outdoor, roma_indoor
3
 
4
  DEBUG_MODE = False
5
- RANK = int(os.environ.get("RANK", default=0))
6
  GLOBAL_STEP = 0
7
  STEP_SIZE = 1
8
- LOCAL_RANK = -1
 
2
  from .models import roma_outdoor, roma_indoor
3
 
4
  DEBUG_MODE = False
5
+ RANK = int(os.environ.get('RANK', default = 0))
6
  GLOBAL_STEP = 0
7
  STEP_SIZE = 1
8
+ LOCAL_RANK = -1
third_party/{Roma β†’ RoMa}/roma/benchmarks/__init__.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/benchmarks/hpatches_sequences_homog_benchmark.py RENAMED
@@ -53,7 +53,7 @@ class HpatchesHomogBenchmark:
53
  )
54
  return im_A_coords, im_A_to_im_B
55
 
56
- def benchmark(self, model, model_name=None):
57
  n_matches = []
58
  homog_dists = []
59
  for seq_idx, seq_name in tqdm(
@@ -69,7 +69,9 @@ class HpatchesHomogBenchmark:
69
  H = np.loadtxt(
70
  os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
71
  )
72
- dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
 
 
73
  good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
74
  pos_a, pos_b = self.convert_coordinates(
75
  good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
@@ -78,9 +80,9 @@ class HpatchesHomogBenchmark:
78
  H_pred, inliers = cv2.findHomography(
79
  pos_a,
80
  pos_b,
81
- method=cv2.RANSAC,
82
- confidence=0.99999,
83
- ransacReprojThreshold=3 * min(w2, h2) / 480,
84
  )
85
  except:
86
  H_pred = None
 
53
  )
54
  return im_A_coords, im_A_to_im_B
55
 
56
+ def benchmark(self, model, model_name = None):
57
  n_matches = []
58
  homog_dists = []
59
  for seq_idx, seq_name in tqdm(
 
69
  H = np.loadtxt(
70
  os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
71
  )
72
+ dense_matches, dense_certainty = model.match(
73
+ im_A_path, im_B_path
74
+ )
75
  good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
76
  pos_a, pos_b = self.convert_coordinates(
77
  good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
 
80
  H_pred, inliers = cv2.findHomography(
81
  pos_a,
82
  pos_b,
83
+ method = cv2.RANSAC,
84
+ confidence = 0.99999,
85
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
86
  )
87
  except:
88
  H_pred = None
third_party/{Roma β†’ RoMa}/roma/benchmarks/megadepth_dense_benchmark.py RENAMED
@@ -6,11 +6,8 @@ from roma.utils import warp_kpts
6
  from torch.utils.data import ConcatDataset
7
  import roma
8
 
9
-
10
  class MegadepthDenseBenchmark:
11
- def __init__(
12
- self, data_root="data/megadepth", h=384, w=512, num_samples=2000
13
- ) -> None:
14
  mega = MegadepthBuilder(data_root=data_root)
15
  self.dataset = ConcatDataset(
16
  mega.build_scenes(split="test_loftr", ht=h, wt=w)
@@ -52,15 +49,13 @@ class MegadepthDenseBenchmark:
52
  pck_3_tot = 0.0
53
  pck_5_tot = 0.0
54
  sampler = torch.utils.data.WeightedRandomSampler(
55
- torch.ones(len(self.dataset)),
56
- replacement=False,
57
- num_samples=self.num_samples,
58
  )
59
  B = batch_size
60
  dataloader = torch.utils.data.DataLoader(
61
  self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
62
  )
63
- for idx, data in tqdm.tqdm(enumerate(dataloader), disable=roma.RANK > 0):
64
  im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
65
  data["im_A"],
66
  data["im_B"],
@@ -77,36 +72,25 @@ class MegadepthDenseBenchmark:
77
  if roma.DEBUG_MODE:
78
  from roma.utils.utils import tensor_to_pil
79
  import torch.nn.functional as F
80
-
81
  path = "vis"
82
  H, W = model.get_output_resolution()
83
- white_im = torch.ones((B, 1, H, W), device="cuda")
84
  im_B_transfer_rgb = F.grid_sample(
85
- im_B.cuda(),
86
- matches[:, :, :W, 2:],
87
- mode="bilinear",
88
- align_corners=False,
89
  )
90
  warp_im = im_B_transfer_rgb
91
- c_b = certainty[
92
- :, None
93
- ] # (certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
94
  vis_im = c_b * warp_im + (1 - c_b) * white_im
95
  for b in range(B):
96
  import os
97
-
98
- os.makedirs(
99
- f"{path}/{model.name}/{idx}_{b}_{H}_{W}", exist_ok=True
100
- )
101
  tensor_to_pil(vis_im[b], unnormalize=True).save(
102
- f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg"
103
- )
104
  tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
105
- f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg"
106
- )
107
  tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
108
- f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg"
109
- )
110
 
111
  gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
112
  gd_tot + gd.mean(),
 
6
  from torch.utils.data import ConcatDataset
7
  import roma
8
 
 
9
  class MegadepthDenseBenchmark:
10
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
 
 
11
  mega = MegadepthBuilder(data_root=data_root)
12
  self.dataset = ConcatDataset(
13
  mega.build_scenes(split="test_loftr", ht=h, wt=w)
 
49
  pck_3_tot = 0.0
50
  pck_5_tot = 0.0
51
  sampler = torch.utils.data.WeightedRandomSampler(
52
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
 
 
53
  )
54
  B = batch_size
55
  dataloader = torch.utils.data.DataLoader(
56
  self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
57
  )
58
+ for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0):
59
  im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
60
  data["im_A"],
61
  data["im_B"],
 
72
  if roma.DEBUG_MODE:
73
  from roma.utils.utils import tensor_to_pil
74
  import torch.nn.functional as F
 
75
  path = "vis"
76
  H, W = model.get_output_resolution()
77
+ white_im = torch.ones((B,1,H,W),device="cuda")
78
  im_B_transfer_rgb = F.grid_sample(
79
+ im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
 
 
 
80
  )
81
  warp_im = im_B_transfer_rgb
82
+ c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
 
 
83
  vis_im = c_b * warp_im + (1 - c_b) * white_im
84
  for b in range(B):
85
  import os
86
+ os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
 
 
 
87
  tensor_to_pil(vis_im[b], unnormalize=True).save(
88
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
 
89
  tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
90
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
 
91
  tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
92
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
93
+
94
 
95
  gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
96
  gd_tot + gd.mean(),
third_party/{Roma β†’ RoMa}/roma/benchmarks/megadepth_pose_estimation_benchmark.py RENAMED
@@ -7,9 +7,8 @@ import torch.nn.functional as F
7
  import roma
8
  import kornia.geometry.epipolar as kepi
9
 
10
-
11
  class MegaDepthPoseEstimationBenchmark:
12
- def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
13
  if scene_names is None:
14
  self.scene_names = [
15
  "0015_0.1_0.3.npz",
@@ -26,22 +25,13 @@ class MegaDepthPoseEstimationBenchmark:
26
  ]
27
  self.data_root = data_root
28
 
29
- def benchmark(
30
- self,
31
- model,
32
- model_name=None,
33
- resolution=None,
34
- scale_intrinsics=True,
35
- calibrated=True,
36
- ):
37
- H, W = model.get_output_resolution()
38
  with torch.no_grad():
39
  data_root = self.data_root
40
  tot_e_t, tot_e_R, tot_e_pose = [], [], []
41
  thresholds = [5, 10, 20]
42
  for scene_ind in range(len(self.scenes)):
43
  import os
44
-
45
  scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
46
  scene = self.scenes[scene_ind]
47
  pairs = scene["pair_infos"]
@@ -58,22 +48,21 @@ class MegaDepthPoseEstimationBenchmark:
58
  T2 = poses[idx2].copy()
59
  R2, t2 = T2[:3, :3], T2[:3, 3]
60
  R, t = compute_relative_pose(R1, t1, R2, t2)
61
- T1_to_2 = np.concatenate((R, t[:, None]), axis=-1)
62
  im_A_path = f"{data_root}/{im_paths[idx1]}"
63
  im_B_path = f"{data_root}/{im_paths[idx2]}"
64
  dense_matches, dense_certainty = model.match(
65
  im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
66
  )
67
- sparse_matches, _ = model.sample(
68
- dense_matches, dense_certainty, 5000
69
  )
70
-
71
  im_A = Image.open(im_A_path)
72
  w1, h1 = im_A.size
73
  im_B = Image.open(im_B_path)
74
  w2, h2 = im_B.size
75
-
76
- if scale_intrinsics:
77
  scale1 = 1200 / max(w1, h1)
78
  scale2 = 1200 / max(w2, h2)
79
  w1, h1 = scale1 * w1, scale1 * h1
@@ -82,42 +71,23 @@ class MegaDepthPoseEstimationBenchmark:
82
  K1[:2] = K1[:2] * scale1
83
  K2[:2] = K2[:2] * scale2
84
 
85
- kpts1 = sparse_matches[:, :2]
86
- kpts1 = np.stack(
87
- (
88
- w1 * (kpts1[:, 0] + 1) / 2,
89
- h1 * (kpts1[:, 1] + 1) / 2,
90
- ),
91
- axis=-1,
92
- )
93
- kpts2 = sparse_matches[:, 2:]
94
- kpts2 = np.stack(
95
- (
96
- w2 * (kpts2[:, 0] + 1) / 2,
97
- h2 * (kpts2[:, 1] + 1) / 2,
98
- ),
99
- axis=-1,
100
- )
101
-
102
  for _ in range(5):
103
  shuffling = np.random.permutation(np.arange(len(kpts1)))
104
  kpts1 = kpts1[shuffling]
105
  kpts2 = kpts2[shuffling]
106
  try:
107
- threshold = 0.5
108
- if calibrated:
109
- norm_threshold = threshold / (
110
- np.mean(np.abs(K1[:2, :2]))
111
- + np.mean(np.abs(K2[:2, :2]))
112
- )
113
- R_est, t_est, mask = estimate_pose(
114
- kpts1,
115
- kpts2,
116
- K1,
117
- K2,
118
- norm_threshold,
119
- conf=0.99999,
120
- )
121
  T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
122
  e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
123
  e_pose = max(e_t, e_R)
 
7
  import roma
8
  import kornia.geometry.epipolar as kepi
9
 
 
10
  class MegaDepthPoseEstimationBenchmark:
11
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
12
  if scene_names is None:
13
  self.scene_names = [
14
  "0015_0.1_0.3.npz",
 
25
  ]
26
  self.data_root = data_root
27
 
28
+ def benchmark(self, model, model_name = None):
 
 
 
 
 
 
 
 
29
  with torch.no_grad():
30
  data_root = self.data_root
31
  tot_e_t, tot_e_R, tot_e_pose = [], [], []
32
  thresholds = [5, 10, 20]
33
  for scene_ind in range(len(self.scenes)):
34
  import os
 
35
  scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
36
  scene = self.scenes[scene_ind]
37
  pairs = scene["pair_infos"]
 
48
  T2 = poses[idx2].copy()
49
  R2, t2 = T2[:3, :3], T2[:3, 3]
50
  R, t = compute_relative_pose(R1, t1, R2, t2)
51
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
52
  im_A_path = f"{data_root}/{im_paths[idx1]}"
53
  im_B_path = f"{data_root}/{im_paths[idx2]}"
54
  dense_matches, dense_certainty = model.match(
55
  im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
56
  )
57
+ sparse_matches,_ = model.sample(
58
+ dense_matches, dense_certainty, 5_000
59
  )
60
+
61
  im_A = Image.open(im_A_path)
62
  w1, h1 = im_A.size
63
  im_B = Image.open(im_B_path)
64
  w2, h2 = im_B.size
65
+ if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
 
66
  scale1 = 1200 / max(w1, h1)
67
  scale2 = 1200 / max(w2, h2)
68
  w1, h1 = scale1 * w1, scale1 * h1
 
71
  K1[:2] = K1[:2] * scale1
72
  K2[:2] = K2[:2] * scale2
73
 
74
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
75
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  for _ in range(5):
77
  shuffling = np.random.permutation(np.arange(len(kpts1)))
78
  kpts1 = kpts1[shuffling]
79
  kpts2 = kpts2[shuffling]
80
  try:
81
+ threshold = 0.5
82
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
83
+ R_est, t_est, mask = estimate_pose(
84
+ kpts1,
85
+ kpts2,
86
+ K1,
87
+ K2,
88
+ norm_threshold,
89
+ conf=0.99999,
90
+ )
 
 
 
 
91
  T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
92
  e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
93
  e_pose = max(e_t, e_R)
third_party/{Roma β†’ RoMa}/roma/benchmarks/scannet_benchmark.py RENAMED
@@ -10,7 +10,7 @@ class ScanNetBenchmark:
10
  def __init__(self, data_root="data/scannet") -> None:
11
  self.data_root = data_root
12
 
13
- def benchmark(self, model, model_name=None):
14
  model.train(False)
15
  with torch.no_grad():
16
  data_root = self.data_root
@@ -24,20 +24,20 @@ class ScanNetBenchmark:
24
  scene = pairs[pairind]
25
  scene_name = f"scene0{scene[0]}_00"
26
  im_A_path = osp.join(
27
- self.data_root,
28
- "scans_test",
29
- scene_name,
30
- "color",
31
- f"{scene[2]}.jpg",
32
- )
33
  im_A = Image.open(im_A_path)
34
  im_B_path = osp.join(
35
- self.data_root,
36
- "scans_test",
37
- scene_name,
38
- "color",
39
- f"{scene[3]}.jpg",
40
- )
41
  im_B = Image.open(im_B_path)
42
  T_gt = rel_pose[pairind].reshape(3, 4)
43
  R, t = T_gt[:3, :3], T_gt[:3, 3]
@@ -76,20 +76,24 @@ class ScanNetBenchmark:
76
 
77
  offset = 0.5
78
  kpts1 = sparse_matches[:, :2]
79
- kpts1 = np.stack(
80
- (
81
- w1 * (kpts1[:, 0] + 1) / 2 - offset,
82
- h1 * (kpts1[:, 1] + 1) / 2 - offset,
83
- ),
84
- axis=-1,
 
 
85
  )
86
  kpts2 = sparse_matches[:, 2:]
87
- kpts2 = np.stack(
88
- (
89
- w2 * (kpts2[:, 0] + 1) / 2 - offset,
90
- h2 * (kpts2[:, 1] + 1) / 2 - offset,
91
- ),
92
- axis=-1,
 
 
93
  )
94
  for _ in range(5):
95
  shuffling = np.random.permutation(np.arange(len(kpts1)))
@@ -97,8 +101,7 @@ class ScanNetBenchmark:
97
  kpts2 = kpts2[shuffling]
98
  try:
99
  norm_threshold = 0.5 / (
100
- np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
101
- )
102
  R_est, t_est, mask = estimate_pose(
103
  kpts1,
104
  kpts2,
 
10
  def __init__(self, data_root="data/scannet") -> None:
11
  self.data_root = data_root
12
 
13
+ def benchmark(self, model, model_name = None):
14
  model.train(False)
15
  with torch.no_grad():
16
  data_root = self.data_root
 
24
  scene = pairs[pairind]
25
  scene_name = f"scene0{scene[0]}_00"
26
  im_A_path = osp.join(
27
+ self.data_root,
28
+ "scans_test",
29
+ scene_name,
30
+ "color",
31
+ f"{scene[2]}.jpg",
32
+ )
33
  im_A = Image.open(im_A_path)
34
  im_B_path = osp.join(
35
+ self.data_root,
36
+ "scans_test",
37
+ scene_name,
38
+ "color",
39
+ f"{scene[3]}.jpg",
40
+ )
41
  im_B = Image.open(im_B_path)
42
  T_gt = rel_pose[pairind].reshape(3, 4)
43
  R, t = T_gt[:3, :3], T_gt[:3, 3]
 
76
 
77
  offset = 0.5
78
  kpts1 = sparse_matches[:, :2]
79
+ kpts1 = (
80
+ np.stack(
81
+ (
82
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
83
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
84
+ ),
85
+ axis=-1,
86
+ )
87
  )
88
  kpts2 = sparse_matches[:, 2:]
89
+ kpts2 = (
90
+ np.stack(
91
+ (
92
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
93
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
94
+ ),
95
+ axis=-1,
96
+ )
97
  )
98
  for _ in range(5):
99
  shuffling = np.random.permutation(np.arange(len(kpts1)))
 
101
  kpts2 = kpts2[shuffling]
102
  try:
103
  norm_threshold = 0.5 / (
104
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
 
105
  R_est, t_est, mask = estimate_pose(
106
  kpts1,
107
  kpts2,
third_party/{Roma β†’ RoMa}/roma/checkpointing/__init__.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/checkpointing/checkpoint.py RENAMED
@@ -7,7 +7,6 @@ import gc
7
 
8
  import roma
9
 
10
-
11
  class CheckPoint:
12
  def __init__(self, dir=None, name="tmp"):
13
  self.name = name
@@ -20,7 +19,7 @@ class CheckPoint:
20
  optimizer,
21
  lr_scheduler,
22
  n,
23
- ):
24
  if roma.RANK == 0:
25
  assert model is not None
26
  if isinstance(model, (DataParallel, DistributedDataParallel)):
@@ -33,14 +32,14 @@ class CheckPoint:
33
  }
34
  torch.save(states, self.dir + self.name + f"_latest.pth")
35
  logger.info(f"Saved states {list(states.keys())}, at step {n}")
36
-
37
  def load(
38
  self,
39
  model,
40
  optimizer,
41
  lr_scheduler,
42
  n,
43
- ):
44
  if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
45
  states = torch.load(self.dir + self.name + f"_latest.pth")
46
  if "model" in states:
@@ -58,4 +57,4 @@ class CheckPoint:
58
  del states
59
  gc.collect()
60
  torch.cuda.empty_cache()
61
- return model, optimizer, lr_scheduler, n
 
7
 
8
  import roma
9
 
 
10
  class CheckPoint:
11
  def __init__(self, dir=None, name="tmp"):
12
  self.name = name
 
19
  optimizer,
20
  lr_scheduler,
21
  n,
22
+ ):
23
  if roma.RANK == 0:
24
  assert model is not None
25
  if isinstance(model, (DataParallel, DistributedDataParallel)):
 
32
  }
33
  torch.save(states, self.dir + self.name + f"_latest.pth")
34
  logger.info(f"Saved states {list(states.keys())}, at step {n}")
35
+
36
  def load(
37
  self,
38
  model,
39
  optimizer,
40
  lr_scheduler,
41
  n,
42
+ ):
43
  if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
44
  states = torch.load(self.dir + self.name + f"_latest.pth")
45
  if "model" in states:
 
57
  del states
58
  gc.collect()
59
  torch.cuda.empty_cache()
60
+ return model, optimizer, lr_scheduler, n
third_party/{Roma β†’ RoMa}/roma/datasets/__init__.py RENAMED
@@ -1,2 +1,2 @@
1
  from .megadepth import MegadepthBuilder
2
- from .scannet import ScanNetBuilder
 
1
  from .megadepth import MegadepthBuilder
2
+ from .scannet import ScanNetBuilder
third_party/{Roma β†’ RoMa}/roma/datasets/megadepth.py RENAMED
@@ -10,7 +10,6 @@ import roma
10
  from roma.utils import *
11
  import math
12
 
13
-
14
  class MegadepthScene:
15
  def __init__(
16
  self,
@@ -23,20 +22,18 @@ class MegadepthScene:
23
  shake_t=0,
24
  rot_prob=0.0,
25
  normalize=True,
26
- max_num_pairs=100_000,
27
- scene_name=None,
28
- use_horizontal_flip_aug=False,
29
- use_single_horizontal_flip_aug=False,
30
- colorjiggle_params=None,
31
- random_eraser=None,
32
- use_randaug=False,
33
- randaug_params=None,
34
- randomize_size=False,
35
  ) -> None:
36
  self.data_root = data_root
37
- self.scene_name = (
38
- os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
39
- )
40
  self.image_paths = scene_info["image_paths"]
41
  self.depth_paths = scene_info["depth_paths"]
42
  self.intrinsics = scene_info["intrinsics"]
@@ -54,18 +51,18 @@ class MegadepthScene:
54
  self.overlaps = self.overlaps[pairinds]
55
  if randomize_size:
56
  area = ht * wt
57
- s = int(16 * (math.sqrt(area) // 16))
58
- sizes = ((ht, wt), (s, s), (wt, ht))
59
  choice = roma.RANK % 3
60
- ht, wt = sizes[choice]
61
  # counts, bins = np.histogram(self.overlaps,20)
62
  # print(counts)
63
  self.im_transform_ops = get_tuple_transform_ops(
64
- resize=(ht, wt),
65
- normalize=normalize,
66
- colorjiggle_params=colorjiggle_params,
67
  )
68
- self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt))
 
 
69
  self.wt, self.ht = wt, ht
70
  self.shake_t = shake_t
71
  self.random_eraser = random_eraser
@@ -78,19 +75,17 @@ class MegadepthScene:
78
  def load_im(self, im_path):
79
  im = Image.open(im_path)
80
  return im
81
-
82
- def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
83
  im_A = im_A.flip(-1)
84
  im_B = im_B.flip(-1)
85
- depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
86
- flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
87
- K_A.device
88
- )
89
- K_A = flip_mat @ K_A
90
- K_B = flip_mat @ K_B
91
-
92
  return im_A, im_B, depth_A, depth_B, K_A, K_B
93
-
94
  def load_depth(self, depth_ref, crop=None):
95
  depth = np.array(h5py.File(depth_ref, "r")["depth"])
96
  return torch.from_numpy(depth)
@@ -145,31 +140,29 @@ class MegadepthScene:
145
  depth_A, depth_B = self.depth_transform_ops(
146
  (depth_A[None, None], depth_B[None, None])
147
  )
148
-
149
- [im_A, im_B, depth_A, depth_B], t = self.rand_shake(
150
- im_A, im_B, depth_A, depth_B
151
- )
152
  K1[:2, 2] += t
153
  K2[:2, 2] += t
154
-
155
  im_A, im_B = im_A[None], im_B[None]
156
  if self.random_eraser is not None:
157
  im_A, depth_A = self.random_eraser(im_A, depth_A)
158
  im_B, depth_B = self.random_eraser(im_B, depth_B)
159
-
160
  if self.use_horizontal_flip_aug:
161
  if np.random.rand() > 0.5:
162
- im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
163
- im_A, im_B, depth_A, depth_B, K1, K2
164
- )
165
  if self.use_single_horizontal_flip_aug:
166
  if np.random.rand() > 0.5:
167
  im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
168
-
169
  if roma.DEBUG_MODE:
170
- tensor_to_pil(im_A[0], unnormalize=True).save(f"vis/im_A.jpg")
171
- tensor_to_pil(im_B[0], unnormalize=True).save(f"vis/im_B.jpg")
172
-
 
 
173
  data_dict = {
174
  "im_A": im_A[0],
175
  "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
@@ -182,53 +175,25 @@ class MegadepthScene:
182
  "T_1to2": T_1to2,
183
  "im_A_path": im_A_ref,
184
  "im_B_path": im_B_ref,
 
185
  }
186
  return data_dict
187
 
188
 
189
  class MegadepthBuilder:
190
- def __init__(
191
- self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True
192
- ) -> None:
193
  self.data_root = data_root
194
  self.scene_info_root = os.path.join(data_root, "prep_scene_info")
195
  self.all_scenes = os.listdir(self.scene_info_root)
196
  self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
197
  # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
198
- self.loftr_ignore_scenes = set(
199
- [
200
- "0121.npy",
201
- "0133.npy",
202
- "0168.npy",
203
- "0178.npy",
204
- "0229.npy",
205
- "0349.npy",
206
- "0412.npy",
207
- "0430.npy",
208
- "0443.npy",
209
- "1001.npy",
210
- "5014.npy",
211
- "5015.npy",
212
- "5016.npy",
213
- ]
214
- )
215
- self.imc21_scenes = set(
216
- [
217
- "0008.npy",
218
- "0019.npy",
219
- "0021.npy",
220
- "0024.npy",
221
- "0025.npy",
222
- "0032.npy",
223
- "0063.npy",
224
- "1589.npy",
225
- ]
226
- )
227
  self.test_scenes_loftr = ["0015.npy", "0022.npy"]
228
  self.loftr_ignore = loftr_ignore
229
  self.imc21_ignore = imc21_ignore
230
 
231
- def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs):
232
  if split == "train":
233
  scene_names = set(self.all_scenes) - set(self.test_scenes)
234
  elif split == "train_loftr":
@@ -252,11 +217,7 @@ class MegadepthBuilder:
252
  ).item()
253
  scenes.append(
254
  MegadepthScene(
255
- self.data_root,
256
- scene_info,
257
- min_overlap=min_overlap,
258
- scene_name=scene_name,
259
- **kwargs,
260
  )
261
  )
262
  return scenes
 
10
  from roma.utils import *
11
  import math
12
 
 
13
  class MegadepthScene:
14
  def __init__(
15
  self,
 
22
  shake_t=0,
23
  rot_prob=0.0,
24
  normalize=True,
25
+ max_num_pairs = 100_000,
26
+ scene_name = None,
27
+ use_horizontal_flip_aug = False,
28
+ use_single_horizontal_flip_aug = False,
29
+ colorjiggle_params = None,
30
+ random_eraser = None,
31
+ use_randaug = False,
32
+ randaug_params = None,
33
+ randomize_size = False,
34
  ) -> None:
35
  self.data_root = data_root
36
+ self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
 
 
37
  self.image_paths = scene_info["image_paths"]
38
  self.depth_paths = scene_info["depth_paths"]
39
  self.intrinsics = scene_info["intrinsics"]
 
51
  self.overlaps = self.overlaps[pairinds]
52
  if randomize_size:
53
  area = ht * wt
54
+ s = int(16 * (math.sqrt(area)//16))
55
+ sizes = ((ht,wt), (s,s), (wt,ht))
56
  choice = roma.RANK % 3
57
+ ht, wt = sizes[choice]
58
  # counts, bins = np.histogram(self.overlaps,20)
59
  # print(counts)
60
  self.im_transform_ops = get_tuple_transform_ops(
61
+ resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
 
 
62
  )
63
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
64
+ resize=(ht, wt)
65
+ )
66
  self.wt, self.ht = wt, ht
67
  self.shake_t = shake_t
68
  self.random_eraser = random_eraser
 
75
  def load_im(self, im_path):
76
  im = Image.open(im_path)
77
  return im
78
+
79
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
80
  im_A = im_A.flip(-1)
81
  im_B = im_B.flip(-1)
82
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
83
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
84
+ K_A = flip_mat@K_A
85
+ K_B = flip_mat@K_B
86
+
 
 
87
  return im_A, im_B, depth_A, depth_B, K_A, K_B
88
+
89
  def load_depth(self, depth_ref, crop=None):
90
  depth = np.array(h5py.File(depth_ref, "r")["depth"])
91
  return torch.from_numpy(depth)
 
140
  depth_A, depth_B = self.depth_transform_ops(
141
  (depth_A[None, None], depth_B[None, None])
142
  )
143
+
144
+ [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
 
 
145
  K1[:2, 2] += t
146
  K2[:2, 2] += t
147
+
148
  im_A, im_B = im_A[None], im_B[None]
149
  if self.random_eraser is not None:
150
  im_A, depth_A = self.random_eraser(im_A, depth_A)
151
  im_B, depth_B = self.random_eraser(im_B, depth_B)
152
+
153
  if self.use_horizontal_flip_aug:
154
  if np.random.rand() > 0.5:
155
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
 
 
156
  if self.use_single_horizontal_flip_aug:
157
  if np.random.rand() > 0.5:
158
  im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
159
+
160
  if roma.DEBUG_MODE:
161
+ tensor_to_pil(im_A[0], unnormalize=True).save(
162
+ f"vis/im_A.jpg")
163
+ tensor_to_pil(im_B[0], unnormalize=True).save(
164
+ f"vis/im_B.jpg")
165
+
166
  data_dict = {
167
  "im_A": im_A[0],
168
  "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
 
175
  "T_1to2": T_1to2,
176
  "im_A_path": im_A_ref,
177
  "im_B_path": im_B_ref,
178
+
179
  }
180
  return data_dict
181
 
182
 
183
  class MegadepthBuilder:
184
+ def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
 
 
185
  self.data_root = data_root
186
  self.scene_info_root = os.path.join(data_root, "prep_scene_info")
187
  self.all_scenes = os.listdir(self.scene_info_root)
188
  self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
189
  # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
190
+ self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
191
+ self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  self.test_scenes_loftr = ["0015.npy", "0022.npy"]
193
  self.loftr_ignore = loftr_ignore
194
  self.imc21_ignore = imc21_ignore
195
 
196
+ def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
197
  if split == "train":
198
  scene_names = set(self.all_scenes) - set(self.test_scenes)
199
  elif split == "train_loftr":
 
217
  ).item()
218
  scenes.append(
219
  MegadepthScene(
220
+ self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
 
 
 
 
221
  )
222
  )
223
  return scenes
third_party/{Roma β†’ RoMa}/roma/datasets/scannet.py RENAMED
@@ -5,7 +5,10 @@ import cv2
5
  import h5py
6
  import numpy as np
7
  import torch
8
- from torch.utils.data import Dataset, DataLoader, ConcatDataset
 
 
 
9
 
10
  import torchvision.transforms.functional as tvf
11
  import kornia.augmentation as K
@@ -16,36 +19,22 @@ from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
16
  from roma.utils.transforms import GeometricSequential
17
  from tqdm import tqdm
18
 
19
-
20
  class ScanNetScene:
21
- def __init__(
22
- self,
23
- data_root,
24
- scene_info,
25
- ht=384,
26
- wt=512,
27
- min_overlap=0.0,
28
- shake_t=0,
29
- rot_prob=0.0,
30
- use_horizontal_flip_aug=False,
31
- ) -> None:
32
- self.scene_root = osp.join(data_root, "scans", "scans_train")
33
- self.data_names = scene_info["name"]
34
- self.overlaps = scene_info["score"]
35
  # Only sample 10s
36
- valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0
37
  self.overlaps = self.overlaps[valid]
38
  self.data_names = self.data_names[valid]
39
  if len(self.data_names) > 10000:
40
- pairinds = np.random.choice(
41
- np.arange(0, len(self.data_names)), 10000, replace=False
42
- )
43
  self.data_names = self.data_names[pairinds]
44
  self.overlaps = self.overlaps[pairinds]
45
  self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
46
- self.depth_transform_ops = get_depth_tuple_transform_ops(
47
- resize=(ht, wt), normalize=False
48
- )
49
  self.wt, self.ht = wt, ht
50
  self.shake_t = shake_t
51
  self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
@@ -54,7 +43,7 @@ class ScanNetScene:
54
  def load_im(self, im_B, crop=None):
55
  im = Image.open(im_B)
56
  return im
57
-
58
  def load_depth(self, depth_ref, crop=None):
59
  depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
60
  depth = depth / 1000
@@ -63,73 +52,64 @@ class ScanNetScene:
63
 
64
  def __len__(self):
65
  return len(self.data_names)
66
-
67
  def scale_intrinsic(self, K, wi, hi):
68
- sx, sy = self.wt / wi, self.ht / hi
69
- sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
70
- return sK @ K
 
 
71
 
72
- def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
73
  im_A = im_A.flip(-1)
74
  im_B = im_B.flip(-1)
75
- depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
76
- flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
77
- K_A.device
78
- )
79
- K_A = flip_mat @ K_A
80
- K_B = flip_mat @ K_B
81
-
82
  return im_A, im_B, depth_A, depth_B, K_A, K_B
83
-
84
- def read_scannet_pose(self, path):
85
- """Read ScanNet's Camera2World pose and transform it to World2Camera.
86
-
87
  Returns:
88
  pose_w2c (np.ndarray): (4, 4)
89
  """
90
- cam2world = np.loadtxt(path, delimiter=" ")
91
  world2cam = np.linalg.inv(cam2world)
92
  return world2cam
93
 
94
- def read_scannet_intrinsic(self, path):
95
- """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
96
- intrinsic = np.loadtxt(path, delimiter=" ")
97
- return torch.tensor(intrinsic[:-1, :-1], dtype=torch.float)
 
 
98
 
99
  def __getitem__(self, pair_idx):
100
  # read intrinsics of original size
101
  data_name = self.data_names[pair_idx]
102
  scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
103
- scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
104
-
105
  # read the intrinsic of depthmap
106
- K1 = K2 = self.read_scannet_intrinsic(
107
- osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt")
108
- ) # the depth K is not the same, but doesnt really matter
109
  # read and compute relative poses
110
- T1 = self.read_scannet_pose(
111
- osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt")
112
- )
113
- T2 = self.read_scannet_pose(
114
- osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt")
115
- )
116
- T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
117
- :4, :4
118
- ] # (4, 4)
119
 
120
  # Load positive pair data
121
- im_A_ref = os.path.join(
122
- self.scene_root, scene_name, "color", f"{stem_name_1}.jpg"
123
- )
124
- im_B_ref = os.path.join(
125
- self.scene_root, scene_name, "color", f"{stem_name_2}.jpg"
126
- )
127
- depth_A_ref = os.path.join(
128
- self.scene_root, scene_name, "depth", f"{stem_name_1}.png"
129
- )
130
- depth_B_ref = os.path.join(
131
- self.scene_root, scene_name, "depth", f"{stem_name_2}.png"
132
- )
133
 
134
  im_A = self.load_im(im_A_ref)
135
  im_B = self.load_im(im_B_ref)
@@ -141,51 +121,40 @@ class ScanNetScene:
141
  K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
142
  # Process images
143
  im_A, im_B = self.im_transform_ops((im_A, im_B))
144
- depth_A, depth_B = self.depth_transform_ops(
145
- (depth_A[None, None], depth_B[None, None])
146
- )
147
  if self.use_horizontal_flip_aug:
148
  if np.random.rand() > 0.5:
149
- im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
150
- im_A, im_B, depth_A, depth_B, K1, K2
151
- )
152
-
153
- data_dict = {
154
- "im_A": im_A,
155
- "im_B": im_B,
156
- "im_A_depth": depth_A[0, 0],
157
- "im_B_depth": depth_B[0, 0],
158
- "K1": K1,
159
- "K2": K2,
160
- "T_1to2": T_1to2,
161
- }
162
  return data_dict
163
 
164
 
165
  class ScanNetBuilder:
166
- def __init__(self, data_root="data/scannet") -> None:
167
  self.data_root = data_root
168
- self.scene_info_root = os.path.join(data_root, "scannet_indices")
169
  self.all_scenes = os.listdir(self.scene_info_root)
170
-
171
- def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
172
  # Note: split doesn't matter here as we always use same scannet_train scenes
173
  scene_names = self.all_scenes
174
  scenes = []
175
- for scene_name in tqdm(scene_names, disable=roma.RANK > 0):
176
- scene_info = np.load(
177
- os.path.join(self.scene_info_root, scene_name), allow_pickle=True
178
- )
179
- scenes.append(
180
- ScanNetScene(
181
- self.data_root, scene_info, min_overlap=min_overlap, **kwargs
182
- )
183
- )
184
  return scenes
185
-
186
- def weight_scenes(self, concat_dataset, alpha=0.5):
187
  ns = []
188
  for d in concat_dataset.datasets:
189
  ns.append(len(d))
190
- ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
191
  return ws
 
5
  import h5py
6
  import numpy as np
7
  import torch
8
+ from torch.utils.data import (
9
+ Dataset,
10
+ DataLoader,
11
+ ConcatDataset)
12
 
13
  import torchvision.transforms.functional as tvf
14
  import kornia.augmentation as K
 
19
  from roma.utils.transforms import GeometricSequential
20
  from tqdm import tqdm
21
 
 
22
  class ScanNetScene:
23
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
24
+ ) -> None:
25
+ self.scene_root = osp.join(data_root,"scans","scans_train")
26
+ self.data_names = scene_info['name']
27
+ self.overlaps = scene_info['score']
 
 
 
 
 
 
 
 
 
28
  # Only sample 10s
29
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
30
  self.overlaps = self.overlaps[valid]
31
  self.data_names = self.data_names[valid]
32
  if len(self.data_names) > 10000:
33
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
 
 
34
  self.data_names = self.data_names[pairinds]
35
  self.overlaps = self.overlaps[pairinds]
36
  self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
37
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
 
 
38
  self.wt, self.ht = wt, ht
39
  self.shake_t = shake_t
40
  self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
 
43
  def load_im(self, im_B, crop=None):
44
  im = Image.open(im_B)
45
  return im
46
+
47
  def load_depth(self, depth_ref, crop=None):
48
  depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
49
  depth = depth / 1000
 
52
 
53
  def __len__(self):
54
  return len(self.data_names)
55
+
56
  def scale_intrinsic(self, K, wi, hi):
57
+ sx, sy = self.wt / wi, self.ht / hi
58
+ sK = torch.tensor([[sx, 0, 0],
59
+ [0, sy, 0],
60
+ [0, 0, 1]])
61
+ return sK@K
62
 
63
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
64
  im_A = im_A.flip(-1)
65
  im_B = im_B.flip(-1)
66
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
67
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
68
+ K_A = flip_mat@K_A
69
+ K_B = flip_mat@K_B
70
+
 
 
71
  return im_A, im_B, depth_A, depth_B, K_A, K_B
72
+ def read_scannet_pose(self,path):
73
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
74
+
 
75
  Returns:
76
  pose_w2c (np.ndarray): (4, 4)
77
  """
78
+ cam2world = np.loadtxt(path, delimiter=' ')
79
  world2cam = np.linalg.inv(cam2world)
80
  return world2cam
81
 
82
+
83
+ def read_scannet_intrinsic(self,path):
84
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
85
+ """
86
+ intrinsic = np.loadtxt(path, delimiter=' ')
87
+ return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
88
 
89
  def __getitem__(self, pair_idx):
90
  # read intrinsics of original size
91
  data_name = self.data_names[pair_idx]
92
  scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
93
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
94
+
95
  # read the intrinsic of depthmap
96
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
97
+ scene_name,
98
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
99
  # read and compute relative poses
100
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
101
+ scene_name,
102
+ 'pose', f'{stem_name_1}.txt'))
103
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
104
+ scene_name,
105
+ 'pose', f'{stem_name_2}.txt'))
106
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
 
 
107
 
108
  # Load positive pair data
109
+ im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
110
+ im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
111
+ depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
112
+ depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
 
 
 
 
 
 
 
 
113
 
114
  im_A = self.load_im(im_A_ref)
115
  im_B = self.load_im(im_B_ref)
 
121
  K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
122
  # Process images
123
  im_A, im_B = self.im_transform_ops((im_A, im_B))
124
+ depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
 
 
125
  if self.use_horizontal_flip_aug:
126
  if np.random.rand() > 0.5:
127
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
128
+
129
+ data_dict = {'im_A': im_A,
130
+ 'im_B': im_B,
131
+ 'im_A_depth': depth_A[0,0],
132
+ 'im_B_depth': depth_B[0,0],
133
+ 'K1': K1,
134
+ 'K2': K2,
135
+ 'T_1to2':T_1to2,
136
+ }
 
 
 
137
  return data_dict
138
 
139
 
140
  class ScanNetBuilder:
141
+ def __init__(self, data_root = 'data/scannet') -> None:
142
  self.data_root = data_root
143
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
144
  self.all_scenes = os.listdir(self.scene_info_root)
145
+
146
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
147
  # Note: split doesn't matter here as we always use same scannet_train scenes
148
  scene_names = self.all_scenes
149
  scenes = []
150
+ for scene_name in tqdm(scene_names, disable = roma.RANK > 0):
151
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
152
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
 
 
 
 
 
 
153
  return scenes
154
+
155
+ def weight_scenes(self, concat_dataset, alpha=.5):
156
  ns = []
157
  for d in concat_dataset.datasets:
158
  ns.append(len(d))
159
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
160
  return ws
third_party/RoMa/roma/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .robust_loss import RobustLosses
third_party/{Roma β†’ RoMa}/roma/losses/robust_loss.py RENAMED
@@ -7,7 +7,6 @@ import wandb
7
  import roma
8
  import math
9
 
10
-
11
  class RobustLosses(nn.Module):
12
  def __init__(
13
  self,
@@ -18,12 +17,12 @@ class RobustLosses(nn.Module):
18
  local_loss=True,
19
  local_dist=4.0,
20
  local_largest_scale=8,
21
- smooth_mask=False,
22
- depth_interpolation_mode="bilinear",
23
- mask_depth_loss=False,
24
- relative_depth_error_threshold=0.05,
25
- alpha=1.0,
26
- c=1e-3,
27
  ):
28
  super().__init__()
29
  self.robust = robust # measured in pixels
@@ -46,103 +45,68 @@ class RobustLosses(nn.Module):
46
  B, C, H, W = scale_gm_cls.shape
47
  device = x2.device
48
  cls_res = round(math.sqrt(C))
49
- G = torch.meshgrid(
50
- *[
51
- torch.linspace(
52
- -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
53
- )
54
- for _ in range(2)
55
- ]
56
- )
57
- G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2)
58
- GT = (
59
- (G[None, :, None, None, :] - x2[:, None])
60
- .norm(dim=-1)
61
- .min(dim=1)
62
- .indices
63
- )
64
- cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99]
65
  if not torch.any(cls_loss):
66
- cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere
67
 
68
- certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:, 0], prob)
69
  losses = {
70
  f"gm_certainty_loss_{scale}": certainty_loss.mean(),
71
  f"gm_cls_loss_{scale}": cls_loss.mean(),
72
  }
73
- wandb.log(losses, step=roma.GLOBAL_STEP)
74
  return losses
75
 
76
- def delta_cls_loss(
77
- self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale
78
- ):
79
  with torch.no_grad():
80
  B, C, H, W = delta_cls.shape
81
  device = x2.device
82
  cls_res = round(math.sqrt(C))
83
- G = torch.meshgrid(
84
- *[
85
- torch.linspace(
86
- -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
87
- )
88
- for _ in range(2)
89
- ]
90
- )
91
- G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale
92
- GT = (
93
- (G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None])
94
- .norm(dim=-1)
95
- .min(dim=1)
96
- .indices
97
- )
98
- cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99]
99
  if not torch.any(cls_loss):
100
- cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere
101
- certainty_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
102
  losses = {
103
  f"delta_certainty_loss_{scale}": certainty_loss.mean(),
104
  f"delta_cls_loss_{scale}": cls_loss.mean(),
105
  }
106
- wandb.log(losses, step=roma.GLOBAL_STEP)
107
  return losses
108
 
109
- def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"):
110
- epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1)
111
  if scale == 1:
112
- pck_05 = (epe[prob > 0.99] < 0.5 * (2 / 512)).float().mean()
113
- wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP)
114
 
115
  ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
116
  a = self.alpha
117
  cs = self.c * scale
118
  x = epe[prob > 0.99]
119
- reg_loss = cs**a * ((x / (cs)) ** 2 + 1**2) ** (a / 2)
120
  if not torch.any(reg_loss):
121
- reg_loss = ce_loss * 0.0 # Prevent issues where prob is 0 everywhere
122
  losses = {
123
  f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
124
  f"{mode}_regression_loss_{scale}": reg_loss.mean(),
125
  }
126
- wandb.log(losses, step=roma.GLOBAL_STEP)
127
  return losses
128
 
129
  def forward(self, corresps, batch):
130
  scales = list(corresps.keys())
131
  tot_loss = 0.0
132
  # scale_weights due to differences in scale for regression gradients and classification gradients
133
- scale_weights = {1: 1, 2: 1, 4: 1, 8: 1, 16: 1}
134
  for scale in scales:
135
  scale_corresps = corresps[scale]
136
- (
137
- scale_certainty,
138
- flow_pre_delta,
139
- delta_cls,
140
- offset_scale,
141
- scale_gm_cls,
142
- scale_gm_certainty,
143
- flow,
144
- scale_gm_flow,
145
- ) = (
146
  scale_corresps["certainty"],
147
  scale_corresps["flow_pre_delta"],
148
  scale_corresps.get("delta_cls"),
@@ -151,72 +115,43 @@ class RobustLosses(nn.Module):
151
  scale_corresps.get("gm_certainty"),
152
  scale_corresps["flow"],
153
  scale_corresps.get("gm_flow"),
 
154
  )
155
  flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
156
  b, h, w, d = flow_pre_delta.shape
157
- gt_warp, gt_prob = get_gt_warp(
158
- batch["im_A_depth"],
159
- batch["im_B_depth"],
160
- batch["T_1to2"],
161
- batch["K1"],
162
- batch["K2"],
163
- H=h,
164
- W=w,
165
- )
166
  x2 = gt_warp.float()
167
  prob = gt_prob
168
-
169
  if self.local_largest_scale >= scale:
170
  prob = prob * (
171
- F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[
172
- :, 0
173
- ]
174
- < (2 / 512) * (self.local_dist[scale] * scale)
175
- )
176
-
177
  if scale_gm_cls is not None:
178
- gm_cls_losses = self.gm_cls_loss(
179
- x2, prob, scale_gm_cls, scale_gm_certainty, scale
180
- )
181
- gm_loss = (
182
- self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"]
183
- + gm_cls_losses[f"gm_cls_loss_{scale}"]
184
- )
185
  tot_loss = tot_loss + scale_weights[scale] * gm_loss
186
  elif scale_gm_flow is not None:
187
- gm_flow_losses = self.regression_loss(
188
- x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode="gm"
189
- )
190
- gm_loss = (
191
- self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"]
192
- + gm_flow_losses[f"gm_regression_loss_{scale}"]
193
- )
194
  tot_loss = tot_loss + scale_weights[scale] * gm_loss
195
-
196
  if delta_cls is not None:
197
- delta_cls_losses = self.delta_cls_loss(
198
- x2,
199
- prob,
200
- flow_pre_delta,
201
- delta_cls,
202
- scale_certainty,
203
- scale,
204
- offset_scale,
205
- )
206
- delta_cls_loss = (
207
- self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"]
208
- + delta_cls_losses[f"delta_cls_loss_{scale}"]
209
- )
210
  tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
211
  else:
212
- delta_regression_losses = self.regression_loss(
213
- x2, prob, flow, scale_certainty, scale
214
- )
215
- reg_loss = (
216
- self.ce_weight
217
- * delta_regression_losses[f"delta_certainty_loss_{scale}"]
218
- + delta_regression_losses[f"delta_regression_loss_{scale}"]
219
- )
220
  tot_loss = tot_loss + scale_weights[scale] * reg_loss
221
- prev_epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1).detach()
222
  return tot_loss
 
7
  import roma
8
  import math
9
 
 
10
  class RobustLosses(nn.Module):
11
  def __init__(
12
  self,
 
17
  local_loss=True,
18
  local_dist=4.0,
19
  local_largest_scale=8,
20
+ smooth_mask = False,
21
+ depth_interpolation_mode = "bilinear",
22
+ mask_depth_loss = False,
23
+ relative_depth_error_threshold = 0.05,
24
+ alpha = 1.,
25
+ c = 1e-3,
26
  ):
27
  super().__init__()
28
  self.robust = robust # measured in pixels
 
45
  B, C, H, W = scale_gm_cls.shape
46
  device = x2.device
47
  cls_res = round(math.sqrt(C))
48
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
49
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
50
+ GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
51
+ cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
 
 
 
 
 
 
 
 
 
 
 
 
52
  if not torch.any(cls_loss):
53
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
54
 
55
+ certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
56
  losses = {
57
  f"gm_certainty_loss_{scale}": certainty_loss.mean(),
58
  f"gm_cls_loss_{scale}": cls_loss.mean(),
59
  }
60
+ wandb.log(losses, step = roma.GLOBAL_STEP)
61
  return losses
62
 
63
+ def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
 
 
64
  with torch.no_grad():
65
  B, C, H, W = delta_cls.shape
66
  device = x2.device
67
  cls_res = round(math.sqrt(C))
68
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
69
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
70
+ GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
71
+ cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
 
 
 
 
 
 
 
 
 
 
 
 
72
  if not torch.any(cls_loss):
73
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
74
+ certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
75
  losses = {
76
  f"delta_certainty_loss_{scale}": certainty_loss.mean(),
77
  f"delta_cls_loss_{scale}": cls_loss.mean(),
78
  }
79
+ wandb.log(losses, step = roma.GLOBAL_STEP)
80
  return losses
81
 
82
+ def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
83
+ epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
84
  if scale == 1:
85
+ pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
86
+ wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP)
87
 
88
  ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
89
  a = self.alpha
90
  cs = self.c * scale
91
  x = epe[prob > 0.99]
92
+ reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
93
  if not torch.any(reg_loss):
94
+ reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
95
  losses = {
96
  f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
97
  f"{mode}_regression_loss_{scale}": reg_loss.mean(),
98
  }
99
+ wandb.log(losses, step = roma.GLOBAL_STEP)
100
  return losses
101
 
102
  def forward(self, corresps, batch):
103
  scales = list(corresps.keys())
104
  tot_loss = 0.0
105
  # scale_weights due to differences in scale for regression gradients and classification gradients
106
+ scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
107
  for scale in scales:
108
  scale_corresps = corresps[scale]
109
+ scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
 
 
 
 
 
 
 
 
 
110
  scale_corresps["certainty"],
111
  scale_corresps["flow_pre_delta"],
112
  scale_corresps.get("delta_cls"),
 
115
  scale_corresps.get("gm_certainty"),
116
  scale_corresps["flow"],
117
  scale_corresps.get("gm_flow"),
118
+
119
  )
120
  flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
121
  b, h, w, d = flow_pre_delta.shape
122
+ gt_warp, gt_prob = get_gt_warp(
123
+ batch["im_A_depth"],
124
+ batch["im_B_depth"],
125
+ batch["T_1to2"],
126
+ batch["K1"],
127
+ batch["K2"],
128
+ H=h,
129
+ W=w,
130
+ )
131
  x2 = gt_warp.float()
132
  prob = gt_prob
133
+
134
  if self.local_largest_scale >= scale:
135
  prob = prob * (
136
+ F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
137
+ < (2 / 512) * (self.local_dist[scale] * scale))
138
+
 
 
 
139
  if scale_gm_cls is not None:
140
+ gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
141
+ gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
 
 
 
 
 
142
  tot_loss = tot_loss + scale_weights[scale] * gm_loss
143
  elif scale_gm_flow is not None:
144
+ gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
145
+ gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
 
 
 
 
 
146
  tot_loss = tot_loss + scale_weights[scale] * gm_loss
147
+
148
  if delta_cls is not None:
149
+ delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
150
+ delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
 
 
 
 
 
 
 
 
 
 
 
151
  tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
152
  else:
153
+ delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
154
+ reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
 
 
 
 
 
 
155
  tot_loss = tot_loss + scale_weights[scale] * reg_loss
156
+ prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
157
  return tot_loss
third_party/RoMa/roma/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_zoo import roma_outdoor, roma_indoor
third_party/{Roma β†’ RoMa}/roma/models/encoders.py RENAMED
@@ -8,7 +8,8 @@ import gc
8
 
9
 
10
  class ResNet50(nn.Module):
11
- def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
 
12
  super().__init__()
13
  if dilation is None:
14
  dilation = [False,False,False]
@@ -24,10 +25,7 @@ class ResNet50(nn.Module):
24
  self.freeze_bn = freeze_bn
25
  self.early_exit = early_exit
26
  self.amp = amp
27
- if not torch.cuda.is_available():
28
- self.amp_dtype = torch.float32
29
- else:
30
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
31
 
32
  def forward(self, x, **kwargs):
33
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -59,14 +57,11 @@ class ResNet50(nn.Module):
59
  pass
60
 
61
  class VGG19(nn.Module):
62
- def __init__(self, pretrained=False, amp = False) -> None:
63
  super().__init__()
64
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
65
  self.amp = amp
66
- if not torch.cuda.is_available():
67
- self.amp_dtype = torch.float32
68
- else:
69
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
70
 
71
  def forward(self, x, **kwargs):
72
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -80,7 +75,7 @@ class VGG19(nn.Module):
80
  return feats
81
 
82
  class CNNandDinov2(nn.Module):
83
- def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
84
  super().__init__()
85
  if dinov2_weights is None:
86
  dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
@@ -100,10 +95,7 @@ class CNNandDinov2(nn.Module):
100
  else:
101
  self.cnn = VGG19(**cnn_kwargs)
102
  self.amp = amp
103
- if not torch.cuda.is_available():
104
- self.amp_dtype = torch.float32
105
- else:
106
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
107
  if self.amp:
108
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
109
  self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
 
8
 
9
 
10
  class ResNet50(nn.Module):
11
+ def __init__(self, pretrained=False, high_res = False, weights = None,
12
+ dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
13
  super().__init__()
14
  if dilation is None:
15
  dilation = [False,False,False]
 
25
  self.freeze_bn = freeze_bn
26
  self.early_exit = early_exit
27
  self.amp = amp
28
+ self.amp_dtype = amp_dtype
 
 
 
29
 
30
  def forward(self, x, **kwargs):
31
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
 
57
  pass
58
 
59
  class VGG19(nn.Module):
60
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
61
  super().__init__()
62
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
63
  self.amp = amp
64
+ self.amp_dtype = amp_dtype
 
 
 
65
 
66
  def forward(self, x, **kwargs):
67
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
 
75
  return feats
76
 
77
  class CNNandDinov2(nn.Module):
78
+ def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
79
  super().__init__()
80
  if dinov2_weights is None:
81
  dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
 
95
  else:
96
  self.cnn = VGG19(**cnn_kwargs)
97
  self.amp = amp
98
+ self.amp_dtype = amp_dtype
 
 
 
99
  if self.amp:
100
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
101
  self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
third_party/{Roma β†’ RoMa}/roma/models/matcher.py RENAMED
@@ -7,6 +7,7 @@ import torch.nn.functional as F
7
  from einops import rearrange
8
  import warnings
9
  from warnings import warn
 
10
 
11
  import roma
12
  from roma.utils import get_tuple_transform_ops
@@ -37,6 +38,7 @@ class ConvRefiner(nn.Module):
37
  sample_mode = "bilinear",
38
  norm_type = nn.BatchNorm2d,
39
  bn_momentum = 0.1,
 
40
  ):
41
  super().__init__()
42
  self.bn_momentum = bn_momentum
@@ -71,12 +73,8 @@ class ConvRefiner(nn.Module):
71
  self.disable_local_corr_grad = disable_local_corr_grad
72
  self.is_classifier = is_classifier
73
  self.sample_mode = sample_mode
74
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
75
- if not torch.cuda.is_available():
76
- self.amp_dtype = torch.float32
77
- else:
78
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
79
-
80
  def create_block(
81
  self,
82
  in_dim,
@@ -113,8 +111,8 @@ class ConvRefiner(nn.Module):
113
  if self.has_displacement_emb:
114
  im_A_coords = torch.meshgrid(
115
  (
116
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=self.device),
117
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=self.device),
118
  )
119
  )
120
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -278,7 +276,7 @@ class Decoder(nn.Module):
278
  def __init__(
279
  self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
280
  num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
281
- flow_upsample_mode = "bilinear"
282
  ):
283
  super().__init__()
284
  self.embedding_decoder = embedding_decoder
@@ -300,11 +298,8 @@ class Decoder(nn.Module):
300
  self.displacement_dropout_p = displacement_dropout_p
301
  self.gm_warp_dropout_p = gm_warp_dropout_p
302
  self.flow_upsample_mode = flow_upsample_mode
303
- if not torch.cuda.is_available():
304
- self.amp_dtype = torch.float32
305
- else:
306
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
307
-
308
  def get_placeholder_flow(self, b, h, w, device):
309
  coarse_coords = torch.meshgrid(
310
  (
@@ -367,7 +362,7 @@ class Decoder(nn.Module):
367
  corresps[ins] = {}
368
  f1_s, f2_s = f1[ins], f2[ins]
369
  if new_scale in self.proj:
370
- with torch.autocast("cuda", self.amp_dtype):
371
  f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
372
 
373
  if ins in coarse_scales:
@@ -429,11 +424,12 @@ class RegressionMatcher(nn.Module):
429
  decoder,
430
  h=448,
431
  w=448,
432
- sample_mode = "threshold",
433
  upsample_preds = False,
434
  symmetric = False,
435
  name = None,
436
  attenuate_cert = None,
 
437
  ):
438
  super().__init__()
439
  self.attenuate_cert = attenuate_cert
@@ -448,6 +444,7 @@ class RegressionMatcher(nn.Module):
448
  self.upsample_res = (14*16*6, 14*16*6)
449
  self.symmetric = symmetric
450
  self.sample_thresh = 0.05
 
451
 
452
  def get_output_resolution(self):
453
  if not self.upsample_preds:
@@ -527,12 +524,62 @@ class RegressionMatcher(nn.Module):
527
  scale_factor=scale_factor)
528
  return corresps
529
 
530
- def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
531
- kpts_A, kpts_B = matches[...,:2], matches[...,2:]
 
 
 
532
  kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
533
  kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
534
  return kpts_A, kpts_B
 
 
 
 
 
 
 
 
 
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  def match(
537
  self,
538
  im_A_path,
@@ -543,9 +590,8 @@ class RegressionMatcher(nn.Module):
543
  ):
544
  if device is None:
545
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
546
- from PIL import Image
547
  if isinstance(im_A_path, (str, os.PathLike)):
548
- im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
549
  else:
550
  # Assume its not a path
551
  im_A, im_B = im_A_path, im_B_path
@@ -597,7 +643,14 @@ class RegressionMatcher(nn.Module):
597
  test_transform = get_tuple_transform_ops(
598
  resize=(hs, ws), normalize=True
599
  )
600
- im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
 
 
 
 
 
 
 
601
  im_A, im_B = test_transform((im_A, im_B))
602
  im_A, im_B = im_A[None].to(device), im_B[None].to(device)
603
  scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
@@ -653,4 +706,30 @@ class RegressionMatcher(nn.Module):
653
  warp[0],
654
  certainty[0, 0],
655
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from einops import rearrange
8
  import warnings
9
  from warnings import warn
10
+ from PIL import Image
11
 
12
  import roma
13
  from roma.utils import get_tuple_transform_ops
 
38
  sample_mode = "bilinear",
39
  norm_type = nn.BatchNorm2d,
40
  bn_momentum = 0.1,
41
+ amp_dtype = torch.float16,
42
  ):
43
  super().__init__()
44
  self.bn_momentum = bn_momentum
 
73
  self.disable_local_corr_grad = disable_local_corr_grad
74
  self.is_classifier = is_classifier
75
  self.sample_mode = sample_mode
76
+ self.amp_dtype = amp_dtype
77
+
 
 
 
 
78
  def create_block(
79
  self,
80
  in_dim,
 
111
  if self.has_displacement_emb:
112
  im_A_coords = torch.meshgrid(
113
  (
114
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
115
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
116
  )
117
  )
118
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
 
276
  def __init__(
277
  self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
278
  num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
279
+ flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
280
  ):
281
  super().__init__()
282
  self.embedding_decoder = embedding_decoder
 
298
  self.displacement_dropout_p = displacement_dropout_p
299
  self.gm_warp_dropout_p = gm_warp_dropout_p
300
  self.flow_upsample_mode = flow_upsample_mode
301
+ self.amp_dtype = amp_dtype
302
+
 
 
 
303
  def get_placeholder_flow(self, b, h, w, device):
304
  coarse_coords = torch.meshgrid(
305
  (
 
362
  corresps[ins] = {}
363
  f1_s, f2_s = f1[ins], f2[ins]
364
  if new_scale in self.proj:
365
+ with torch.autocast("cuda", dtype = self.amp_dtype):
366
  f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
367
 
368
  if ins in coarse_scales:
 
424
  decoder,
425
  h=448,
426
  w=448,
427
+ sample_mode = "threshold_balanced",
428
  upsample_preds = False,
429
  symmetric = False,
430
  name = None,
431
  attenuate_cert = None,
432
+ recrop_upsample = False,
433
  ):
434
  super().__init__()
435
  self.attenuate_cert = attenuate_cert
 
444
  self.upsample_res = (14*16*6, 14*16*6)
445
  self.symmetric = symmetric
446
  self.sample_thresh = 0.05
447
+ self.recrop_upsample = recrop_upsample
448
 
449
  def get_output_resolution(self):
450
  if not self.upsample_preds:
 
524
  scale_factor=scale_factor)
525
  return corresps
526
 
527
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B, W_B):
528
+ if isinstance(coords, (list, tuple)):
529
+ kpts_A, kpts_B = coords[0], coords[1]
530
+ else:
531
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
532
  kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
533
  kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
534
  return kpts_A, kpts_B
535
+
536
+ def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
537
+ if isinstance(coords, (list, tuple)):
538
+ kpts_A, kpts_B = coords[0], coords[1]
539
+ else:
540
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
541
+ kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
542
+ kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
543
+ return kpts_A, kpts_B
544
 
545
+ def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
546
+ x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
547
+ cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
548
+ D = torch.cdist(x_A_to_B, x_B)
549
+ inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
550
+
551
+ if return_tuple:
552
+ if return_inds:
553
+ return inds_A, inds_B
554
+ else:
555
+ return x_A[inds_A], x_B[inds_B]
556
+ else:
557
+ if return_inds:
558
+ return torch.cat((inds_A, inds_B),dim=-1)
559
+ else:
560
+ return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
561
+
562
+ def get_roi(self, certainty, W, H, thr = 0.025):
563
+ raise NotImplementedError("WIP, disable for now")
564
+ hs,ws = certainty.shape
565
+ certainty = certainty/certainty.sum(dim=(-1,-2))
566
+ cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
567
+ cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
568
+ print(cum_certainty_w)
569
+ print(torch.min(torch.nonzero(cum_certainty_w > thr)))
570
+ print(torch.min(torch.nonzero(cum_certainty_w < thr)))
571
+ left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
572
+ right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
573
+ top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
574
+ bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
575
+ print(left, right, top, bottom)
576
+ return left, top, right, bottom
577
+
578
+ def recrop(self, certainty, image_path):
579
+ roi = self.get_roi(certainty, *Image.open(image_path).size)
580
+ return Image.open(image_path).convert("RGB").crop(roi)
581
+
582
+ @torch.inference_mode()
583
  def match(
584
  self,
585
  im_A_path,
 
590
  ):
591
  if device is None:
592
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
593
  if isinstance(im_A_path, (str, os.PathLike)):
594
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
595
  else:
596
  # Assume its not a path
597
  im_A, im_B = im_A_path, im_B_path
 
643
  test_transform = get_tuple_transform_ops(
644
  resize=(hs, ws), normalize=True
645
  )
646
+ if self.recrop_upsample:
647
+ certainty = corresps[finest_scale]["certainty"]
648
+ print(certainty.shape)
649
+ im_A = self.recrop(certainty[0,0], im_A_path)
650
+ im_B = self.recrop(certainty[1,0], im_B_path)
651
+ #TODO: need to adjust corresps when doing this
652
+ else:
653
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
654
  im_A, im_B = test_transform((im_A, im_B))
655
  im_A, im_B = im_A[None].to(device), im_B[None].to(device)
656
  scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
 
706
  warp[0],
707
  certainty[0, 0],
708
  )
709
+
710
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None, im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None):
711
+ assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
712
+ H,W2,_ = warp.shape
713
+ W = W2//2 if symmetric else W2
714
+ if im_A is None:
715
+ from PIL import Image
716
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
717
+ im_A = im_A.resize((W,H))
718
+ im_B = im_B.resize((W,H))
719
+
720
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
721
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
722
 
723
+ im_A_transfer_rgb = F.grid_sample(
724
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
725
+ )[0]
726
+ im_B_transfer_rgb = F.grid_sample(
727
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
728
+ )[0]
729
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
730
+ white_im = torch.ones((H,2*W),device=device)
731
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
732
+ if save_path is not None:
733
+ from roma.utils import tensor_to_pil
734
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
735
+ return vis_im
third_party/RoMa/roma/models/model_zoo/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import torch
3
+ from .roma_models import roma_model
4
+
5
+ weight_urls = {
6
+ "roma": {
7
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
8
+ "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
9
+ },
10
+ "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
11
+ }
12
+
13
+ def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
14
+ if isinstance(coarse_res, int):
15
+ coarse_res = (coarse_res, coarse_res)
16
+ if isinstance(upsample_res, int):
17
+ upsample_res = (upsample_res, upsample_res)
18
+
19
+ assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
20
+ assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
21
+
22
+ if weights is None:
23
+ weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
24
+ map_location=device)
25
+ if dinov2_weights is None:
26
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
27
+ map_location=device)
28
+ model = roma_model(resolution=coarse_res, upsample_preds=True,
29
+ weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
30
+ model.upsample_res = upsample_res
31
+ print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
32
+ return model
33
+
34
+ def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
35
+ if isinstance(coarse_res, int):
36
+ coarse_res = (coarse_res, coarse_res)
37
+ if isinstance(upsample_res, int):
38
+ upsample_res = (upsample_res, upsample_res)
39
+
40
+ assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
41
+ assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
42
+
43
+ if weights is None:
44
+ weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
45
+ map_location=device)
46
+ if dinov2_weights is None:
47
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
48
+ map_location=device)
49
+ model = roma_model(resolution=coarse_res, upsample_preds=True,
50
+ weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
51
+ model.upsample_res = upsample_res
52
+ print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
53
+ return model
third_party/{Roma β†’ RoMa}/roma/models/model_zoo/roma_models.py RENAMED
@@ -1,98 +1,91 @@
1
  import warnings
2
  import torch.nn as nn
 
3
  from roma.models.matcher import *
4
  from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
5
  from roma.models.encoders import *
6
 
7
-
8
- def roma_model(
9
- resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs
10
- ):
11
  # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
12
- torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
13
- torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
14
- warnings.filterwarnings(
15
- "ignore", category=UserWarning, message="TypedStorage is deprecated"
16
- )
17
  gp_dim = 512
18
  feat_dim = 512
19
  decoder_dim = gp_dim + feat_dim
20
  cls_to_coord_res = 64
21
  coordinate_decoder = TransformerDecoder(
22
- nn.Sequential(
23
- *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
24
- ),
25
- decoder_dim,
26
  cls_to_coord_res**2 + 1,
27
  is_classifier=True,
28
- amp=True,
29
- pos_enc=False,
30
- )
31
  dw = True
32
  hidden_blocks = 8
33
  kernel_size = 5
34
  displacement_emb = "linear"
35
  disable_local_corr_grad = True
36
-
37
  conv_refiner = nn.ModuleDict(
38
  {
39
  "16": ConvRefiner(
40
- 2 * 512 + 128 + (2 * 7 + 1) ** 2,
41
- 2 * 512 + 128 + (2 * 7 + 1) ** 2,
42
  2 + 1,
43
  kernel_size=kernel_size,
44
  dw=dw,
45
  hidden_blocks=hidden_blocks,
46
  displacement_emb=displacement_emb,
47
  displacement_emb_dim=128,
48
- local_corr_radius=7,
49
- corr_in_other=True,
50
- amp=True,
51
- disable_local_corr_grad=disable_local_corr_grad,
52
- bn_momentum=0.01,
53
  ),
54
  "8": ConvRefiner(
55
- 2 * 512 + 64 + (2 * 3 + 1) ** 2,
56
- 2 * 512 + 64 + (2 * 3 + 1) ** 2,
57
  2 + 1,
58
  kernel_size=kernel_size,
59
  dw=dw,
60
  hidden_blocks=hidden_blocks,
61
  displacement_emb=displacement_emb,
62
  displacement_emb_dim=64,
63
- local_corr_radius=3,
64
- corr_in_other=True,
65
- amp=True,
66
- disable_local_corr_grad=disable_local_corr_grad,
67
- bn_momentum=0.01,
68
  ),
69
  "4": ConvRefiner(
70
- 2 * 256 + 32 + (2 * 2 + 1) ** 2,
71
- 2 * 256 + 32 + (2 * 2 + 1) ** 2,
72
  2 + 1,
73
  kernel_size=kernel_size,
74
  dw=dw,
75
  hidden_blocks=hidden_blocks,
76
  displacement_emb=displacement_emb,
77
  displacement_emb_dim=32,
78
- local_corr_radius=2,
79
- corr_in_other=True,
80
- amp=True,
81
- disable_local_corr_grad=disable_local_corr_grad,
82
- bn_momentum=0.01,
83
  ),
84
  "2": ConvRefiner(
85
- 2 * 64 + 16,
86
- 128 + 16,
87
  2 + 1,
88
  kernel_size=kernel_size,
89
  dw=dw,
90
  hidden_blocks=hidden_blocks,
91
  displacement_emb=displacement_emb,
92
  displacement_emb_dim=16,
93
- amp=True,
94
- disable_local_corr_grad=disable_local_corr_grad,
95
- bn_momentum=0.01,
96
  ),
97
  "1": ConvRefiner(
98
  2 * 9 + 6,
@@ -100,12 +93,12 @@ def roma_model(
100
  2 + 1,
101
  kernel_size=kernel_size,
102
  dw=dw,
103
- hidden_blocks=hidden_blocks,
104
- displacement_emb=displacement_emb,
105
- displacement_emb_dim=6,
106
- amp=True,
107
- disable_local_corr_grad=disable_local_corr_grad,
108
- bn_momentum=0.01,
109
  ),
110
  }
111
  )
@@ -130,46 +123,38 @@ def roma_model(
130
  proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
131
  proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
132
  proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
133
- proj = nn.ModuleDict(
134
- {
135
- "16": proj16,
136
- "8": proj8,
137
- "4": proj4,
138
- "2": proj2,
139
- "1": proj1,
140
- }
141
- )
142
  displacement_dropout_p = 0.0
143
  gm_warp_dropout_p = 0.0
144
- decoder = Decoder(
145
- coordinate_decoder,
146
- gps,
147
- proj,
148
- conv_refiner,
149
- detach=True,
150
- scales=["16", "8", "4", "2", "1"],
151
- displacement_dropout_p=displacement_dropout_p,
152
- gm_warp_dropout_p=gm_warp_dropout_p,
153
- )
154
-
155
  encoder = CNNandDinov2(
156
- cnn_kwargs=dict(pretrained=False, amp=True),
157
- amp=True,
158
- use_vgg=True,
159
- dinov2_weights=dinov2_weights,
 
 
 
160
  )
161
- h, w = resolution
162
  symmetric = True
163
  attenuate_cert = True
164
- matcher = RegressionMatcher(
165
- encoder,
166
- decoder,
167
- h=h,
168
- w=w,
169
- upsample_preds=upsample_preds,
170
- symmetric=symmetric,
171
- attenuate_cert=attenuate_cert,
172
- **kwargs
173
- ).to(device)
174
  matcher.load_state_dict(weights)
175
  return matcher
 
1
  import warnings
2
  import torch.nn as nn
3
+ import torch
4
  from roma.models.matcher import *
5
  from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
6
  from roma.models.encoders import *
7
 
8
+ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
 
 
 
9
  # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
10
+ #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
11
+ #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
12
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
 
 
13
  gp_dim = 512
14
  feat_dim = 512
15
  decoder_dim = gp_dim + feat_dim
16
  cls_to_coord_res = 64
17
  coordinate_decoder = TransformerDecoder(
18
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
19
+ decoder_dim,
 
 
20
  cls_to_coord_res**2 + 1,
21
  is_classifier=True,
22
+ amp = True,
23
+ pos_enc = False,)
 
24
  dw = True
25
  hidden_blocks = 8
26
  kernel_size = 5
27
  displacement_emb = "linear"
28
  disable_local_corr_grad = True
29
+
30
  conv_refiner = nn.ModuleDict(
31
  {
32
  "16": ConvRefiner(
33
+ 2 * 512+128+(2*7+1)**2,
34
+ 2 * 512+128+(2*7+1)**2,
35
  2 + 1,
36
  kernel_size=kernel_size,
37
  dw=dw,
38
  hidden_blocks=hidden_blocks,
39
  displacement_emb=displacement_emb,
40
  displacement_emb_dim=128,
41
+ local_corr_radius = 7,
42
+ corr_in_other = True,
43
+ amp = True,
44
+ disable_local_corr_grad = disable_local_corr_grad,
45
+ bn_momentum = 0.01,
46
  ),
47
  "8": ConvRefiner(
48
+ 2 * 512+64+(2*3+1)**2,
49
+ 2 * 512+64+(2*3+1)**2,
50
  2 + 1,
51
  kernel_size=kernel_size,
52
  dw=dw,
53
  hidden_blocks=hidden_blocks,
54
  displacement_emb=displacement_emb,
55
  displacement_emb_dim=64,
56
+ local_corr_radius = 3,
57
+ corr_in_other = True,
58
+ amp = True,
59
+ disable_local_corr_grad = disable_local_corr_grad,
60
+ bn_momentum = 0.01,
61
  ),
62
  "4": ConvRefiner(
63
+ 2 * 256+32+(2*2+1)**2,
64
+ 2 * 256+32+(2*2+1)**2,
65
  2 + 1,
66
  kernel_size=kernel_size,
67
  dw=dw,
68
  hidden_blocks=hidden_blocks,
69
  displacement_emb=displacement_emb,
70
  displacement_emb_dim=32,
71
+ local_corr_radius = 2,
72
+ corr_in_other = True,
73
+ amp = True,
74
+ disable_local_corr_grad = disable_local_corr_grad,
75
+ bn_momentum = 0.01,
76
  ),
77
  "2": ConvRefiner(
78
+ 2 * 64+16,
79
+ 128+16,
80
  2 + 1,
81
  kernel_size=kernel_size,
82
  dw=dw,
83
  hidden_blocks=hidden_blocks,
84
  displacement_emb=displacement_emb,
85
  displacement_emb_dim=16,
86
+ amp = True,
87
+ disable_local_corr_grad = disable_local_corr_grad,
88
+ bn_momentum = 0.01,
89
  ),
90
  "1": ConvRefiner(
91
  2 * 9 + 6,
 
93
  2 + 1,
94
  kernel_size=kernel_size,
95
  dw=dw,
96
+ hidden_blocks = hidden_blocks,
97
+ displacement_emb = displacement_emb,
98
+ displacement_emb_dim = 6,
99
+ amp = True,
100
+ disable_local_corr_grad = disable_local_corr_grad,
101
+ bn_momentum = 0.01,
102
  ),
103
  }
104
  )
 
123
  proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
124
  proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
125
  proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
126
+ proj = nn.ModuleDict({
127
+ "16": proj16,
128
+ "8": proj8,
129
+ "4": proj4,
130
+ "2": proj2,
131
+ "1": proj1,
132
+ })
 
 
133
  displacement_dropout_p = 0.0
134
  gm_warp_dropout_p = 0.0
135
+ decoder = Decoder(coordinate_decoder,
136
+ gps,
137
+ proj,
138
+ conv_refiner,
139
+ detach=True,
140
+ scales=["16", "8", "4", "2", "1"],
141
+ displacement_dropout_p = displacement_dropout_p,
142
+ gm_warp_dropout_p = gm_warp_dropout_p)
143
+
 
 
144
  encoder = CNNandDinov2(
145
+ cnn_kwargs = dict(
146
+ pretrained=False,
147
+ amp = True),
148
+ amp = True,
149
+ use_vgg = True,
150
+ dinov2_weights = dinov2_weights,
151
+ amp_dtype=amp_dtype,
152
  )
153
+ h,w = resolution
154
  symmetric = True
155
  attenuate_cert = True
156
+ sample_mode = "threshold_balanced"
157
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
158
+ symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
 
 
 
 
 
 
 
159
  matcher.load_state_dict(weights)
160
  return matcher
third_party/{Roma β†’ RoMa}/roma/models/transformer/__init__.py RENAMED
@@ -7,23 +7,9 @@ from .layers.block import Block
7
  from .layers.attention import MemEffAttention
8
  from .dinov2 import vit_large
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
-
13
  class TransformerDecoder(nn.Module):
14
- def __init__(
15
- self,
16
- blocks,
17
- hidden_dim,
18
- out_dim,
19
- is_classifier=False,
20
- *args,
21
- amp=False,
22
- pos_enc=True,
23
- learned_embeddings=False,
24
- embedding_dim=None,
25
- **kwargs
26
- ) -> None:
27
  super().__init__(*args, **kwargs)
28
  self.blocks = blocks
29
  self.to_out = nn.Linear(hidden_dim, out_dim)
@@ -32,48 +18,30 @@ class TransformerDecoder(nn.Module):
32
  self._scales = [16]
33
  self.is_classifier = is_classifier
34
  self.amp = amp
35
- if torch.cuda.is_available():
36
- if torch.cuda.is_bf16_supported():
37
- self.amp_dtype = torch.bfloat16
38
- else:
39
- self.amp_dtype = torch.float16
40
- else:
41
- self.amp_dtype = torch.float32
42
-
43
  self.pos_enc = pos_enc
44
  self.learned_embeddings = learned_embeddings
45
  if self.learned_embeddings:
46
- self.learned_pos_embeddings = nn.Parameter(
47
- nn.init.kaiming_normal_(
48
- torch.empty((1, hidden_dim, embedding_dim, embedding_dim))
49
- )
50
- )
51
 
52
  def scales(self):
53
  return self._scales.copy()
54
 
55
  def forward(self, gp_posterior, features, old_stuff, new_scale):
56
- with torch.autocast(device, dtype=self.amp_dtype, enabled=self.amp):
57
- B, C, H, W = gp_posterior.shape
58
- x = torch.cat((gp_posterior, features), dim=1)
59
- B, C, H, W = x.shape
60
- grid = get_grid(B, H, W, x.device).reshape(B, H * W, 2)
61
  if self.learned_embeddings:
62
- pos_enc = (
63
- F.interpolate(
64
- self.learned_pos_embeddings,
65
- size=(H, W),
66
- mode="bilinear",
67
- align_corners=False,
68
- )
69
- .permute(0, 2, 3, 1)
70
- .reshape(1, H * W, C)
71
- )
72
  else:
73
  pos_enc = 0
74
- tokens = x.reshape(B, C, H * W).permute(0, 2, 1) + pos_enc
75
  z = self.blocks(tokens)
76
  out = self.to_out(z)
77
- out = out.permute(0, 2, 1).reshape(B, self.out_dim, H, W)
78
  warp, certainty = out[:, :-1], out[:, -1:]
79
  return warp, certainty, None
 
 
 
7
  from .layers.attention import MemEffAttention
8
  from .dinov2 import vit_large
9
 
 
 
 
10
  class TransformerDecoder(nn.Module):
11
+ def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
12
+ amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
 
 
 
 
 
 
 
 
 
 
 
13
  super().__init__(*args, **kwargs)
14
  self.blocks = blocks
15
  self.to_out = nn.Linear(hidden_dim, out_dim)
 
18
  self._scales = [16]
19
  self.is_classifier = is_classifier
20
  self.amp = amp
21
+ self.amp_dtype = amp_dtype
 
 
 
 
 
 
 
22
  self.pos_enc = pos_enc
23
  self.learned_embeddings = learned_embeddings
24
  if self.learned_embeddings:
25
+ self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
 
 
 
 
26
 
27
  def scales(self):
28
  return self._scales.copy()
29
 
30
  def forward(self, gp_posterior, features, old_stuff, new_scale):
31
+ with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
32
+ B,C,H,W = gp_posterior.shape
33
+ x = torch.cat((gp_posterior, features), dim = 1)
34
+ B,C,H,W = x.shape
35
+ grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
36
  if self.learned_embeddings:
37
+ pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
 
 
 
 
 
 
 
 
 
38
  else:
39
  pos_enc = 0
40
+ tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
41
  z = self.blocks(tokens)
42
  out = self.to_out(z)
43
+ out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
44
  warp, certainty = out[:, :-1], out[:, -1:]
45
  return warp, certainty, None
46
+
47
+
third_party/{Roma β†’ RoMa}/roma/models/transformer/dinov2.py RENAMED
@@ -18,29 +18,16 @@ import torch.nn as nn
18
  import torch.utils.checkpoint
19
  from torch.nn.init import trunc_normal_
20
 
21
- from .layers import (
22
- Mlp,
23
- PatchEmbed,
24
- SwiGLUFFNFused,
25
- MemEffAttention,
26
- NestedTensorBlock as Block,
27
- )
28
-
29
-
30
- def named_apply(
31
- fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
32
- ) -> nn.Module:
33
  if not depth_first and include_root:
34
  fn(module=module, name=name)
35
  for child_name, child_module in module.named_children():
36
  child_name = ".".join((name, child_name)) if name else child_name
37
- named_apply(
38
- fn=fn,
39
- module=child_module,
40
- name=child_name,
41
- depth_first=depth_first,
42
- include_root=True,
43
- )
44
  if depth_first and include_root:
45
  fn(module=module, name=name)
46
  return module
@@ -100,33 +87,22 @@ class DinoVisionTransformer(nn.Module):
100
  super().__init__()
101
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
102
 
103
- self.num_features = (
104
- self.embed_dim
105
- ) = embed_dim # num_features for consistency with other models
106
  self.num_tokens = 1
107
  self.n_blocks = depth
108
  self.num_heads = num_heads
109
  self.patch_size = patch_size
110
 
111
- self.patch_embed = embed_layer(
112
- img_size=img_size,
113
- patch_size=patch_size,
114
- in_chans=in_chans,
115
- embed_dim=embed_dim,
116
- )
117
  num_patches = self.patch_embed.num_patches
118
 
119
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
120
- self.pos_embed = nn.Parameter(
121
- torch.zeros(1, num_patches + self.num_tokens, embed_dim)
122
- )
123
 
124
  if drop_path_uniform is True:
125
  dpr = [drop_path_rate] * depth
126
  else:
127
- dpr = [
128
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
129
- ] # stochastic depth decay rule
130
 
131
  if ffn_layer == "mlp":
132
  ffn_layer = Mlp
@@ -163,9 +139,7 @@ class DinoVisionTransformer(nn.Module):
163
  chunksize = depth // block_chunks
164
  for i in range(0, depth, chunksize):
165
  # this is to keep the block index consistent if we chunk the block list
166
- chunked_blocks.append(
167
- [nn.Identity()] * i + blocks_list[i : i + chunksize]
168
- )
169
  self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
170
  else:
171
  self.chunked_blocks = False
@@ -179,7 +153,7 @@ class DinoVisionTransformer(nn.Module):
179
  self.init_weights()
180
  for param in self.parameters():
181
  param.requires_grad = False
182
-
183
  @property
184
  def device(self):
185
  return self.cls_token.device
@@ -206,29 +180,20 @@ class DinoVisionTransformer(nn.Module):
206
  w0, h0 = w0 + 0.1, h0 + 0.1
207
 
208
  patch_pos_embed = nn.functional.interpolate(
209
- patch_pos_embed.reshape(
210
- 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
211
- ).permute(0, 3, 1, 2),
212
  scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
213
  mode="bicubic",
214
  )
215
 
216
- assert (
217
- int(w0) == patch_pos_embed.shape[-2]
218
- and int(h0) == patch_pos_embed.shape[-1]
219
- )
220
  patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
221
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
222
- previous_dtype
223
- )
224
 
225
  def prepare_tokens_with_masks(self, x, masks=None):
226
  B, nc, w, h = x.shape
227
  x = self.patch_embed(x)
228
  if masks is not None:
229
- x = torch.where(
230
- masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
231
- )
232
 
233
  x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
234
  x = x + self.interpolate_pos_encoding(x, w, h)
@@ -236,10 +201,7 @@ class DinoVisionTransformer(nn.Module):
236
  return x
237
 
238
  def forward_features_list(self, x_list, masks_list):
239
- x = [
240
- self.prepare_tokens_with_masks(x, masks)
241
- for x, masks in zip(x_list, masks_list)
242
- ]
243
  for blk in self.blocks:
244
  x = blk(x)
245
 
@@ -278,34 +240,26 @@ class DinoVisionTransformer(nn.Module):
278
  x = self.prepare_tokens_with_masks(x)
279
  # If n is an int, take the n last blocks. If it's a list, take them
280
  output, total_block_len = [], len(self.blocks)
281
- blocks_to_take = (
282
- range(total_block_len - n, total_block_len) if isinstance(n, int) else n
283
- )
284
  for i, blk in enumerate(self.blocks):
285
  x = blk(x)
286
  if i in blocks_to_take:
287
  output.append(x)
288
- assert len(output) == len(
289
- blocks_to_take
290
- ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
291
  return output
292
 
293
  def _get_intermediate_layers_chunked(self, x, n=1):
294
  x = self.prepare_tokens_with_masks(x)
295
  output, i, total_block_len = [], 0, len(self.blocks[-1])
296
  # If n is an int, take the n last blocks. If it's a list, take them
297
- blocks_to_take = (
298
- range(total_block_len - n, total_block_len) if isinstance(n, int) else n
299
- )
300
  for block_chunk in self.blocks:
301
  for blk in block_chunk[i:]: # Passing the nn.Identity()
302
  x = blk(x)
303
  if i in blocks_to_take:
304
  output.append(x)
305
  i += 1
306
- assert len(output) == len(
307
- blocks_to_take
308
- ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
309
  return output
310
 
311
  def get_intermediate_layers(
@@ -327,9 +281,7 @@ class DinoVisionTransformer(nn.Module):
327
  if reshape:
328
  B, _, w, h = x.shape
329
  outputs = [
330
- out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
331
- .permute(0, 3, 1, 2)
332
- .contiguous()
333
  for out in outputs
334
  ]
335
  if return_class_token:
@@ -404,4 +356,4 @@ def vit_giant2(patch_size=16, **kwargs):
404
  block_fn=partial(Block, attn_class=MemEffAttention),
405
  **kwargs,
406
  )
407
- return model
 
18
  import torch.utils.checkpoint
19
  from torch.nn.init import trunc_normal_
20
 
21
+ from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22
+
23
+
24
+
25
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
 
 
 
 
 
 
 
26
  if not depth_first and include_root:
27
  fn(module=module, name=name)
28
  for child_name, child_module in module.named_children():
29
  child_name = ".".join((name, child_name)) if name else child_name
30
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
 
 
 
 
 
 
31
  if depth_first and include_root:
32
  fn(module=module, name=name)
33
  return module
 
87
  super().__init__()
88
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
89
 
90
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
 
 
91
  self.num_tokens = 1
92
  self.n_blocks = depth
93
  self.num_heads = num_heads
94
  self.patch_size = patch_size
95
 
96
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
 
 
 
 
 
97
  num_patches = self.patch_embed.num_patches
98
 
99
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
100
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
 
 
101
 
102
  if drop_path_uniform is True:
103
  dpr = [drop_path_rate] * depth
104
  else:
105
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
 
 
106
 
107
  if ffn_layer == "mlp":
108
  ffn_layer = Mlp
 
139
  chunksize = depth // block_chunks
140
  for i in range(0, depth, chunksize):
141
  # this is to keep the block index consistent if we chunk the block list
142
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
 
 
143
  self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
144
  else:
145
  self.chunked_blocks = False
 
153
  self.init_weights()
154
  for param in self.parameters():
155
  param.requires_grad = False
156
+
157
  @property
158
  def device(self):
159
  return self.cls_token.device
 
180
  w0, h0 = w0 + 0.1, h0 + 0.1
181
 
182
  patch_pos_embed = nn.functional.interpolate(
183
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
 
 
184
  scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
185
  mode="bicubic",
186
  )
187
 
188
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
 
 
 
189
  patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
190
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
 
 
191
 
192
  def prepare_tokens_with_masks(self, x, masks=None):
193
  B, nc, w, h = x.shape
194
  x = self.patch_embed(x)
195
  if masks is not None:
196
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
 
 
197
 
198
  x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
199
  x = x + self.interpolate_pos_encoding(x, w, h)
 
201
  return x
202
 
203
  def forward_features_list(self, x_list, masks_list):
204
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
 
 
 
205
  for blk in self.blocks:
206
  x = blk(x)
207
 
 
240
  x = self.prepare_tokens_with_masks(x)
241
  # If n is an int, take the n last blocks. If it's a list, take them
242
  output, total_block_len = [], len(self.blocks)
243
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
 
 
244
  for i, blk in enumerate(self.blocks):
245
  x = blk(x)
246
  if i in blocks_to_take:
247
  output.append(x)
248
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
 
 
249
  return output
250
 
251
  def _get_intermediate_layers_chunked(self, x, n=1):
252
  x = self.prepare_tokens_with_masks(x)
253
  output, i, total_block_len = [], 0, len(self.blocks[-1])
254
  # If n is an int, take the n last blocks. If it's a list, take them
255
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
 
 
256
  for block_chunk in self.blocks:
257
  for blk in block_chunk[i:]: # Passing the nn.Identity()
258
  x = blk(x)
259
  if i in blocks_to_take:
260
  output.append(x)
261
  i += 1
262
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
 
 
263
  return output
264
 
265
  def get_intermediate_layers(
 
281
  if reshape:
282
  B, _, w, h = x.shape
283
  outputs = [
284
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
 
 
285
  for out in outputs
286
  ]
287
  if return_class_token:
 
356
  block_fn=partial(Block, attn_class=MemEffAttention),
357
  **kwargs,
358
  )
359
+ return model
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/__init__.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/attention.py RENAMED
@@ -48,11 +48,7 @@ class Attention(nn.Module):
48
 
49
  def forward(self, x: Tensor) -> Tensor:
50
  B, N, C = x.shape
51
- qkv = (
52
- self.qkv(x)
53
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
54
- .permute(2, 0, 3, 1, 4)
55
- )
56
 
57
  q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
58
  attn = q @ k.transpose(-2, -1)
 
48
 
49
  def forward(self, x: Tensor) -> Tensor:
50
  B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
52
 
53
  q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
  attn = q @ k.transpose(-2, -1)
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/block.py RENAMED
@@ -62,9 +62,7 @@ class Block(nn.Module):
62
  attn_drop=attn_drop,
63
  proj_drop=drop,
64
  )
65
- self.ls1 = (
66
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
67
- )
68
  self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
69
 
70
  self.norm2 = norm_layer(dim)
@@ -76,9 +74,7 @@ class Block(nn.Module):
76
  drop=drop,
77
  bias=ffn_bias,
78
  )
79
- self.ls2 = (
80
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
81
- )
82
  self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
83
 
84
  self.sample_drop_ratio = drop_path
@@ -131,9 +127,7 @@ def drop_add_residual_stochastic_depth(
131
  residual_scale_factor = b / sample_subset_size
132
 
133
  # 3) add the residual
134
- x_plus_residual = torch.index_add(
135
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
136
- )
137
  return x_plus_residual.view_as(x)
138
 
139
 
@@ -149,16 +143,10 @@ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None
149
  if scaling_vector is None:
150
  x_flat = x.flatten(1)
151
  residual = residual.flatten(1)
152
- x_plus_residual = torch.index_add(
153
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
154
- )
155
  else:
156
  x_plus_residual = scaled_index_add(
157
- x,
158
- brange,
159
- residual.to(dtype=x.dtype),
160
- scaling=scaling_vector,
161
- alpha=residual_scale_factor,
162
  )
163
  return x_plus_residual
164
 
@@ -170,11 +158,7 @@ def get_attn_bias_and_cat(x_list, branges=None):
170
  """
171
  this will perform the index select, cat the tensors, and provide the attn_bias from cache
172
  """
173
- batch_sizes = (
174
- [b.shape[0] for b in branges]
175
- if branges is not None
176
- else [x.shape[0] for x in x_list]
177
- )
178
  all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
179
  if all_shapes not in attn_bias_cache.keys():
180
  seqlens = []
@@ -186,9 +170,7 @@ def get_attn_bias_and_cat(x_list, branges=None):
186
  attn_bias_cache[all_shapes] = attn_bias
187
 
188
  if branges is not None:
189
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
190
- 1, -1, x_list[0].shape[-1]
191
- )
192
  else:
193
  tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
194
  cat_tensors = torch.cat(tensors_bs1, dim=1)
@@ -203,9 +185,7 @@ def drop_add_residual_stochastic_depth_list(
203
  scaling_vector=None,
204
  ) -> Tensor:
205
  # 1) generate random set of indices for dropping samples in the batch
206
- branges_scales = [
207
- get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
208
- ]
209
  branges = [s[0] for s in branges_scales]
210
  residual_scale_factors = [s[1] for s in branges_scales]
211
 
@@ -216,14 +196,8 @@ def drop_add_residual_stochastic_depth_list(
216
  residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
217
 
218
  outputs = []
219
- for x, brange, residual, residual_scale_factor in zip(
220
- x_list, branges, residual_list, residual_scale_factors
221
- ):
222
- outputs.append(
223
- add_residual(
224
- x, brange, residual, residual_scale_factor, scaling_vector
225
- ).view_as(x)
226
- )
227
  return outputs
228
 
229
 
@@ -246,17 +220,13 @@ class NestedTensorBlock(Block):
246
  x_list,
247
  residual_func=attn_residual_func,
248
  sample_drop_ratio=self.sample_drop_ratio,
249
- scaling_vector=self.ls1.gamma
250
- if isinstance(self.ls1, LayerScale)
251
- else None,
252
  )
253
  x_list = drop_add_residual_stochastic_depth_list(
254
  x_list,
255
  residual_func=ffn_residual_func,
256
  sample_drop_ratio=self.sample_drop_ratio,
257
- scaling_vector=self.ls2.gamma
258
- if isinstance(self.ls1, LayerScale)
259
- else None,
260
  )
261
  return x_list
262
  else:
@@ -276,9 +246,7 @@ class NestedTensorBlock(Block):
276
  if isinstance(x_or_x_list, Tensor):
277
  return super().forward(x_or_x_list)
278
  elif isinstance(x_or_x_list, list):
279
- assert (
280
- XFORMERS_AVAILABLE
281
- ), "Please install xFormers for nested tensors usage"
282
  return self.forward_nested(x_or_x_list)
283
  else:
284
  raise AssertionError
 
62
  attn_drop=attn_drop,
63
  proj_drop=drop,
64
  )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
 
 
66
  self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
 
68
  self.norm2 = norm_layer(dim)
 
74
  drop=drop,
75
  bias=ffn_bias,
76
  )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
 
 
78
  self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
 
80
  self.sample_drop_ratio = drop_path
 
127
  residual_scale_factor = b / sample_subset_size
128
 
129
  # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
 
 
131
  return x_plus_residual.view_as(x)
132
 
133
 
 
143
  if scaling_vector is None:
144
  x_flat = x.flatten(1)
145
  residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
 
 
147
  else:
148
  x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
 
 
 
 
150
  )
151
  return x_plus_residual
152
 
 
158
  """
159
  this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
  """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
 
 
 
 
162
  all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
  if all_shapes not in attn_bias_cache.keys():
164
  seqlens = []
 
170
  attn_bias_cache[all_shapes] = attn_bias
171
 
172
  if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
 
 
174
  else:
175
  tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
  cat_tensors = torch.cat(tensors_bs1, dim=1)
 
185
  scaling_vector=None,
186
  ) -> Tensor:
187
  # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
 
 
189
  branges = [s[0] for s in branges_scales]
190
  residual_scale_factors = [s[1] for s in branges_scales]
191
 
 
196
  residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
 
198
  outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
 
 
 
 
 
 
201
  return outputs
202
 
203
 
 
220
  x_list,
221
  residual_func=attn_residual_func,
222
  sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
 
 
224
  )
225
  x_list = drop_add_residual_stochastic_depth_list(
226
  x_list,
227
  residual_func=ffn_residual_func,
228
  sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
 
 
230
  )
231
  return x_list
232
  else:
 
246
  if isinstance(x_or_x_list, Tensor):
247
  return super().forward(x_or_x_list)
248
  elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
 
 
250
  return self.forward_nested(x_or_x_list)
251
  else:
252
  raise AssertionError
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/dino_head.py RENAMED
@@ -23,14 +23,7 @@ class DINOHead(nn.Module):
23
  ):
24
  super().__init__()
25
  nlayers = max(nlayers, 1)
26
- self.mlp = _build_mlp(
27
- nlayers,
28
- in_dim,
29
- bottleneck_dim,
30
- hidden_dim=hidden_dim,
31
- use_bn=use_bn,
32
- bias=mlp_bias,
33
- )
34
  self.apply(self._init_weights)
35
  self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
36
  self.last_layer.weight_g.data.fill_(1)
@@ -49,9 +42,7 @@ class DINOHead(nn.Module):
49
  return x
50
 
51
 
52
- def _build_mlp(
53
- nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
54
- ):
55
  if nlayers == 1:
56
  return nn.Linear(in_dim, bottleneck_dim, bias=bias)
57
  else:
 
23
  ):
24
  super().__init__()
25
  nlayers = max(nlayers, 1)
26
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
 
 
 
 
 
 
 
27
  self.apply(self._init_weights)
28
  self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29
  self.last_layer.weight_g.data.fill_(1)
 
42
  return x
43
 
44
 
45
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
 
 
46
  if nlayers == 1:
47
  return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48
  else:
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/drop_path.py RENAMED
@@ -16,9 +16,7 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
  if drop_prob == 0.0 or not training:
17
  return x
18
  keep_prob = 1 - drop_prob
19
- shape = (x.shape[0],) + (1,) * (
20
- x.ndim - 1
21
- ) # work with diff dim tensors, not just 2D ConvNets
22
  random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
23
  if keep_prob > 0.0:
24
  random_tensor.div_(keep_prob)
 
16
  if drop_prob == 0.0 or not training:
17
  return x
18
  keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
 
 
20
  random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
  if keep_prob > 0.0:
22
  random_tensor.div_(keep_prob)
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/layer_scale.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/mlp.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/patch_embed.py RENAMED
@@ -63,21 +63,15 @@ class PatchEmbed(nn.Module):
63
 
64
  self.flatten_embedding = flatten_embedding
65
 
66
- self.proj = nn.Conv2d(
67
- in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
68
- )
69
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
70
 
71
  def forward(self, x: Tensor) -> Tensor:
72
  _, _, H, W = x.shape
73
  patch_H, patch_W = self.patch_size
74
 
75
- assert (
76
- H % patch_H == 0
77
- ), f"Input image height {H} is not a multiple of patch height {patch_H}"
78
- assert (
79
- W % patch_W == 0
80
- ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
81
 
82
  x = self.proj(x) # B C H W
83
  H, W = x.size(2), x.size(3)
@@ -89,13 +83,7 @@ class PatchEmbed(nn.Module):
89
 
90
  def flops(self) -> float:
91
  Ho, Wo = self.patches_resolution
92
- flops = (
93
- Ho
94
- * Wo
95
- * self.embed_dim
96
- * self.in_chans
97
- * (self.patch_size[0] * self.patch_size[1])
98
- )
99
  if self.norm is not None:
100
  flops += Ho * Wo * self.embed_dim
101
  return flops
 
63
 
64
  self.flatten_embedding = flatten_embedding
65
 
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
 
 
67
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
 
69
  def forward(self, x: Tensor) -> Tensor:
70
  _, _, H, W = x.shape
71
  patch_H, patch_W = self.patch_size
72
 
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
 
 
 
 
75
 
76
  x = self.proj(x) # B C H W
77
  H, W = x.size(2), x.size(3)
 
83
 
84
  def flops(self) -> float:
85
  Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
 
 
 
 
 
 
87
  if self.norm is not None:
88
  flops += Ho * Wo * self.embed_dim
89
  return flops
third_party/{Roma β†’ RoMa}/roma/models/transformer/layers/swiglu_ffn.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/train/__init__.py RENAMED
File without changes
third_party/{Roma β†’ RoMa}/roma/train/train.py RENAMED
@@ -4,62 +4,41 @@ import roma
4
  import torch
5
  import wandb
6
 
7
-
8
- def log_param_statistics(named_parameters, norm_type=2):
9
  named_parameters = list(named_parameters)
10
  grads = [p.grad for n, p in named_parameters if p.grad is not None]
11
- weight_norms = [
12
- p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None
13
- ]
14
- names = [n for n, p in named_parameters if p.grad is not None]
15
  param_norm = torch.stack(weight_norms).norm(p=norm_type)
16
  device = grads[0].device
17
- grad_norms = torch.stack(
18
- [torch.norm(g.detach(), norm_type).to(device) for g in grads]
19
- )
20
  nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
21
  nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
22
  total_grad_norm = torch.norm(grad_norms, norm_type)
23
  if torch.any(nans_or_infs):
24
  print(f"These params have nan or inf grads: {nan_inf_names}")
25
- wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP)
26
- wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP)
27
-
28
 
29
- def train_step(
30
- train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs
31
- ):
32
  optimizer.zero_grad()
33
  out = model(train_batch)
34
  l = objective(out, train_batch)
35
  grad_scaler.scale(l).backward()
36
  grad_scaler.unscale_(optimizer)
37
  log_param_statistics(model.named_parameters())
38
- torch.nn.utils.clip_grad_norm_(
39
- model.parameters(), grad_clip_norm
40
- ) # what should max norm be?
41
  grad_scaler.step(optimizer)
42
  grad_scaler.update()
43
- wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP)
44
- if grad_scaler._scale < 1.0:
45
- grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale)
46
- roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
47
  return {"train_out": out, "train_loss": l.item()}
48
 
49
 
50
  def train_k_steps(
51
- n_0,
52
- k,
53
- dataloader,
54
- model,
55
- objective,
56
- optimizer,
57
- lr_scheduler,
58
- grad_scaler,
59
- progress_bar=True,
60
- grad_clip_norm=1.0,
61
- warmup=None,
62
- ema_model=None,
63
  ):
64
  for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
65
  batch = next(dataloader)
@@ -73,7 +52,7 @@ def train_k_steps(
73
  lr_scheduler=lr_scheduler,
74
  grad_scaler=grad_scaler,
75
  n=n,
76
- grad_clip_norm=grad_clip_norm,
77
  )
78
  if ema_model is not None:
79
  ema_model.update()
@@ -82,10 +61,7 @@ def train_k_steps(
82
  lr_scheduler.step()
83
  else:
84
  lr_scheduler.step()
85
- [
86
- wandb.log({f"lr_group_{grp}": lr})
87
- for grp, lr in enumerate(lr_scheduler.get_last_lr())
88
- ]
89
 
90
 
91
  def train_epoch(
 
4
  import torch
5
  import wandb
6
 
7
+ def log_param_statistics(named_parameters, norm_type = 2):
 
8
  named_parameters = list(named_parameters)
9
  grads = [p.grad for n, p in named_parameters if p.grad is not None]
10
+ weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
11
+ names = [n for n,p in named_parameters if p.grad is not None]
 
 
12
  param_norm = torch.stack(weight_norms).norm(p=norm_type)
13
  device = grads[0].device
14
+ grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
 
 
15
  nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
16
  nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
17
  total_grad_norm = torch.norm(grad_norms, norm_type)
18
  if torch.any(nans_or_infs):
19
  print(f"These params have nan or inf grads: {nan_inf_names}")
20
+ wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP)
21
+ wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP)
 
22
 
23
+ def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
 
 
24
  optimizer.zero_grad()
25
  out = model(train_batch)
26
  l = objective(out, train_batch)
27
  grad_scaler.scale(l).backward()
28
  grad_scaler.unscale_(optimizer)
29
  log_param_statistics(model.named_parameters())
30
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
 
 
31
  grad_scaler.step(optimizer)
32
  grad_scaler.update()
33
+ wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP)
34
+ if grad_scaler._scale < 1.:
35
+ grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
36
+ roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
37
  return {"train_out": out, "train_loss": l.item()}
38
 
39
 
40
  def train_k_steps(
41
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None,
 
 
 
 
 
 
 
 
 
 
 
42
  ):
43
  for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
44
  batch = next(dataloader)
 
52
  lr_scheduler=lr_scheduler,
53
  grad_scaler=grad_scaler,
54
  n=n,
55
+ grad_clip_norm = grad_clip_norm,
56
  )
57
  if ema_model is not None:
58
  ema_model.update()
 
61
  lr_scheduler.step()
62
  else:
63
  lr_scheduler.step()
64
+ [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
 
 
 
65
 
66
 
67
  def train_epoch(
third_party/{Roma β†’ RoMa}/roma/utils/__init__.py RENAMED
File without changes