JohannesK14 commited on
Commit
e44f283
·
1 Parent(s): 42d9aa3

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Software Copyright License for Academic Use of RIPE, Version 2.0
2
+
3
+ © Copyright (2025) Fraunhofer-Gesellschaft zur Förderung der angewandten Forschung e.V.
4
+
5
+ 1. INTRODUCTION
6
+
7
+ RIPE which means any source code, object code or binary files provided by Fraunhofer excluding third party software and materials, is made available under this Software Copyright License.
8
+
9
+ 2. COPYRIGHT LICENSE
10
+
11
+ Internal use of RIPE, in source and binary forms, with or without modification, is permitted without payment of copyright license fees for non-commercial purposes of evaluation, testing and academic research.
12
+
13
+ No right or license, express or implied, is granted to any part of RIPE except and solely to the extent as expressly set forth herein. Any commercial use or exploitation of RIPE and/or any modifications thereto under this license are prohibited.
14
+
15
+ For any other use of RIPE than permitted by this software copyright license You need another license from Fraunhofer. In such case please contact Fraunhofer under the CONTACT INFORMATION below.
16
+
17
+ 3. LIMITED PATENT LICENSE
18
+
19
+ If Fraunhofer patents are implemented by RIPE their use is permitted for internal non-commercial purposes of evaluation, testing and academic research. No patent grant is provided for any other use, including but not limited to commercial use or exploitation.
20
+
21
+ Fraunhofer provides no warranty of patent non-infringement with respect to RIPE.
22
+
23
+ 4. DISCLAIMER
24
+
25
+ RIPE is provided by Fraunhofer "AS IS" and WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES, including but not limited to the implied warranties of fitness for a particular purpose. IN NO EVENT SHALL FRAUNHOFER BE LIABLE for any direct, indirect, incidental, special, exemplary, or consequential damages, including but not limited to procurement of substitute goods or services; loss of use, data, or profits, or business interruption, however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence), arising in any way out of the use of the Fraunhofer Software, even if advised of the possibility of such damage.
26
+
27
+ 5. CONTACT INFORMATION
28
+
29
+ Fraunhofer-Institut für Nachrichtentechnik, Heinrich-Hertz-Institut, HHI
30
+ Einsteinufer 37, 10587 Berlin, Germany
31
+ info@hhi.fraunhofer.de
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a small gradio interface to access our RIPE keypoint extractor.
2
+ # You can either upload two images or use one of the example image pairs.
3
+
4
+ import os
5
+
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ from ripe import vgg_hyper
10
+
11
+ SEED = 32000
12
+ os.environ["PYTHONHASHSEED"] = str(SEED)
13
+
14
+ import random
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ torch.manual_seed(SEED)
21
+ np.random.seed(SEED)
22
+ random.seed(SEED)
23
+ import cv2
24
+ import kornia.feature as KF
25
+ import kornia.geometry as KG
26
+
27
+ from ripe.utils.utils import cv2_matches_from_kornia, to_cv_kpts
28
+
29
+ MIN_SIZE = 512
30
+ MAX_SIZE = 768
31
+
32
+ description_text = """
33
+ <p align='center'>
34
+ <h1 align='center'>🌊🌺 ICCV 2025 🌺🌊</h1>
35
+ <p align='center'>
36
+ <a href='https://scholar.google.com/citations?user=ybMR38kAAAAJ'>Johannes Künzel</a> ·
37
+ <a href='https://scholar.google.com/citations?user=5yTuyGIAAAAJ'>Anna Hilsmann</a> ·
38
+ <a href='https://scholar.google.com/citations?user=BCElyCkAAAAJ'>Peter Eisert</a>
39
+ </p>
40
+ <h2 align='center'>
41
+ <a href='https://arxiv.org/abs/2507.04839'>Arxiv</a> |
42
+ <a href='https://fraunhoferhhi.github.io/RIPE/'>Project Page</a> |
43
+ <a href='https://github.com/fraunhoferhhi/RIPE'>Code</a>
44
+ </h2>
45
+ </p>
46
+
47
+ <br/>
48
+ <div align='center'>
49
+
50
+ ### This demo showcases our new keypoint extractor model, RIPE (Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction).
51
+
52
+ ### RIPE is trained without requiring pose or depth supervision or artificial augmentations. By leveraging reinforcement learning, it learns to extract keypoints solely based on whether an image pair depicts the same scene or not.
53
+
54
+ ### For more detailed information, please refer to our [paper](link to be added).
55
+
56
+ The demo code extracts the 2048-top keypoints from the two input images. It uses the mutual nearest neighbor (MNN) descriptor matcher from kornia to find matches between the two images.
57
+ If the number of matches is greater than 8, it applies RANSAC to filter out outliers based on the inlier threshold provided by the user.
58
+ Images are resized to fit within a maximum size of 2048x2048 pixels with maintained aspect ratio.
59
+
60
+ </div>
61
+ """
62
+
63
+ model = vgg_hyper("./weights_ripe.pth")
64
+
65
+
66
+ def get_new_image_size(image, min_size=1600, max_size=2048):
67
+ """
68
+ Get a new size for the image that is scaled to fit between min_size and max_size while maintaining the aspect ratio.
69
+
70
+ Args:
71
+ image (PIL.Image): Input image.
72
+ min_size (int): Minimum allowed size for width and height.
73
+ max_size (int): Maximum allowed size for width and height.
74
+
75
+ Returns:
76
+ tuple: New size (width, height) for the image.
77
+ """
78
+ width, height = image.size
79
+
80
+ aspect_ratio = width / height
81
+ if width > height:
82
+ new_width = max(min_size, min(max_size, width))
83
+ new_height = int(new_width / aspect_ratio)
84
+ else:
85
+ new_height = max(min_size, min(max_size, height))
86
+ new_width = int(new_height * aspect_ratio)
87
+
88
+ new_size = (new_width, new_height)
89
+
90
+ return new_size
91
+
92
+
93
+ def extract_keypoints(image1, image2, inl_th):
94
+ """
95
+ Extract keypoints from two input images using the RIPE model.
96
+
97
+ Args:
98
+ image1 (PIL.Image): First input image.
99
+ image2 (PIL.Image): Second input image.
100
+ inl_th (float): RANSAC inlier threshold.
101
+
102
+ Returns:
103
+ dict: A dictionary containing keypoints and matches.
104
+ """
105
+ log_text = "Extracting keypoints and matches with RIPE\n"
106
+
107
+ log_text += f"Image 1 size: {image1.size}\n"
108
+ log_text += f"Image 2 size: {image2.size}\n"
109
+
110
+ # check not larger than 2048x2048
111
+ new_size = get_new_image_size(image1, min_size=MIN_SIZE, max_size=MAX_SIZE)
112
+ image1 = image1.resize(new_size)
113
+
114
+ new_size = get_new_image_size(image2, min_size=MIN_SIZE, max_size=MAX_SIZE)
115
+ image2 = image2.resize(new_size)
116
+
117
+ log_text += f"Resized Image 1 size: {image1.size}\n"
118
+ log_text += f"Resized Image 2 size: {image2.size}\n"
119
+
120
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+ model.to(dev)
122
+
123
+ image1 = image1.convert("RGB")
124
+ image2 = image2.convert("RGB")
125
+
126
+ image1_original = image1.copy()
127
+ image2_original = image2.copy()
128
+
129
+ # convert PIL images to numpy arrays
130
+ image1_original = np.array(image1_original)
131
+ image2_original = np.array(image2_original)
132
+
133
+ # convert PIL images to tensors
134
+ image1 = torch.tensor(np.array(image1)).permute(2, 0, 1).float() / 255.0
135
+ image2 = torch.tensor(np.array(image2)).permute(2, 0, 1).float() / 255.0
136
+
137
+ image1 = image1.to(dev).unsqueeze(0) # Add batch dimension
138
+ image2 = image2.to(dev).unsqueeze(0) # Add batch dimension
139
+
140
+ kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
141
+ kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
142
+
143
+ log_text += f"Number of keypoints in image 1: {kpts_1.shape[0]}\n"
144
+ log_text += f"Number of keypoints in image 2: {kpts_2.shape[0]}\n"
145
+
146
+ matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
147
+ match_dists, match_idxs = matcher(desc_1, desc_2)
148
+
149
+ log_text += f"Number of MNN matches: {match_idxs.shape[0]}\n"
150
+
151
+ cv2_matches = cv2_matches_from_kornia(match_dists, match_idxs)
152
+
153
+ do_ransac = match_idxs.shape[0] > 8
154
+
155
+ if do_ransac:
156
+ matched_pts_1 = kpts_1[match_idxs[:, 0]]
157
+ matched_pts_2 = kpts_2[match_idxs[:, 1]]
158
+
159
+ H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=inl_th)(matched_pts_1, matched_pts_2)
160
+ matchesMask = mask.int().ravel().tolist()
161
+
162
+ log_text += f"RANSAC found {mask.sum().item()} inliers out of {mask.shape[0]} matches with an inlier threshold of {inl_th}.\n"
163
+ else:
164
+ log_text += "Not enough matches for RANSAC, skipping RANSAC step.\n"
165
+
166
+ kpts_1 = to_cv_kpts(kpts_1, score_1)
167
+ kpts_2 = to_cv_kpts(kpts_2, score_2)
168
+
169
+ keypoints_raw_1 = cv2.drawKeypoints(image1_original, kpts_1, image1_original, color=(0, 255, 0))
170
+ keypoints_raw_2 = cv2.drawKeypoints(image2_original, kpts_2, image2_original, color=(0, 255, 0))
171
+
172
+ # pad height smaller image to match the height of the larger image
173
+ if keypoints_raw_1.shape[0] < keypoints_raw_2.shape[0]:
174
+ pad_height = keypoints_raw_2.shape[0] - keypoints_raw_1.shape[0]
175
+ keypoints_raw_1 = np.pad(
176
+ keypoints_raw_1, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
177
+ )
178
+ elif keypoints_raw_1.shape[0] > keypoints_raw_2.shape[0]:
179
+ pad_height = keypoints_raw_1.shape[0] - keypoints_raw_2.shape[0]
180
+ keypoints_raw_2 = np.pad(
181
+ keypoints_raw_2, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
182
+ )
183
+
184
+ # concatenate keypoints images horizontally
185
+ keypoints_raw = np.concatenate((keypoints_raw_1, keypoints_raw_2), axis=1)
186
+ keypoints_raw_pil = Image.fromarray(keypoints_raw)
187
+
188
+ result_raw = cv2.drawMatches(
189
+ image1_original,
190
+ kpts_1,
191
+ image2_original,
192
+ kpts_2,
193
+ cv2_matches,
194
+ None,
195
+ matchColor=(0, 255, 0),
196
+ matchesMask=None,
197
+ # matchesMask=None,
198
+ flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
199
+ )
200
+
201
+ if not do_ransac:
202
+ result_ransac = None
203
+ else:
204
+ result_ransac = cv2.drawMatches(
205
+ image1_original,
206
+ kpts_1,
207
+ image2_original,
208
+ kpts_2,
209
+ cv2_matches,
210
+ None,
211
+ matchColor=(0, 255, 0),
212
+ matchesMask=matchesMask,
213
+ singlePointColor=(0, 0, 255),
214
+ flags=cv2.DrawMatchesFlags_DEFAULT,
215
+ )
216
+
217
+ # result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # Convert BGR to RGB for display
218
+
219
+ # convert to PIL Image
220
+ result_raw_pil = Image.fromarray(result_raw)
221
+ if result_ransac is not None:
222
+ result_ransac_pil = Image.fromarray(result_ransac)
223
+ else:
224
+ result_ransac_pil = None
225
+
226
+ return log_text, result_ransac_pil, result_raw_pil, keypoints_raw_pil
227
+
228
+
229
+ demo = gr.Interface(
230
+ fn=extract_keypoints,
231
+ inputs=[
232
+ gr.Image(type="pil", label="Image 1"),
233
+ gr.Image(type="pil", label="Image 2"),
234
+ gr.Slider(
235
+ minimum=0.1,
236
+ maximum=3.0,
237
+ step=0.1,
238
+ value=0.5,
239
+ label="RANSAC inlier threshold",
240
+ info="Threshold for RANSAC inlier detection. Lower values may yield fewer inliers but more robust matches.",
241
+ ),
242
+ ],
243
+ outputs=[
244
+ gr.Textbox(type="text", label="Log"),
245
+ gr.Image(type="pil", label="Keypoints and Matches (RANSAC)"),
246
+ gr.Image(type="pil", label="Keypoints and Matches"),
247
+ gr.Image(type="pil", label="Keypoint Detection Results"),
248
+ ],
249
+ title="RIPE: Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction",
250
+ description=description_text,
251
+ examples=[
252
+ [
253
+ "assets_gradio/all_souls_000013.jpg",
254
+ "assets_gradio/all_souls_000055.jpg",
255
+ ],
256
+ [
257
+ "assets_gradio/167170681_0e5c42fd21_o.jpg",
258
+ "assets_gradio/170804731_6bf4fbecd4_o.jpg",
259
+ ],
260
+ [
261
+ "assets_gradio/4171014767_0fe879b783_o.jpg",
262
+ "assets_gradio/4174108353_20422632d6_o.jpg",
263
+ ],
264
+ ],
265
+ flagging_mode="never",
266
+ theme="default",
267
+ )
268
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libeigen3-dev
2
+ cmake
3
+ build-essentials
4
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ setuptools
2
+ poselib @ git+https://github.com/PoseLib/PoseLib.git@56d158f744d3561b0b70174e6d8ca9a7fc9bd9c1
3
+ opencv-python
4
+ kornia
5
+ numpy
6
+ torch
7
+ torchvision
ripe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_zoo import vgg_hyper # noqa: F401
ripe/benchmarks/imw_2020.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import kornia.feature as KF
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import poselib
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from ripe import utils
13
+ from ripe.data.data_transforms import Compose, Normalize, Resize
14
+ from ripe.data.datasets.disk_imw import DISK_IMW
15
+ from ripe.utils.pose_error import AUCMetric, relative_pose_error
16
+ from ripe.utils.utils import (
17
+ cv2_matches_from_kornia,
18
+ cv_resize_and_pad_to_shape,
19
+ to_cv_kpts,
20
+ )
21
+
22
+ log = utils.get_pylogger(__name__)
23
+
24
+
25
+ class IMW_2020_Benchmark:
26
+ def __init__(
27
+ self,
28
+ use_predefined_subset: bool = True,
29
+ conf_inference=None,
30
+ edge_input_divisible_by=None,
31
+ ):
32
+ data_dir = os.getenv("DATA_DIR")
33
+ if data_dir is None:
34
+ raise ValueError("Environment variable DATA_DIR is not set.")
35
+ root_path = Path(data_dir) / "disk-data"
36
+
37
+ self.data = DISK_IMW(
38
+ str(
39
+ root_path
40
+ ), # Resize only to ensure that the input size is divisible the value of edge_input_divisible_by
41
+ transforms=Compose(
42
+ [
43
+ Resize(None, edge_input_divisible_by),
44
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
45
+ ]
46
+ ),
47
+ )
48
+ self.ids_subset = None
49
+ self.results = []
50
+ self.conf_inference = conf_inference
51
+
52
+ # fmt: off
53
+ if use_predefined_subset:
54
+ self.ids_subset = [4921, 3561, 3143, 6040, 802, 6828, 5338, 9275, 10764, 10085, 5124, 11355, 7, 10027, 2161, 4433, 6887, 3311, 10766,
55
+ 11451, 11433, 8539, 2581, 10300, 10562, 1723, 8803, 6275, 10140, 11487, 6238, 638, 8092, 9979, 201, 10394, 3414,
56
+ 9002, 7456, 2431, 632, 6589, 9265, 9889, 3139, 7890, 10619, 4899, 675, 176, 4309, 4814, 3833, 3519, 148, 4560, 10705,
57
+ 3744, 1441, 4049, 1791, 5106, 575, 1540, 1105, 6791, 1383, 9344, 501, 2504, 4335, 8992, 10970, 10786, 10405, 9317,
58
+ 5279, 1396, 5044, 9408, 11125, 10417, 7627, 7480, 1358, 7738, 5461, 10178, 9226, 8106, 2766, 6216, 4032, 7298, 259,
59
+ 3021, 2645, 8756, 7513, 3163, 2510, 6701, 6684, 3159, 9689, 7425, 6066, 1904, 6382, 3052, 777, 6277, 7409, 5997, 2987,
60
+ 11316, 2894, 4528, 1927, 10366, 8605, 2726, 1886, 2416, 2164, 3352, 2997, 6636, 6765, 5609, 3679, 76, 10956, 3612, 6699,
61
+ 1741, 8811, 3755, 1285, 9520, 2476, 3977, 370, 9823, 1834, 7551, 6227, 7303, 6399, 4758, 10713, 5050, 380, 11056, 7620,
62
+ 4826, 6090, 9011, 7523, 7355, 8021, 9801, 1801, 6522, 7138, 10017, 8732, 6402, 3116, 4031, 6088, 3975, 9841, 9082, 9412,
63
+ 5406, 217, 2385, 8791, 8361, 494, 4319, 5275, 3274, 335, 6731, 207, 10095, 3068, 5996, 3951, 2808, 5877, 6134, 7772, 10042,
64
+ 8574, 5501, 10885, 7871]
65
+ # self.ids_subset = self.ids_subset[:10]
66
+ # fmt: on
67
+
68
+ def evaluate_sample(self, model, sample, dev):
69
+ img_1 = sample["src_image"].unsqueeze(0).to(dev)
70
+ img_2 = sample["trg_image"].unsqueeze(0).to(dev)
71
+
72
+ scale_h_1, scale_w_1 = (
73
+ sample["orig_size_src"][0] / img_1.shape[2],
74
+ sample["orig_size_src"][1] / img_1.shape[3],
75
+ )
76
+ scale_h_2, scale_w_2 = (
77
+ sample["orig_size_trg"][0] / img_2.shape[2],
78
+ sample["orig_size_trg"][1] / img_2.shape[3],
79
+ )
80
+
81
+ M = None
82
+ info = {}
83
+ kpts_1, desc_1, score_1 = None, None, None
84
+ kpts_2, desc_2, score_2 = None, None, None
85
+ match_dists, match_idxs = None, None
86
+
87
+ try:
88
+ kpts_1, desc_1, score_1 = model.detectAndCompute(img_1, **self.conf_inference)
89
+ kpts_2, desc_2, score_2 = model.detectAndCompute(img_2, **self.conf_inference)
90
+
91
+ if kpts_1.dim() == 3:
92
+ assert kpts_1.shape[0] == 1 and kpts_2.shape[0] == 1, "Batch size must be 1"
93
+
94
+ kpts_1, desc_1, score_1 = (
95
+ kpts_1.squeeze(0),
96
+ desc_1[0].squeeze(0),
97
+ score_1[0].squeeze(0),
98
+ )
99
+ kpts_2, desc_2, score_2 = (
100
+ kpts_2.squeeze(0),
101
+ desc_2[0].squeeze(0),
102
+ score_2[0].squeeze(0),
103
+ )
104
+
105
+ scale_1 = torch.tensor([scale_w_1, scale_h_1], dtype=torch.float).to(dev)
106
+ scale_2 = torch.tensor([scale_w_2, scale_h_2], dtype=torch.float).to(dev)
107
+
108
+ kpts_1 = kpts_1 * scale_1
109
+ kpts_2 = kpts_2 * scale_2
110
+
111
+ matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
112
+ match_dists, match_idxs = matcher(desc_1, desc_2)
113
+
114
+ matched_pts_1 = kpts_1[match_idxs[:, 0]]
115
+ matched_pts_2 = kpts_2[match_idxs[:, 1]]
116
+
117
+ camera_1 = sample["src_camera"]
118
+ camera_2 = sample["trg_camera"]
119
+
120
+ M, info = poselib.estimate_relative_pose(
121
+ matched_pts_1.cpu().numpy(),
122
+ matched_pts_2.cpu().numpy(),
123
+ camera_1.to_cameradict(),
124
+ camera_2.to_cameradict(),
125
+ {
126
+ "max_epipolar_error": 0.5,
127
+ },
128
+ {},
129
+ )
130
+ except RuntimeError as e:
131
+ if "No keypoints detected" in str(e):
132
+ pass
133
+ else:
134
+ raise e
135
+
136
+ success = M is not None
137
+ if success:
138
+ M = {
139
+ "R": torch.tensor(M.R, dtype=torch.float),
140
+ "t": torch.tensor(M.t, dtype=torch.float),
141
+ }
142
+ inl = info["inliers"]
143
+ else:
144
+ M = {
145
+ "R": torch.eye(3, dtype=torch.float),
146
+ "t": torch.zeros((3), dtype=torch.float),
147
+ }
148
+ inl = np.zeros((0,)).astype(bool)
149
+
150
+ t_err, r_err = relative_pose_error(sample["s2t_R"].cpu(), sample["s2t_T"].cpu(), M["R"], M["t"])
151
+
152
+ rel_pose_error = max(t_err.item(), r_err.item()) if success else np.inf
153
+ ransac_inl = np.sum(inl)
154
+ ransac_inl_ratio = np.mean(inl)
155
+
156
+ if success:
157
+ assert match_dists is not None and match_idxs is not None, "Matches must be computed"
158
+ cv_keypoints_src = to_cv_kpts(kpts_1, score_1)
159
+ cv_keypoints_trg = to_cv_kpts(kpts_2, score_2)
160
+ cv_matches = cv2_matches_from_kornia(match_dists, match_idxs)
161
+ cv_mask = [int(m) for m in inl]
162
+ else:
163
+ cv_keypoints_src, cv_keypoints_trg = [], []
164
+ cv_matches, cv_mask = [], []
165
+
166
+ estimation = {
167
+ "success": success,
168
+ "M_0to1": M,
169
+ "inliers": torch.tensor(inl).to(img_1),
170
+ "rel_pose_error": rel_pose_error,
171
+ "ransac_inl": ransac_inl,
172
+ "ransac_inl_ratio": ransac_inl_ratio,
173
+ "path_src_image": sample["src_path"],
174
+ "path_trg_image": sample["trg_path"],
175
+ "cv_keypoints_src": cv_keypoints_src,
176
+ "cv_keypoints_trg": cv_keypoints_trg,
177
+ "cv_matches": cv_matches,
178
+ "cv_mask": cv_mask,
179
+ }
180
+
181
+ return estimation
182
+
183
+ def evaluate(self, model, dev, progress_bar=False):
184
+ model.eval()
185
+
186
+ # reset results
187
+ self.results = []
188
+
189
+ for idx in tqdm(
190
+ self.ids_subset if self.ids_subset is not None else range(len(self.data)),
191
+ disable=not progress_bar,
192
+ ):
193
+ sample = self.data[idx]
194
+ self.results.append(self.evaluate_sample(model, sample, dev))
195
+
196
+ def get_auc(self, threshold=5, downsampled=False):
197
+ if len(self.results) == 0:
198
+ raise ValueError("No results to log. Run evaluate first.")
199
+
200
+ summary_results = self.calc_auc(downsampled=downsampled)
201
+
202
+ return summary_results[f"rel_pose_error@{threshold}°{'__original' if not downsampled else '__downsampled'}"]
203
+
204
+ def plot_results(self, num_samples=10, logger=None, step=None, downsampled=False):
205
+ if len(self.results) == 0:
206
+ raise ValueError("No results to plot. Run evaluate first.")
207
+
208
+ plot_data = []
209
+
210
+ for result in self.results[:num_samples]:
211
+ img1 = cv2.imread(result["path_src_image"])
212
+ img2 = cv2.imread(result["path_trg_image"])
213
+
214
+ # from BGR to RGB
215
+ img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
216
+ img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
217
+
218
+ plt_matches = cv2.drawMatches(
219
+ img1,
220
+ result["cv_keypoints_src"],
221
+ img2,
222
+ result["cv_keypoints_trg"],
223
+ result["cv_matches"],
224
+ None,
225
+ matchColor=None,
226
+ matchesMask=result["cv_mask"],
227
+ flags=cv2.DrawMatchesFlags_DEFAULT,
228
+ )
229
+ file_name = (
230
+ Path(result["path_src_image"]).parent.parent.name
231
+ + "_"
232
+ + Path(result["path_src_image"]).stem
233
+ + Path(result["path_trg_image"]).stem
234
+ + ("_downsampled" if downsampled else "")
235
+ + ".png"
236
+ )
237
+ # print rel_pose_error on image
238
+ plt_matches = cv2.putText(
239
+ plt_matches,
240
+ f"rel_pose_error: {result['rel_pose_error']:.2f} num_inliers: {result['ransac_inl']} inl_ratio: {result['ransac_inl_ratio']:.2f} num_matches: {len(result['cv_matches'])} num_keypoints: {len(result['cv_keypoints_src'])}/{len(result['cv_keypoints_trg'])}",
241
+ (10, 30),
242
+ cv2.FONT_HERSHEY_SIMPLEX,
243
+ 1,
244
+ (0, 0, 0),
245
+ 2,
246
+ cv2.LINE_8,
247
+ )
248
+
249
+ plot_data.append({"file_name": file_name, "image": plt_matches})
250
+
251
+ if logger is None:
252
+ log.info("No logger provided. Using plt to plot results.")
253
+ for image in plot_data:
254
+ plt.imsave(
255
+ image["file_name"],
256
+ cv_resize_and_pad_to_shape(image["image"], (1024, 2048)),
257
+ )
258
+ plt.close()
259
+ else:
260
+ import wandb
261
+
262
+ log.info(f"Logging images to wandb with step={step}")
263
+ if not downsampled:
264
+ logger.log(
265
+ {
266
+ "examples": [
267
+ wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
268
+ ]
269
+ },
270
+ step=step,
271
+ )
272
+ else:
273
+ logger.log(
274
+ {
275
+ "examples_downsampled": [
276
+ wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
277
+ ]
278
+ },
279
+ step=step,
280
+ )
281
+
282
+ def log_results(self, logger=None, step=None, downsampled=False):
283
+ if len(self.results) == 0:
284
+ raise ValueError("No results to log. Run evaluate first.")
285
+
286
+ summary_results = self.calc_auc(downsampled=downsampled)
287
+
288
+ if logger is not None:
289
+ logger.log(summary_results, step=step)
290
+ else:
291
+ log.warning("No logger provided. Printing results instead.")
292
+ print(self.calc_auc())
293
+
294
+ def print_results(self):
295
+ if len(self.results) == 0:
296
+ raise ValueError("No results to print. Run evaluate first.")
297
+
298
+ print(self.calc_auc())
299
+
300
+ def calc_auc(self, auc_thresholds=None, downsampled=False):
301
+ if auc_thresholds is None:
302
+ auc_thresholds = [5, 10, 20]
303
+ if not isinstance(auc_thresholds, list):
304
+ auc_thresholds = [auc_thresholds]
305
+
306
+ if len(self.results) == 0:
307
+ raise ValueError("No results to calculate auc. Run evaluate first.")
308
+
309
+ rel_pose_errors = [r["rel_pose_error"] for r in self.results]
310
+
311
+ pose_aucs = AUCMetric(auc_thresholds, rel_pose_errors).compute()
312
+ assert isinstance(pose_aucs, list) and len(pose_aucs) == len(auc_thresholds)
313
+
314
+ ext = "_downsampled" if downsampled else "_original"
315
+
316
+ summary = {}
317
+ for i, ath in enumerate(auc_thresholds):
318
+ summary[f"rel_pose_error@{ath}°_{ext}"] = pose_aucs[i]
319
+
320
+ return summary
ripe/data/__init__.py ADDED
File without changes
ripe/data/data_transforms.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import collections.abc
3
+
4
+ import kornia.geometry as KG
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.transforms import functional as TF
8
+
9
+
10
+ class Compose:
11
+ """Composes several transforms together. The transforms are applied in the order they are passed in.
12
+ Args: transforms (list): A list of transforms to be applied.
13
+ """
14
+
15
+ def __init__(self, transforms):
16
+ self.transforms = transforms
17
+
18
+ def __call__(self, src, trg, src_mask, trg_mask, h):
19
+ for t in self.transforms:
20
+ src, trg, src_mask, trg_mask, h = t(src, trg, src_mask, trg_mask, h)
21
+
22
+ return src, trg, src_mask, trg_mask, h
23
+
24
+
25
+ class Transform:
26
+ """Base class for all transforms. It provides a method to apply a transformation function to the input images and masks.
27
+ Args:
28
+ src (torch.Tensor): The source image tensor.
29
+ trg (torch.Tensor): The target image tensor.
30
+ src_mask (torch.Tensor): The source image mask tensor.
31
+ trg_mask (torch.Tensor): The target image mask tensor.
32
+ h (torch.Tensor): The homography matrix tensor.
33
+ Returns:
34
+ tuple: A tuple containing the transformed source image, the transformed target image, the transformed source mask,
35
+ the transformed target mask and the updated homography matrix.
36
+ """
37
+
38
+ def __init__(self):
39
+ pass
40
+
41
+ def apply_transform(self, src, trg, src_mask, trg_mask, h, transfrom_function):
42
+ src, trg, src_mask, trg_mask, h = transfrom_function(src, trg, src_mask, trg_mask, h)
43
+ return src, trg, src_mask, trg_mask, h
44
+
45
+
46
+ class Normalize(Transform):
47
+ def __init__(self, mean, std):
48
+ self.mean = mean
49
+ self.std = std
50
+
51
+ def __call__(self, src, trg, src_mask, trg_mask, h):
52
+ return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
53
+
54
+ def transform_function(self, src, trg, src_mask, trg_mask, h):
55
+ src = TF.normalize(src, mean=self.mean, std=self.std)
56
+ trg = TF.normalize(trg, mean=self.mean, std=self.std)
57
+ return src, trg, src_mask, trg_mask, h
58
+
59
+
60
+ class ResizeAndPadWithHomography(Transform):
61
+ def __init__(self, target_size_longer_side=768):
62
+ self.target_size = target_size_longer_side
63
+
64
+ def __call__(self, src, trg, src_mask, trg_mask, h):
65
+ return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
66
+
67
+ def transform_function(self, src, trg, src_mask, trg_mask, h):
68
+ src_w, src_h = src.shape[-1], src.shape[-2]
69
+ trg_w, trg_h = trg.shape[-1], trg.shape[-2]
70
+
71
+ # Resizing logic for both images
72
+ scale_src, new_src_w, new_src_h = self.compute_resize(src_w, src_h)
73
+ scale_trg, new_trg_w, new_trg_h = self.compute_resize(trg_w, trg_h)
74
+
75
+ # Resize both images
76
+ src_resized = TF.resize(src, [new_src_h, new_src_w])
77
+ trg_resized = TF.resize(trg, [new_trg_h, new_trg_w])
78
+
79
+ src_mask_resized = TF.resize(src_mask, [new_src_h, new_src_w])
80
+ trg_mask_resized = TF.resize(trg_mask, [new_trg_h, new_trg_w])
81
+
82
+ # Pad the resized images to be square (768x768)
83
+ src_padded, src_padding = self.apply_padding(src_resized, new_src_w, new_src_h)
84
+ trg_padded, trg_padding = self.apply_padding(trg_resized, new_trg_w, new_trg_h)
85
+
86
+ src_mask_padded, _ = self.apply_padding(src_mask_resized, new_src_w, new_src_h)
87
+ trg_mask_padded, _ = self.apply_padding(trg_mask_resized, new_trg_w, new_trg_h)
88
+
89
+ # Update the homography matrix
90
+ h = self.update_homography(h, scale_src, src_padding, scale_trg, trg_padding)
91
+
92
+ return src_padded, trg_padded, src_mask_padded, trg_mask_padded, h
93
+
94
+ def compute_resize(self, w, h):
95
+ if w > h:
96
+ scale = self.target_size / w
97
+ new_w = self.target_size
98
+ new_h = int(h * scale)
99
+ else:
100
+ scale = self.target_size / h
101
+ new_h = self.target_size
102
+ new_w = int(w * scale)
103
+ return scale, new_w, new_h
104
+
105
+ def apply_padding(self, img, new_w, new_h):
106
+ pad_w = (self.target_size - new_w) // 2
107
+ pad_h = (self.target_size - new_h) // 2
108
+ padding = [
109
+ pad_w,
110
+ pad_h,
111
+ self.target_size - new_w - pad_w,
112
+ self.target_size - new_h - pad_h,
113
+ ]
114
+ img_padded = TF.pad(img, padding, fill=0) # Zero-pad
115
+ return img_padded, padding
116
+
117
+ def update_homography(self, h, scale_src, padding_src, scale_trg, padding_trg):
118
+ # Create the scaling matrices
119
+ scale_matrix_src = np.array([[scale_src, 0, 0], [0, scale_src, 0], [0, 0, 1]])
120
+ scale_matrix_trg = np.array([[scale_trg, 0, 0], [0, scale_trg, 0], [0, 0, 1]])
121
+
122
+ # Create the padding translation matrices
123
+ pad_matrix_src = np.array([[1, 0, padding_src[0]], [0, 1, padding_src[1]], [0, 0, 1]])
124
+ pad_matrix_trg = np.array([[1, 0, -padding_trg[0]], [0, 1, -padding_trg[1]], [0, 0, 1]])
125
+
126
+ # Update the homography: apply scaling and translation
127
+ h_updated = (
128
+ pad_matrix_trg
129
+ @ scale_matrix_trg
130
+ @ h.numpy()
131
+ @ np.linalg.inv(scale_matrix_src)
132
+ @ np.linalg.inv(pad_matrix_src)
133
+ )
134
+
135
+ return torch.from_numpy(h_updated).float()
136
+
137
+
138
+ class Resize(Transform):
139
+ def __init__(self, output_size, edge_divisible_by=None, side="long", antialias=True):
140
+ self.output_size = output_size
141
+ self.edge_divisible_by = edge_divisible_by
142
+ self.side = side
143
+ self.antialias = antialias
144
+
145
+ def __call__(self, src, trg, src_mask, trg_mask, h):
146
+ return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
147
+
148
+ def transform_function(self, src, trg, src_mask, trg_mask, h):
149
+ new_size_src = self.get_new_image_size(src)
150
+ new_size_trg = self.get_new_image_size(trg)
151
+
152
+ src, T_src = self.resize(src, new_size_src)
153
+ trg, T_trg = self.resize(trg, new_size_trg)
154
+
155
+ src_mask, _ = self.resize(src_mask, new_size_src)
156
+ trg_mask, _ = self.resize(trg_mask, new_size_trg)
157
+
158
+ h = torch.from_numpy(T_trg @ h.numpy() @ T_src).float()
159
+
160
+ return src, trg, src_mask, trg_mask, h
161
+
162
+ def resize(self, img, size):
163
+ h, w = img.shape[-2:]
164
+
165
+ img = KG.transform.resize(
166
+ img,
167
+ size,
168
+ side=self.side,
169
+ antialias=self.antialias,
170
+ align_corners=None,
171
+ interpolation="bilinear",
172
+ )
173
+
174
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
175
+ T = np.diag([scale[0].item(), scale[1].item(), 1])
176
+
177
+ return img, T
178
+
179
+ def get_new_image_size(self, img):
180
+ h, w = img.shape[-2:]
181
+
182
+ if isinstance(self.output_size, collections.abc.Iterable):
183
+ assert len(self.output_size) == 2
184
+ return tuple(self.output_size)
185
+ if self.output_size is None: # keep the original size, but possibly make it divisible by edge_divisible_by
186
+ size = (h, w)
187
+ else:
188
+ side_size = self.output_size
189
+ aspect_ratio = w / h
190
+ if self.side not in ("short", "long", "vert", "horz"):
191
+ raise ValueError(f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{self.side}'")
192
+ if self.side == "vert":
193
+ size = side_size, int(side_size * aspect_ratio)
194
+ elif self.side == "horz":
195
+ size = int(side_size / aspect_ratio), side_size
196
+ elif (self.side == "short") ^ (aspect_ratio < 1.0):
197
+ size = side_size, int(side_size * aspect_ratio)
198
+ else:
199
+ size = int(side_size / aspect_ratio), side_size
200
+
201
+ if self.edge_divisible_by is not None:
202
+ df = self.edge_divisible_by
203
+ size = list(map(lambda x: int(x // df * df), size))
204
+ return size
ripe/data/datasets/__init__.py ADDED
File without changes
ripe/data/datasets/acdc.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Callable, Dict, Optional
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision.io import read_image
7
+
8
+ from ripe import utils
9
+ from ripe.data.data_transforms import Compose
10
+ from ripe.utils.utils import get_other_random_id
11
+
12
+ log = utils.get_pylogger(__name__)
13
+
14
+
15
+ class ACDC(Dataset):
16
+ def __init__(
17
+ self,
18
+ root: Path,
19
+ stage: str = "train",
20
+ condition: str = "rain",
21
+ transforms: Optional[Callable] = None,
22
+ positive_only: bool = False,
23
+ ) -> None:
24
+ self.root = root
25
+ self.stage = stage
26
+ self.condition = condition
27
+ self.transforms = transforms
28
+ self.positive_only = positive_only
29
+
30
+ if isinstance(self.root, str):
31
+ self.root = Path(self.root)
32
+
33
+ if not self.root.exists():
34
+ raise FileNotFoundError(f"Dataset not found at {self.root}")
35
+
36
+ if transforms is None:
37
+ self.transforms = Compose([])
38
+ else:
39
+ self.transforms = transforms
40
+
41
+ if self.stage not in ["train", "val", "test", "pred"]:
42
+ raise RuntimeError(
43
+ "Unknown option "
44
+ + self.stage
45
+ + " as training stage variable. Valid options: 'train', 'val', 'test' and 'pred'"
46
+ )
47
+
48
+ if self.stage == "pred": # prediction uses the test set
49
+ self.stage = "test"
50
+
51
+ if self.stage in ["val", "test", "pred"]:
52
+ self.positive_only = True
53
+ log.info(f"{self.stage} stage: Using only positive pairs!")
54
+
55
+ weather_conditions = ["fog", "night", "rain", "snow"]
56
+
57
+ if self.condition not in weather_conditions + ["all"]:
58
+ raise RuntimeError(
59
+ "Unknown option "
60
+ + self.condition
61
+ + " as weather condition variable. Valid options: 'fog', 'night', 'rain', 'snow' and 'all'"
62
+ )
63
+
64
+ self.weather_condition_query = weather_conditions if self.condition == "all" else [self.condition]
65
+
66
+ self._read_sample_files()
67
+
68
+ if positive_only:
69
+ log.warning("Using only positive pairs!")
70
+ log.info(f"Found {len(self.src_images)} source images and {len(self.trg_images)} target images.")
71
+
72
+ def _read_sample_files(self):
73
+ file_name_pattern_ref = "_ref_anon.png"
74
+ file_name_pattern = "_rgb_anon.png"
75
+
76
+ self.trg_images = []
77
+ self.src_images = []
78
+
79
+ for weather_condition in self.weather_condition_query:
80
+ rgb_files = sorted(
81
+ list(self.root.glob("rgb_anon/" + weather_condition + "/" + self.stage + "/**/*" + file_name_pattern)),
82
+ key=lambda i: i.stem[:21],
83
+ )
84
+
85
+ src_images = sorted(
86
+ list(
87
+ self.root.glob(
88
+ "rgb_anon/" + weather_condition + "/" + self.stage + "_ref" + "/**/*" + file_name_pattern_ref
89
+ )
90
+ ),
91
+ key=lambda i: i.stem[:21],
92
+ )
93
+
94
+ self.trg_images += rgb_files
95
+ self.src_images += src_images
96
+
97
+ def __len__(self) -> int:
98
+ if self.positive_only:
99
+ return len(self.trg_images)
100
+ return 2 * len(self.trg_images)
101
+
102
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
103
+ sample: Any = {}
104
+
105
+ positive_sample = (idx % 2 == 0) or (self.positive_only)
106
+ if not self.positive_only:
107
+ idx = idx // 2
108
+
109
+ sample["label"] = positive_sample
110
+
111
+ if positive_sample:
112
+ sample["src_path"] = str(self.src_images[idx])
113
+ sample["trg_path"] = str(self.trg_images[idx])
114
+
115
+ assert self.src_images[idx].stem[:21] == self.trg_images[idx].stem[:21], (
116
+ f"Source and target image mismatch: {self.src_images[idx]} vs {self.trg_images[idx]}"
117
+ )
118
+
119
+ src_img = read_image(sample["src_path"])
120
+ trg_img = read_image(sample["trg_path"])
121
+
122
+ homography = torch.eye(3, dtype=torch.float32)
123
+ else:
124
+ sample["src_path"] = str(self.src_images[idx])
125
+ idx_other = get_other_random_id(idx, len(self) // 2)
126
+ sample["trg_path"] = str(self.trg_images[idx_other])
127
+
128
+ assert self.src_images[idx].stem[:21] != self.trg_images[idx_other].stem[:21], (
129
+ f"Source and target image match for negative sample: {self.src_images[idx]} vs {self.trg_images[idx_other]}"
130
+ )
131
+
132
+ src_img = read_image(sample["src_path"])
133
+ trg_img = read_image(sample["trg_path"])
134
+
135
+ homography = torch.zeros((3, 3), dtype=torch.float32)
136
+
137
+ src_img = src_img / 255.0
138
+ trg_img = trg_img / 255.0
139
+
140
+ _, H, W = src_img.shape
141
+
142
+ src_mask = torch.ones((1, H, W), dtype=torch.uint8)
143
+ trg_mask = torch.ones((1, H, W), dtype=torch.uint8)
144
+
145
+ if self.transforms:
146
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
147
+
148
+ sample["src_image"] = src_img
149
+ sample["trg_image"] = trg_img
150
+ sample["src_mask"] = src_mask.to(torch.bool)
151
+ sample["trg_mask"] = trg_mask.to(torch.bool)
152
+ sample["homography"] = homography
153
+
154
+ return sample
ripe/data/datasets/dataset_combinator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ripe import utils
4
+
5
+ log = utils.get_pylogger(__name__)
6
+
7
+
8
+ class DatasetCombinator:
9
+ """Combines multiple datasets into one. Length of the combined dataset is the length of the
10
+ longest dataset. Shorter datasets are looped over.
11
+
12
+ Args:
13
+ datasets: List of datasets to combine.
14
+ mode: How to sample from the datasets. Can be either "uniform" or "weighted".
15
+ In "uniform" mode, each dataset is sampled with equal probability.
16
+ In "weighted" mode, each dataset is sampled with probability proportional to its length.
17
+ """
18
+
19
+ def __init__(self, datasets, mode="uniform", weights=None):
20
+ self.datasets = datasets
21
+
22
+ names_datasets = [type(ds).__name__ for ds in self.datasets]
23
+ self.lengths = [len(ds) for ds in datasets]
24
+
25
+ if mode == "weighted":
26
+ self.probs_datasets = [length / sum(self.lengths) for length in self.lengths]
27
+ elif mode == "uniform":
28
+ self.probs_datasets = [1 / len(self.datasets) for _ in self.datasets]
29
+ elif mode == "custom":
30
+ assert weights is not None, "Weights must be provided in custom mode"
31
+ assert len(weights) == len(datasets), "Number of weights must match number of datasets"
32
+ assert sum(weights) == 1.0, "Weights must sum to 1"
33
+ self.probs_datasets = weights
34
+ else:
35
+ raise ValueError(f"Unknown mode {mode}")
36
+
37
+ log.info("Got the following datasets: ")
38
+
39
+ for name, length, prob in zip(names_datasets, self.lengths, self.probs_datasets):
40
+ log.info(f"{name} with {length} samples and probability {prob}")
41
+ log.info(f"Total number of samples: {sum(self.lengths)}")
42
+
43
+ self.num_samples = max(self.lengths)
44
+
45
+ self.dataset_dist = torch.distributions.Categorical(probs=torch.tensor(self.probs_datasets))
46
+
47
+ def __len__(self):
48
+ return self.num_samples
49
+
50
+ def __getitem__(self, idx: int):
51
+ positive_sample = idx % 2 == 0
52
+
53
+ if positive_sample:
54
+ dataset_idx = self.dataset_dist.sample().item()
55
+
56
+ idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
57
+ while idx % 2 == 1:
58
+ idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
59
+
60
+ return self.datasets[dataset_idx][idx]
61
+ else:
62
+ dataset_idx_1 = self.dataset_dist.sample().item()
63
+ dataset_idx_2 = self.dataset_dist.sample().item()
64
+
65
+ if dataset_idx_1 == dataset_idx_2:
66
+ idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
67
+ while idx % 2 == 0:
68
+ idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
69
+ return self.datasets[dataset_idx_1][idx]
70
+
71
+ else:
72
+ idx_1 = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
73
+ idx_2 = torch.randint(0, self.lengths[dataset_idx_2], (1,)).item()
74
+
75
+ sample_1 = self.datasets[dataset_idx_1][idx_1]
76
+ sample_2 = self.datasets[dataset_idx_2][idx_2]
77
+
78
+ sample = {
79
+ "label": False,
80
+ "src_path": sample_1["src_path"],
81
+ "trg_path": sample_2["trg_path"],
82
+ "src_image": sample_1["src_image"],
83
+ "trg_image": sample_2["trg_image"],
84
+ "src_mask": sample_1["src_mask"],
85
+ "trg_mask": sample_2["trg_mask"],
86
+ "homography": sample_2["homography"],
87
+ }
88
+ return sample
ripe/data/datasets/disk_imw.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from itertools import accumulate
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, Optional, Tuple
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision.io import read_image
10
+
11
+ from ripe import utils
12
+ from ripe.data.data_transforms import Compose
13
+ from ripe.utils.image_utils import Camera, cameras2F
14
+
15
+ log = utils.get_pylogger(__name__)
16
+
17
+
18
+ class DISK_IMW(Dataset):
19
+ def __init__(
20
+ self,
21
+ root: str,
22
+ stage: str = "val",
23
+ # condition: str = "rain",
24
+ transforms: Optional[Callable] = None,
25
+ ) -> None:
26
+ self.root = root
27
+ self.stage = stage
28
+ self.transforms = transforms
29
+
30
+ if isinstance(self.root, str):
31
+ self.root = Path(self.root)
32
+
33
+ if not self.root.exists():
34
+ raise FileNotFoundError(f"Dataset not found at {self.root}")
35
+
36
+ if transforms is None:
37
+ self.transforms = Compose([])
38
+ else:
39
+ self.transforms = transforms
40
+
41
+ if self.stage not in ["val"]:
42
+ raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
43
+
44
+ json_path = self.root / "imw2020-val" / "dataset.json"
45
+ with open(json_path) as json_file:
46
+ json_data = json.load(json_file)
47
+
48
+ self.scenes = []
49
+
50
+ for scene in json_data:
51
+ self.scenes.append(Scene(self.root / "imw2020-val", json_data[scene]))
52
+
53
+ self.tuples_per_scene = [len(scene) for scene in self.scenes]
54
+
55
+ def __len__(self) -> int:
56
+ return sum(self.tuples_per_scene)
57
+
58
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
59
+ sample: Any = {}
60
+
61
+ i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
62
+
63
+ sample["src_path"], sample["trg_path"], path_calib_src, path_calib_trg = self.scenes[i_scene][i_image]
64
+
65
+ cam_src = Camera.from_calibration_file(path_calib_src)
66
+ cam_trg = Camera.from_calibration_file(path_calib_trg)
67
+
68
+ F = self.get_F(cam_src, cam_trg)
69
+ s2t_R, s2t_T = self.get_relative_pose(cam_src, cam_trg)
70
+
71
+ src_img = read_image(sample["src_path"]) / 255.0
72
+ trg_img = read_image(sample["trg_path"]) / 255.0
73
+
74
+ _, H_src, W_src = src_img.shape
75
+ _, H_trg, W_trg = trg_img.shape
76
+
77
+ src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
78
+ trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
79
+
80
+ H = torch.eye(3)
81
+ if self.transforms:
82
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, H)
83
+
84
+ # check if transformations in self.transforms. Only Normalize is allowed
85
+ for t in self.transforms.transforms:
86
+ if t.__class__.__name__ not in ["Normalize", "Resize"]:
87
+ raise ValueError(f"Transform {t.__class__.__name__} not allowed in DISK_IMW dataset")
88
+
89
+ sample["src_image"] = src_img
90
+ sample["trg_image"] = trg_img
91
+ sample["orig_size_src"] = (H_src, W_src)
92
+ sample["orig_size_trg"] = (H_trg, W_trg)
93
+ sample["src_mask"] = src_mask.to(torch.bool)
94
+ sample["trg_mask"] = trg_mask.to(torch.bool)
95
+ sample["F"] = F
96
+ sample["s2t_R"] = s2t_R
97
+ sample["s2t_T"] = s2t_T
98
+ sample["src_camera"] = cam_src
99
+ sample["trg_camera"] = cam_trg
100
+
101
+ return sample
102
+
103
+ def get_relative_pose(self, cam_src: Camera, cam_trg: Camera) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ R = cam_trg.R @ cam_src.R.T
105
+ T = cam_trg.t - R @ cam_src.t
106
+
107
+ return R, T
108
+
109
+ def get_F(self, cam_src: Camera, cam_trg: Camera) -> torch.Tensor:
110
+ F = cameras2F(cam_src, cam_trg)
111
+
112
+ return F
113
+
114
+ def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
115
+ accumulated_tuples = accumulate(self.tuples_per_scene)
116
+
117
+ if idx >= sum(self.tuples_per_scene):
118
+ raise IndexError(f"Index {idx} out of bounds")
119
+
120
+ idx_scene = None
121
+ for i, accumulated_tuple in enumerate(accumulated_tuples):
122
+ idx_scene = i
123
+ if idx < accumulated_tuple:
124
+ break
125
+
126
+ idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
127
+
128
+ return idx_scene, idx_image
129
+
130
+ def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
131
+ possible_scene_ids = list(range(len(self.scenes)))
132
+ possible_scene_ids.remove(scene_id_to_exclude)
133
+
134
+ idx_scene = random.choice(possible_scene_ids)
135
+ idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
136
+
137
+ return idx_scene, idx_image
138
+
139
+
140
+ class Scene:
141
+ def __init__(self, root_path, scene_data: Dict[str, Any]) -> None:
142
+ self.root_path = root_path
143
+ self.image_path = Path(scene_data["image_path"])
144
+ self.calib_path = Path(scene_data["calib_path"])
145
+ self.image_names = scene_data["images"]
146
+ self.tuples = scene_data["tuples"]
147
+
148
+ def __len__(self) -> int:
149
+ return len(self.tuples)
150
+
151
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
152
+ idx_1 = self.tuples[idx][0]
153
+ idx_2 = self.tuples[idx][1]
154
+
155
+ path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1]) + ".jpg"
156
+ path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2]) + ".jpg"
157
+ path_calib_1 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_1])) + ".h5"
158
+ path_calib_2 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_2])) + ".h5"
159
+
160
+ return path_image_1, path_image_2, path_calib_1, path_calib_2
ripe/data/datasets/disk_megadepth.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from itertools import accumulate
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, Optional, Tuple
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision.io import read_image
10
+
11
+ from ripe import utils
12
+ from ripe.data.data_transforms import Compose
13
+
14
+ log = utils.get_pylogger(__name__)
15
+
16
+
17
+ class DISK_Megadepth(Dataset):
18
+ def __init__(
19
+ self,
20
+ root: str,
21
+ max_scene_size: int,
22
+ stage: str = "train",
23
+ # condition: str = "rain",
24
+ transforms: Optional[Callable] = None,
25
+ positive_only: bool = False,
26
+ ) -> None:
27
+ self.root = root
28
+ self.stage = stage
29
+ self.transforms = transforms
30
+ self.positive_only = positive_only
31
+
32
+ if isinstance(self.root, str):
33
+ self.root = Path(self.root)
34
+
35
+ if not self.root.exists():
36
+ raise FileNotFoundError(f"Dataset not found at {self.root}")
37
+
38
+ if transforms is None:
39
+ self.transforms = Compose([])
40
+ else:
41
+ self.transforms = transforms
42
+
43
+ if self.stage not in ["train"]:
44
+ raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
45
+
46
+ json_path = self.root / "megadepth" / "dataset.json"
47
+ with open(json_path) as json_file:
48
+ json_data = json.load(json_file)
49
+
50
+ self.scenes = []
51
+
52
+ for scene in json_data:
53
+ self.scenes.append(Scene(self.root / "megadepth", json_data[scene], max_scene_size))
54
+
55
+ self.tuples_per_scene = [len(scene) for scene in self.scenes]
56
+
57
+ if positive_only:
58
+ log.warning("Using only positive pairs!")
59
+
60
+ def __len__(self) -> int:
61
+ if self.positive_only:
62
+ return sum(self.tuples_per_scene)
63
+ return 2 * sum(self.tuples_per_scene)
64
+
65
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
66
+ sample: Any = {}
67
+
68
+ positive_sample = idx % 2 == 0 or self.positive_only
69
+ if not self.positive_only:
70
+ idx = idx // 2
71
+
72
+ sample["label"] = positive_sample
73
+
74
+ i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
75
+
76
+ if positive_sample:
77
+ sample["src_path"], sample["trg_path"] = self.scenes[i_scene][i_image]
78
+
79
+ homography = torch.eye(3, dtype=torch.float32)
80
+ else:
81
+ sample["src_path"], _ = self.scenes[i_scene][i_image]
82
+
83
+ i_scene_other, i_image_other = self._get_other_random_scene_and_image_id(i_scene)
84
+
85
+ sample["trg_path"], _ = self.scenes[i_scene_other][i_image_other]
86
+
87
+ homography = torch.zeros((3, 3), dtype=torch.float32)
88
+
89
+ src_img = read_image(sample["src_path"]) / 255.0
90
+ trg_img = read_image(sample["trg_path"]) / 255.0
91
+
92
+ _, H_src, W_src = src_img.shape
93
+ _, H_trg, W_trg = trg_img.shape
94
+
95
+ src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
96
+ trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
97
+
98
+ if self.transforms:
99
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
100
+
101
+ sample["src_image"] = src_img
102
+ sample["trg_image"] = trg_img
103
+ sample["src_mask"] = src_mask.to(torch.bool)
104
+ sample["trg_mask"] = trg_mask.to(torch.bool)
105
+ sample["homography"] = homography
106
+
107
+ return sample
108
+
109
+ def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
110
+ accumulated_tuples = accumulate(self.tuples_per_scene)
111
+
112
+ if idx >= sum(self.tuples_per_scene):
113
+ raise IndexError(f"Index {idx} out of bounds")
114
+
115
+ idx_scene = None
116
+ for i, accumulated_tuple in enumerate(accumulated_tuples):
117
+ idx_scene = i
118
+ if idx < accumulated_tuple:
119
+ break
120
+
121
+ idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
122
+
123
+ return idx_scene, idx_image
124
+
125
+ def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
126
+ possible_scene_ids = list(range(len(self.scenes)))
127
+ possible_scene_ids.remove(scene_id_to_exclude)
128
+
129
+ idx_scene = random.choice(possible_scene_ids)
130
+ idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
131
+
132
+ return idx_scene, idx_image
133
+
134
+
135
+ class Scene:
136
+ def __init__(self, root_path, scene_data: Dict[str, Any], max_size_scene) -> None:
137
+ self.root_path = root_path
138
+ self.image_path = Path(scene_data["image_path"])
139
+ self.image_names = scene_data["images"]
140
+
141
+ # randomly sample tuples
142
+ if max_size_scene > 0:
143
+ self.tuples = random.sample(scene_data["tuples"], min(max_size_scene, len(scene_data["tuples"])))
144
+
145
+ def __len__(self) -> int:
146
+ return len(self.tuples)
147
+
148
+ def __getitem__(self, idx: int) -> Tuple[str, str]:
149
+ idx_1, idx_2 = random.sample([0, 1, 2], 2)
150
+
151
+ idx_1 = self.tuples[idx][idx_1]
152
+ idx_2 = self.tuples[idx][idx_2]
153
+
154
+ path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1])
155
+ path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2])
156
+
157
+ return path_image_1, path_image_2
ripe/data/datasets/tokyo247.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from glob import glob
4
+ from typing import Any, Callable, Optional
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from torchvision.io import read_image
9
+
10
+ from ripe import utils
11
+ from ripe.data.data_transforms import Compose
12
+
13
+ log = utils.get_pylogger(__name__)
14
+
15
+
16
+ class Tokyo247(Dataset):
17
+ def __init__(
18
+ self,
19
+ root: str,
20
+ stage: str = "train",
21
+ transforms: Optional[Callable] = None,
22
+ positive_only: bool = False,
23
+ ):
24
+ if stage != "train":
25
+ raise ValueError("Tokyo247Dataset only supports the 'train' stage.")
26
+
27
+ # check if the root directory exists
28
+ if not os.path.isdir(root):
29
+ raise FileNotFoundError(f"Directory {root} does not exist.")
30
+
31
+ self.root_dir = root
32
+ self.transforms = transforms if transforms is not None else Compose([])
33
+ self.positive_only = positive_only
34
+
35
+ self.image_paths = []
36
+ self.positive_pairs = []
37
+
38
+ # Collect images grouped by location folder
39
+ self.locations = {}
40
+ for location_rough in sorted(os.listdir(self.root_dir)):
41
+ location_rough_path = os.path.join(self.root_dir, location_rough)
42
+
43
+ # check if the location_rough_path is a directory
44
+ if not os.path.isdir(location_rough_path):
45
+ continue
46
+
47
+ for location_fine in sorted(os.listdir(location_rough_path)):
48
+ location_fine_path = os.path.join(self.root_dir, location_rough, location_fine)
49
+
50
+ if os.path.isdir(location_fine_path):
51
+ images = sorted(
52
+ glob(os.path.join(location_fine_path, "*.png")),
53
+ key=lambda i: int(i[-7:-4]),
54
+ )
55
+ if len(images) >= 12:
56
+ self.locations[location_fine] = images
57
+ self.image_paths.extend(images)
58
+
59
+ # Generate positive pairs
60
+ for _, images in self.locations.items():
61
+ for i in range(len(images) - 1):
62
+ self.positive_pairs.append((images[i], images[i + 1]))
63
+ self.positive_pairs.append((images[-1], images[0]))
64
+
65
+ if positive_only:
66
+ log.warning("Using only positive pairs!")
67
+
68
+ log.info(f"Found {len(self.positive_pairs)} image pairs.")
69
+
70
+ def __len__(self):
71
+ if self.positive_only:
72
+ return len(self.positive_pairs)
73
+ return 2 * len(self.positive_pairs)
74
+
75
+ def __getitem__(self, idx):
76
+ sample: Any = {}
77
+
78
+ positive_sample = (idx % 2 == 0) or (self.positive_only)
79
+ if not self.positive_only:
80
+ idx = idx // 2
81
+
82
+ sample["label"] = positive_sample
83
+
84
+ if positive_sample: # Positive pair
85
+ img1_path, img2_path = self.positive_pairs[idx]
86
+
87
+ assert os.path.dirname(img1_path) == os.path.dirname(img2_path), (
88
+ f"Source and target image mismatch: {img1_path} vs {img2_path}"
89
+ )
90
+
91
+ homography = torch.eye(3, dtype=torch.float32)
92
+ else: # Negative pair
93
+ img1_path = random.choice(self.image_paths)
94
+ img2_path = random.choice(self.image_paths)
95
+
96
+ # Ensure images are from different folders
97
+ esc = 0
98
+ while os.path.dirname(img1_path) == os.path.dirname(img2_path):
99
+ img2_path = random.choice(self.image_paths)
100
+
101
+ esc += 1
102
+ if esc > 100:
103
+ raise RuntimeError("Could not find a negative pair.")
104
+
105
+ assert os.path.dirname(img1_path) != os.path.dirname(img2_path), (
106
+ f"Source and target image match for negative pair: {img1_path} vs {img2_path}"
107
+ )
108
+
109
+ homography = torch.zeros((3, 3), dtype=torch.float32)
110
+
111
+ sample["src_path"] = img1_path
112
+ sample["trg_path"] = img2_path
113
+
114
+ # Load images
115
+ src_img = read_image(sample["src_path"]) / 255.0
116
+ trg_img = read_image(sample["trg_path"]) / 255.0
117
+
118
+ _, H_src, W_src = src_img.shape
119
+ _, H_trg, W_trg = src_img.shape
120
+
121
+ src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
122
+ trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
123
+
124
+ # Apply transformations
125
+ if self.transforms:
126
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
127
+
128
+ sample["src_image"] = src_img
129
+ sample["trg_image"] = trg_img
130
+ sample["src_mask"] = src_mask.to(torch.bool)
131
+ sample["trg_mask"] = trg_mask.to(torch.bool)
132
+ sample["homography"] = homography
133
+
134
+ return sample
ripe/losses/__init__.py ADDED
File without changes
ripe/losses/contrastive_loss.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def second_nearest_neighbor(desc1, desc2):
7
+ if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
8
+ raise ValueError("desc2 should have at least 2 descriptors")
9
+
10
+ dist = torch.cdist(desc1, desc2, p=2)
11
+
12
+ vals, idxs = torch.topk(dist, 2, dim=1, largest=False)
13
+ idxs_in_2 = idxs[:, 1]
14
+ idxs_in_1 = torch.arange(0, idxs_in_2.size(0), device=dist.device)
15
+
16
+ matches_idxs = torch.cat([idxs_in_1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
17
+
18
+ return vals[:, 1].view(-1, 1), matches_idxs
19
+
20
+
21
+ def contrastive_loss(
22
+ desc1,
23
+ desc2,
24
+ matches,
25
+ inliers,
26
+ label,
27
+ logits_1,
28
+ logits_2,
29
+ pos_margin=1.0,
30
+ neg_margin=1.0,
31
+ ):
32
+ if inliers.sum() < 8: # if there are too few inliers, calculate loss on all matches
33
+ inliers = torch.ones_like(inliers)
34
+
35
+ matched_inliers_descs1 = desc1[matches[:, 0][inliers]]
36
+ matched_inliers_descs2 = desc2[matches[:, 1][inliers]]
37
+
38
+ if logits_1 is not None and logits_2 is not None:
39
+ matched_inliers_logits1 = logits_1[matches[:, 0][inliers]]
40
+ matched_inliers_logits2 = logits_2[matches[:, 1][inliers]]
41
+ logits = torch.minimum(matched_inliers_logits1, matched_inliers_logits2)
42
+ else:
43
+ logits = torch.ones_like(matches[:, 0][inliers])
44
+
45
+ if label:
46
+ snn_match_dists_1, idx1 = second_nearest_neighbor(matched_inliers_descs1, desc2)
47
+ snn_match_dists_2, idx2 = second_nearest_neighbor(matched_inliers_descs2, desc1)
48
+
49
+ dists = torch.hstack((snn_match_dists_1, snn_match_dists_2))
50
+ min_dists_idx = torch.min(dists, dim=1).indices.unsqueeze(1)
51
+
52
+ dists_hard = torch.gather(dists, 1, min_dists_idx).squeeze(-1)
53
+ dists_pos = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
54
+
55
+ contrastive_loss = torch.clamp(pos_margin + dists_pos - dists_hard, min=0.0)
56
+
57
+ contrastive_loss = contrastive_loss * logits
58
+
59
+ contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8) # small epsilon to avoid division by zero
60
+ else:
61
+ dists = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
62
+ contrastive_loss = torch.clamp(neg_margin - dists, min=0.0)
63
+
64
+ contrastive_loss = contrastive_loss * logits
65
+
66
+ contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8) # small epsilon to avoid division by zero
67
+
68
+ return contrastive_loss
69
+
70
+
71
+ class ContrastiveLoss(nn.Module):
72
+ def __init__(self, pos_margin=1.0, neg_margin=1.0):
73
+ super().__init__()
74
+ self.pos_margin = pos_margin
75
+ self.neg_margin = neg_margin
76
+
77
+ def forward(self, desc1, desc2, matches, inliers, label, logits_1=None, logits_2=None):
78
+ return contrastive_loss(
79
+ desc1,
80
+ desc2,
81
+ matches,
82
+ inliers,
83
+ label,
84
+ logits_1,
85
+ logits_2,
86
+ self.pos_margin,
87
+ self.neg_margin,
88
+ )
ripe/matcher/__init__.py ADDED
File without changes
ripe/matcher/concurrent_matcher.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+
3
+ import torch
4
+
5
+
6
+ class ConcurrentMatcher:
7
+ """A class that performs matching and geometric filtering in parallel using a thread pool executor.
8
+ It matches keypoints from two sets of descriptors and applies a robust estimator to filter the matches based on geometric constraints.
9
+
10
+ Args:
11
+ matcher (callable): A callable that takes two sets of descriptors and returns distances and indices of matches.
12
+ robust_estimator (callable): A callable that estimates a geometric transformation and returns inliers.
13
+ min_num_matches (int, optional): Minimum number of matches required to perform geometric filtering. Defaults to 8.
14
+ max_workers (int, optional): Maximum number of threads in the thread pool executor. Defaults to 12.
15
+ """
16
+
17
+ def __init__(self, matcher, robust_estimator, min_num_matches=8, max_workers=12):
18
+ self.matcher = matcher
19
+ self.robust_estimator = robust_estimator
20
+ self.min_num_matches = min_num_matches
21
+
22
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
23
+
24
+ @torch.no_grad()
25
+ def __call__(
26
+ self,
27
+ kpts1,
28
+ kpts2,
29
+ pdesc1,
30
+ pdesc2,
31
+ selected_mask1,
32
+ selected_mask2,
33
+ inl_th,
34
+ label=None,
35
+ ):
36
+ dev = pdesc1.device
37
+ B = pdesc1.shape[0]
38
+
39
+ batch_rel_idx_matches = [None] * B
40
+ batch_idx_matches = [None] * B
41
+ future_results = [None] * B
42
+
43
+ for b in range(B):
44
+ if selected_mask1[b].sum() < 16 or selected_mask2[b].sum() < 16:
45
+ continue
46
+
47
+ dists, idx_matches = self.matcher(pdesc1[b][selected_mask1[b]], pdesc2[b][selected_mask2[b]])
48
+
49
+ batch_rel_idx_matches[b] = idx_matches.clone()
50
+
51
+ # calculate ABSOLUTE indexes
52
+ idx_matches[:, 0] = torch.nonzero(selected_mask1[b], as_tuple=False)[idx_matches[:, 0]].squeeze()
53
+ idx_matches[:, 1] = torch.nonzero(selected_mask2[b], as_tuple=False)[idx_matches[:, 1]].squeeze()
54
+
55
+ batch_idx_matches[b] = idx_matches
56
+
57
+ # if not enough matches
58
+ if idx_matches.shape[0] < self.min_num_matches:
59
+ ransac_inliers = torch.zeros((idx_matches.shape[0]), device=dev).bool()
60
+ future_results[b] = (None, ransac_inliers)
61
+ continue
62
+
63
+ # use label information to exclude negative pairs from geometric filtering process -> enforces more descriminative descriptors
64
+ if label is not None and label[b] == 0:
65
+ ransac_inliers = torch.ones((idx_matches.shape[0]), device=dev).bool()
66
+ future_results[b] = (None, ransac_inliers)
67
+ continue
68
+
69
+ mkpts1 = kpts1[b][idx_matches[:, 0]]
70
+ mkpts2 = kpts2[b][idx_matches[:, 1]]
71
+
72
+ future_results[b] = self.executor.submit(self.robust_estimator, mkpts1, mkpts2, inl_th)
73
+
74
+ batch_ransac_inliers = [None] * B
75
+ batch_Fm = [None] * B
76
+
77
+ for b in range(B):
78
+ future_result = future_results[b]
79
+ if future_result is None:
80
+ ransac_inliers = None
81
+ Fm = None
82
+ elif isinstance(future_result, tuple):
83
+ Fm, ransac_inliers = future_result
84
+ else:
85
+ Fm, ransac_inliers = future_result.result()
86
+
87
+ # if no inliers
88
+ if ransac_inliers.sum() == 0:
89
+ ransac_inliers = ransac_inliers.squeeze(
90
+ -1
91
+ ) # kornia.geometry.ransac.RANSAC returns (N, 1) tensor if no inliers and (N,) tensor if inliers
92
+ Fm = None
93
+
94
+ batch_ransac_inliers[b] = ransac_inliers
95
+ batch_Fm[b] = Fm
96
+
97
+ return batch_rel_idx_matches, batch_idx_matches, batch_ransac_inliers, batch_Fm
ripe/matcher/pose_estimator_poselib.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import poselib
2
+ import torch
3
+
4
+
5
+ class PoseLibRelativePoseEstimator:
6
+ """PoseLibRelativePoseEstimator estimates the fundamental matrix using poselib library.
7
+ It uses the poselib's estimate_fundamental function to compute the fundamental matrix and inliers based on the provided points.
8
+ Args:
9
+ None
10
+ """
11
+
12
+ def __init__(self):
13
+ pass
14
+
15
+ def __call__(self, pts0, pts1, inl_th):
16
+ F, info = poselib.estimate_fundamental(
17
+ pts0.cpu().numpy(),
18
+ pts1.cpu().numpy(),
19
+ {
20
+ "max_epipolar_error": inl_th,
21
+ },
22
+ )
23
+
24
+ success = F is not None
25
+ if success:
26
+ inliers = info.pop("inliers")
27
+ inliers = torch.tensor(inliers, dtype=torch.bool, device=pts0.device)
28
+ else:
29
+ inliers = torch.zeros(pts0.shape[0], dtype=torch.bool, device=pts0.device)
30
+
31
+ return F, inliers
ripe/model_zoo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .vgg_hyper import vgg_hyper # noqa: F401
ripe/model_zoo/vgg_hyper.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from ripe.models.backbones.vgg import VGG
6
+ from ripe.models.ripe import RIPE
7
+ from ripe.models.upsampler.hypercolumn_features import HyperColumnFeatures
8
+
9
+
10
+ def vgg_hyper(model_path: Path = None, desc_shares=None):
11
+ if model_path is None:
12
+ # check if the weights file exists in the current directory
13
+ model_path = Path("/tmp/ripe_weights.pth")
14
+
15
+ if model_path.exists():
16
+ print(f"Using existing weights from {model_path}")
17
+ else:
18
+ print("Weights file not found. Downloading ...")
19
+ torch.hub.download_url_to_file(
20
+ "https://cvg.hhi.fraunhofer.de/RIPE/ripe_weights.pth",
21
+ "/tmp/ripe_weights.pth",
22
+ )
23
+ else:
24
+ if not model_path.exists():
25
+ print(f"Error: {model_path} does not exist.")
26
+ raise FileNotFoundError(f"Error: {model_path} does not exist.")
27
+
28
+ backbone = VGG(pretrained=False)
29
+ upsampler = HyperColumnFeatures()
30
+
31
+ extractor = RIPE(
32
+ net=backbone,
33
+ upsampler=upsampler,
34
+ desc_shares=desc_shares,
35
+ )
36
+
37
+ extractor.load_state_dict(torch.load(model_path, map_location="cpu"))
38
+
39
+ return extractor
ripe/models/__init__.py ADDED
File without changes
ripe/models/backbones/__init__.py ADDED
File without changes
ripe/models/backbones/backbone_base.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class BackboneBase(nn.Module):
6
+ """Base class for backbone networks. Provides a standard interface for preprocessing inputs and
7
+ defining encoder dimensions.
8
+
9
+ Args:
10
+ nchannels (int): Number of input channels.
11
+ use_instance_norm (bool): Whether to apply instance normalization.
12
+ """
13
+
14
+ def __init__(self, nchannels=3, use_instance_norm=False):
15
+ super().__init__()
16
+ assert nchannels > 0, "Number of channels must be positive."
17
+ self.nchannels = nchannels
18
+ self.use_instance_norm = use_instance_norm
19
+ self.norm = nn.InstanceNorm2d(nchannels) if use_instance_norm else None
20
+
21
+ def get_dim_layers_encoder(self):
22
+ """Get dimensions of encoder layers."""
23
+ raise NotImplementedError("Subclasses must implement this method.")
24
+
25
+ def _forward(self, x):
26
+ """Define the forward pass for the backbone."""
27
+ raise NotImplementedError("Subclasses must implement this method.")
28
+
29
+ def forward(self, x: torch.Tensor, preprocess=True):
30
+ """Forward pass with optional preprocessing.
31
+
32
+ Args:
33
+ x (Tensor): Input tensor.
34
+ preprocess (bool): Whether to apply channel reduction.
35
+ """
36
+ if preprocess:
37
+ if x.dim() != 4:
38
+ if x.dim() == 2 and x.shape[0] > 3 and x.shape[1] > 3:
39
+ x = x.unsqueeze(0).unsqueeze(0)
40
+ elif x.dim() == 3:
41
+ x = x.unsqueeze(0)
42
+ else:
43
+ raise ValueError(f"Unexpected input shape: {x.shape}")
44
+
45
+ if self.nchannels == 1 and x.shape[1] != 1:
46
+ if len(x.shape) == 4: # Assumes (batch, channel, height, width)
47
+ x = torch.mean(x, axis=1, keepdim=True)
48
+ else:
49
+ raise ValueError(f"Unexpected input shape: {x.shape}")
50
+
51
+ #
52
+ if self.nchannels == 3 and x.shape[1] == 1:
53
+ if len(x.shape) == 4:
54
+ x = x.repeat(1, 3, 1, 1)
55
+ else:
56
+ raise ValueError(f"Unexpected input shape: {x.shape}")
57
+
58
+ if self.use_instance_norm:
59
+ x = self.norm(x)
60
+
61
+ return self._forward(x)
ripe/models/backbones/vgg.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .backbone_base import BackboneBase
7
+ from .vgg_utils import VGG19, ConvRefiner, Decoder
8
+
9
+
10
+ class VGG(BackboneBase):
11
+ def __init__(self, nchannels=3, pretrained=True, use_instance_norm=True, mode="dect"):
12
+ super().__init__(nchannels=nchannels, use_instance_norm=use_instance_norm)
13
+
14
+ self.nchannels = nchannels
15
+ self.mode = mode
16
+
17
+ if self.mode not in ["dect", "desc", "dect+desc"]:
18
+ raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
19
+
20
+ NUM_OUTPUT_CHANNELS, hidden_blocks = self._get_mode_params(mode)
21
+ conv_refiner = self._create_conv_refiner(NUM_OUTPUT_CHANNELS, hidden_blocks)
22
+
23
+ self.encoder = VGG19(pretrained=pretrained, num_input_channels=nchannels)
24
+ self.decoder = Decoder(conv_refiner, num_prototypes=NUM_OUTPUT_CHANNELS)
25
+
26
+ def _get_mode_params(self, mode):
27
+ """Get the number of output channels and the number of hidden blocks for the ConvRefiner.
28
+
29
+ Depending on the mode, the ConvRefiner will have a different number of output channels.
30
+ """
31
+
32
+ if mode == "dect":
33
+ return 1, 8
34
+ elif mode == "desc":
35
+ return 256, 5
36
+ elif mode == "dect+desc":
37
+ return 256 + 1, 8
38
+
39
+ def _create_conv_refiner(self, num_output_channels, hidden_blocks):
40
+ return nn.ModuleDict(
41
+ {
42
+ "8": ConvRefiner(
43
+ 512,
44
+ 512,
45
+ 256 + num_output_channels,
46
+ hidden_blocks=hidden_blocks,
47
+ residual=True,
48
+ ),
49
+ "4": ConvRefiner(
50
+ 256 + 256,
51
+ 256,
52
+ 128 + num_output_channels,
53
+ hidden_blocks=hidden_blocks,
54
+ residual=True,
55
+ ),
56
+ "2": ConvRefiner(
57
+ 128 + 128,
58
+ 128,
59
+ 64 + num_output_channels,
60
+ hidden_blocks=hidden_blocks,
61
+ residual=True,
62
+ ),
63
+ "1": ConvRefiner(
64
+ 64 + 64,
65
+ 64,
66
+ 1 + num_output_channels,
67
+ hidden_blocks=hidden_blocks,
68
+ residual=True,
69
+ ),
70
+ }
71
+ )
72
+
73
+ def get_dim_layers_encoder(self):
74
+ return self.encoder.get_dim_layers()
75
+
76
+ def _forward(self, x):
77
+ features, sizes = self.encoder(x)
78
+ output = 0
79
+ context = None
80
+ scales = self.decoder.scales
81
+ for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
82
+ delta_descriptor, context = self.decoder(feature_map, scale=scale, context=context)
83
+ output = output + delta_descriptor
84
+ if idx < len(scales) - 1:
85
+ size = sizes[-(idx + 2)]
86
+ output = F.interpolate(output, size=size, mode="bilinear", align_corners=False)
87
+ context = F.interpolate(context, size=size, mode="bilinear", align_corners=False)
88
+
89
+ if self.mode == "dect":
90
+ return {"heatmap": output, "coarse_descs": features}
91
+ elif self.mode == "desc":
92
+ return {"fine_descs": output, "coarse_descs": features}
93
+ elif self.mode == "dect+desc":
94
+ logits = output[:, :1].contiguous()
95
+ descs = output[:, 1:].contiguous()
96
+
97
+ return {"heatmap": logits, "fine_descs": descs, "coarse_descs": features}
98
+ else:
99
+ raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
ripe/models/backbones/vgg_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as tvm
6
+
7
+ from ripe import utils
8
+
9
+ log = utils.get_pylogger(__name__)
10
+
11
+
12
+ class Decoder(nn.Module):
13
+ def __init__(self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs) -> None:
14
+ super().__init__(*args, **kwargs)
15
+ self.layers = layers
16
+ self.scales = self.layers.keys()
17
+ self.super_resolution = super_resolution
18
+ self.num_prototypes = num_prototypes
19
+
20
+ def forward(self, features, context=None, scale=None):
21
+ if context is not None:
22
+ features = torch.cat((features, context), dim=1)
23
+ stuff = self.layers[scale](features)
24
+ logits, context = (
25
+ stuff[:, : self.num_prototypes],
26
+ stuff[:, self.num_prototypes :],
27
+ )
28
+ return logits, context
29
+
30
+
31
+ class ConvRefiner(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_dim=6,
35
+ hidden_dim=16,
36
+ out_dim=2,
37
+ dw=True,
38
+ kernel_size=5,
39
+ hidden_blocks=5,
40
+ residual=False,
41
+ ):
42
+ super().__init__()
43
+ self.block1 = self.create_block(
44
+ in_dim,
45
+ hidden_dim,
46
+ dw=False,
47
+ kernel_size=1,
48
+ )
49
+ self.hidden_blocks = nn.Sequential(
50
+ *[
51
+ self.create_block(
52
+ hidden_dim,
53
+ hidden_dim,
54
+ dw=dw,
55
+ kernel_size=kernel_size,
56
+ )
57
+ for hb in range(hidden_blocks)
58
+ ]
59
+ )
60
+ self.hidden_blocks = self.hidden_blocks
61
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
62
+ self.residual = residual
63
+
64
+ def create_block(
65
+ self,
66
+ in_dim,
67
+ out_dim,
68
+ dw=True,
69
+ kernel_size=5,
70
+ bias=True,
71
+ norm_type=nn.BatchNorm2d,
72
+ ):
73
+ num_groups = 1 if not dw else in_dim
74
+ if dw:
75
+ assert out_dim % in_dim == 0, "outdim must be divisible by indim for depthwise"
76
+ conv1 = nn.Conv2d(
77
+ in_dim,
78
+ out_dim,
79
+ kernel_size=kernel_size,
80
+ stride=1,
81
+ padding=kernel_size // 2,
82
+ groups=num_groups,
83
+ bias=bias,
84
+ )
85
+ norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim)
86
+ relu = nn.ReLU(inplace=True)
87
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
88
+ return nn.Sequential(conv1, norm, relu, conv2)
89
+
90
+ def forward(self, feats):
91
+ b, c, hs, ws = feats.shape
92
+ x0 = self.block1(feats)
93
+ x = self.hidden_blocks(x0)
94
+ if self.residual:
95
+ x = (x + x0) / 1.4
96
+ x = self.out_conv(x)
97
+ return x
98
+
99
+
100
+ class VGG19(nn.Module):
101
+ def __init__(self, pretrained=False, num_input_channels=3) -> None:
102
+ super().__init__()
103
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
104
+ # Maxpool layers: 6, 13, 26, 39
105
+
106
+ if num_input_channels != 3:
107
+ log.info(f"Changing input channels from 3 to {num_input_channels}")
108
+ self.layers[0] = nn.Conv2d(num_input_channels, 64, 3, 1, 1)
109
+
110
+ def get_dim_layers(self):
111
+ return [64, 128, 256, 512]
112
+
113
+ def forward(self, x, **kwargs):
114
+ feats = []
115
+ sizes = []
116
+ for layer in self.layers:
117
+ if isinstance(layer, nn.MaxPool2d):
118
+ feats.append(x)
119
+ sizes.append(x.shape[-2:])
120
+ x = layer(x)
121
+ return feats, sizes
122
+
123
+
124
+ class VGG(nn.Module):
125
+ def __init__(self, size="19", pretrained=False) -> None:
126
+ super().__init__()
127
+ if size == "11":
128
+ self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22])
129
+ elif size == "13":
130
+ self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28])
131
+ elif size == "19":
132
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
133
+ # Maxpool layers: 6, 13, 26, 39
134
+
135
+ def forward(self, x, **kwargs):
136
+ feats = []
137
+ sizes = []
138
+ for layer in self.layers:
139
+ if isinstance(layer, nn.MaxPool2d):
140
+ feats.append(x)
141
+ sizes.append(x.shape[-2:])
142
+ x = layer(x)
143
+ return feats, sizes
ripe/models/ripe.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ripe import utils
9
+ from ripe.utils.utils import gridify
10
+
11
+ log = utils.get_pylogger(__name__)
12
+
13
+
14
+ class KeypointSampler(nn.Module):
15
+ """
16
+ Sample keypoints according to a Heatmap
17
+ Adapted from: https://github.com/verlab/DALF_CVPR_2023/blob/main/modules/models/DALF.py
18
+ """
19
+
20
+ def __init__(self, window_size=8):
21
+ super().__init__()
22
+ self.window_size = window_size
23
+ self.idx_cells = None # Cache for meshgrid indices
24
+
25
+ def sample(self, grid):
26
+ """
27
+ Sample keypoints given a grid where each cell has logits stacked in last dimension
28
+ Input
29
+ grid: [B, C, H//w, W//w, w*w]
30
+
31
+ Returns
32
+ log_probs: [B, C, H//w, W//w ] - logprobs of selected samples
33
+ choices: [B, C, H//w, W//w] indices of choices
34
+ accept_mask: [B, C, H//w, W//w] mask of accepted keypoints
35
+
36
+ """
37
+ chooser = torch.distributions.Categorical(logits=grid)
38
+ choices = chooser.sample()
39
+ logits_selected = torch.gather(grid, -1, choices.unsqueeze(-1)).squeeze(-1)
40
+
41
+ flipper = torch.distributions.Bernoulli(logits=logits_selected)
42
+ accepted_choices = flipper.sample()
43
+
44
+ # Sum log-probabilities is equivalent to multiplying the probabilities
45
+ log_probs = chooser.log_prob(choices) + flipper.log_prob(accepted_choices)
46
+
47
+ accept_mask = accepted_choices.gt(0)
48
+
49
+ return (
50
+ log_probs.squeeze(1),
51
+ choices,
52
+ accept_mask.squeeze(1),
53
+ logits_selected.squeeze(1),
54
+ )
55
+
56
+ def precompute_idx_cells(self, H, W, device):
57
+ idx_cells = gridify(
58
+ torch.dstack(
59
+ torch.meshgrid(
60
+ torch.arange(H, dtype=torch.float32, device=device),
61
+ torch.arange(W, dtype=torch.float32, device=device),
62
+ )
63
+ )
64
+ .permute(2, 0, 1)
65
+ .unsqueeze(0)
66
+ .expand(1, -1, -1, -1),
67
+ window_size=self.window_size,
68
+ )
69
+
70
+ return idx_cells
71
+
72
+ def forward(self, x, mask_padding=None):
73
+ """
74
+ Sample keypoints from a heatmap
75
+ Input
76
+ x: [B, C, H, W] Heatmap
77
+ mask_padding: [B, 1, H, W] Mask for padding (optional)
78
+ Returns
79
+ keypoints: [B, H//w, W//w, 2] Keypoints in (x, y) format
80
+ log_probs: [B, H//w, W//w] Log probabilities of selected keypoints
81
+ mask: [B, H//w, W//w] Mask of accepted keypoints
82
+ mask_padding: [B, 1, H//w, W//w] Mask of padding (optional)
83
+ logits_selected: [B, H//w, W//w] Logits of selected keypoints
84
+ """
85
+
86
+ B, C, H, W = x.shape
87
+
88
+ keypoint_cells = gridify(x, self.window_size)
89
+
90
+ mask_padding = (
91
+ (torch.min(gridify(mask_padding, self.window_size), dim=4).values) if mask_padding is not None else None
92
+ )
93
+
94
+ if self.idx_cells is None or self.idx_cells.shape[2:4] != (
95
+ H // self.window_size,
96
+ W // self.window_size,
97
+ ):
98
+ self.idx_cells = self.precompute_idx_cells(H, W, x.device)
99
+
100
+ log_probs, idx, mask, logits_selected = self.sample(keypoint_cells)
101
+
102
+ keypoints = (
103
+ torch.gather(
104
+ self.idx_cells.expand(B, -1, -1, -1, -1),
105
+ -1,
106
+ idx.repeat(1, 2, 1, 1).unsqueeze(-1),
107
+ )
108
+ .squeeze(-1)
109
+ .permute(0, 2, 3, 1)
110
+ )
111
+
112
+ # flip keypoints to (x, y) format
113
+ return keypoints.flip(-1), log_probs, mask, mask_padding, logits_selected
114
+
115
+
116
+ class RIPE(nn.Module):
117
+ """
118
+ Base class for extracting keypoints and descriptors
119
+ Input
120
+ x: [B, C, H, W] Images
121
+
122
+ Returns
123
+ kpts:
124
+ list of size [B] with detected keypoints
125
+ descs:
126
+ list of size [B] with descriptors
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ net,
132
+ upsampler,
133
+ window_size: int = 8,
134
+ non_linearity_dect=None,
135
+ desc_shares: Optional[List[int]] = None,
136
+ descriptor_dim: int = 256,
137
+ device=None,
138
+ ):
139
+ super().__init__()
140
+ self.net = net
141
+
142
+ self.detector = KeypointSampler(window_size)
143
+ self.upsampler = upsampler
144
+ self.sampler = None
145
+ self.window_size = window_size
146
+ self.non_linearity_dect = non_linearity_dect if non_linearity_dect is not None else nn.Identity()
147
+
148
+ log.info(f"Training with window size {window_size}.")
149
+ log.info(f"Use {non_linearity_dect} as final non-linearity before the detection heatmap.")
150
+
151
+ dim_coarse_desc = self.get_dim_raw_desc()
152
+
153
+ if desc_shares is not None:
154
+ assert upsampler.name == "HyperColumnFeatures", (
155
+ "Individual descriptor convolutions are only supported with HyperColumnFeatures"
156
+ )
157
+ assert len(desc_shares) == 4, "desc_shares should have 4 elements"
158
+ assert sum(desc_shares) == descriptor_dim, f"sum of desc_shares should be {descriptor_dim}"
159
+
160
+ self.conv_dim_reduction_coarse_desc = nn.ModuleList()
161
+
162
+ for dim_in, dim_out in zip(dim_coarse_desc, desc_shares):
163
+ log.info(f"Training dim reduction descriptor with {dim_in} -> {dim_out} 1x1 conv")
164
+ self.conv_dim_reduction_coarse_desc.append(
165
+ nn.Conv1d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)
166
+ )
167
+ else:
168
+ if descriptor_dim is not None:
169
+ log.info(f"Training dim reduction descriptor with {sum(dim_coarse_desc)} -> {descriptor_dim} 1x1 conv")
170
+ self.conv_dim_reduction_coarse_desc = nn.Conv1d(
171
+ sum(dim_coarse_desc),
172
+ descriptor_dim,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0,
176
+ )
177
+ else:
178
+ log.warning(
179
+ f"No descriptor dimension specified, no 1x1 conv will be applied! Direct usage of {sum(dim_coarse_desc)}-dimensional raw descriptor"
180
+ )
181
+ self.conv_dim_reduction_coarse_desc = nn.Identity()
182
+
183
+ def get_dim_raw_desc(self):
184
+ layers_dims_encoder = self.net.get_dim_layers_encoder()
185
+
186
+ if self.upsampler.name == "InterpolateSparse2d":
187
+ return [layers_dims_encoder[-1]]
188
+ elif self.upsampler.name == "HyperColumnFeatures":
189
+ return layers_dims_encoder
190
+ else:
191
+ raise ValueError(f"Unknown interpolator {self.upsampler.name}")
192
+
193
+ @torch.inference_mode()
194
+ def detectAndCompute(self, img, threshold=0.5, top_k=2048, output_aux=False):
195
+ self.train(False)
196
+
197
+ if img.dim() == 3:
198
+ img = img.unsqueeze(0)
199
+
200
+ out = self(img, training=False)
201
+ B, K, H, W = out["heatmap"].shape
202
+
203
+ assert B == 1, "Batch size should be 1"
204
+
205
+ kpts = [{"xy": self.NMS(out["heatmap"][b], threshold)} for b in range(B)]
206
+
207
+ if top_k is not None:
208
+ for b in range(B):
209
+ scores = out["heatmap"][b].squeeze(0)[kpts[b]["xy"][:, 1].long(), kpts[b]["xy"][:, 0].long()]
210
+ sorted_idx = torch.argsort(-scores)
211
+ kpts[b]["xy"] = kpts[b]["xy"][sorted_idx[:top_k]]
212
+ if "logprobs" in kpts[b]:
213
+ kpts[b]["logprobs"] = kpts[b]["xy"][sorted_idx[:top_k]]
214
+
215
+ if kpts[0]["xy"].shape[0] == 0:
216
+ raise RuntimeError("No keypoints detected")
217
+
218
+ # the following works for batch size 1 only
219
+
220
+ descs = self.get_descs(out["coarse_descs"], img, kpts[0]["xy"].unsqueeze(0), H, W)
221
+ descs = descs.squeeze(0)
222
+
223
+ score_map = out["heatmap"][0].squeeze(0)
224
+
225
+ kpts = kpts[0]["xy"]
226
+
227
+ scores = score_map[kpts[:, 1], kpts[:, 0]]
228
+ scores /= score_map.max()
229
+
230
+ sort_idx = torch.argsort(-scores)
231
+ kpts, descs, scores = kpts[sort_idx], descs[sort_idx], scores[sort_idx]
232
+
233
+ if output_aux:
234
+ return (
235
+ kpts.float(),
236
+ descs,
237
+ scores,
238
+ {
239
+ "heatmap": out["heatmap"],
240
+ "descs": out["coarse_descs"],
241
+ "conv": self.conv_dim_reduction_coarse_desc,
242
+ },
243
+ )
244
+
245
+ return kpts.float(), descs, scores
246
+
247
+ def NMS(self, x, threshold=3.0, kernel_size=3):
248
+ pad = kernel_size // 2
249
+ local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
250
+
251
+ pos = (x == local_max) & (x > threshold)
252
+ return pos.nonzero()[..., 1:].flip(-1)
253
+
254
+ def get_descs(self, feature_map, guidance, kpts, H, W):
255
+ descs = self.upsampler(feature_map, kpts, H, W)
256
+
257
+ if isinstance(self.conv_dim_reduction_coarse_desc, nn.ModuleList):
258
+ # individual descriptor convolutions for each layer
259
+ desc_conv = []
260
+ for desc, conv in zip(descs, self.conv_dim_reduction_coarse_desc):
261
+ desc_conv.append(conv(desc.permute(0, 2, 1)).permute(0, 2, 1))
262
+ desc = torch.cat(desc_conv, dim=-1)
263
+ else:
264
+ desc = torch.cat(descs, dim=-1)
265
+ desc = self.conv_dim_reduction_coarse_desc(desc.permute(0, 2, 1)).permute(0, 2, 1)
266
+
267
+ desc = F.normalize(desc, dim=2)
268
+
269
+ return desc
270
+
271
+ def forward(self, x, mask_padding=None, training=False):
272
+ B, C, H, W = x.shape
273
+ out = self.net(x)
274
+ out["heatmap"] = self.non_linearity_dect(out["heatmap"])
275
+ # print(out['map'].shape, out['descr'].shape)
276
+ if training:
277
+ kpts, log_probs, mask, mask_padding, logits_selected = self.detector(out["heatmap"], mask_padding)
278
+
279
+ filter_A = kpts[:, :, :, 0] >= 16
280
+ filter_B = kpts[:, :, :, 1] >= 16
281
+ filter_C = kpts[:, :, :, 0] < W - 16
282
+ filter_D = kpts[:, :, :, 1] < H - 16
283
+ filter_all = filter_A * filter_B * filter_C * filter_D
284
+
285
+ mask = mask * filter_all
286
+
287
+ return (
288
+ kpts.view(B, -1, 2),
289
+ log_probs.view(B, -1),
290
+ mask.view(B, -1),
291
+ mask_padding.view(B, -1),
292
+ logits_selected.view(B, -1),
293
+ out,
294
+ )
295
+ else:
296
+ return out
297
+
298
+
299
+ def output_number_trainable_params(model):
300
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
301
+ nb_params = sum([np.prod(p.size()) for p in model_parameters])
302
+
303
+ print(f"Number of trainable parameters: {nb_params:d}")
ripe/models/upsampler/hypercolumn_features.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class HyperColumnFeatures(nn.Module):
7
+ """
8
+ Interpolate 3D tensor given N sparse 2D positions
9
+ Input
10
+ x: list([C, H, W]) list of feature tensors at different scales (e.g. from a U-Net) -> extract hypercolumn features
11
+ pos: [N, 2] tensor of positions
12
+ H: int, height of the OUTPUT map
13
+ W: int, width of the OUTPUT map
14
+
15
+ Returns
16
+ [N, C] sampled features at 2d positions
17
+ """
18
+
19
+ def __init__(self, mode="bilinear"):
20
+ super().__init__()
21
+ self.mode = mode
22
+ self.name = "HyperColumnFeatures"
23
+
24
+ def normgrid(self, x, H, W):
25
+ return 2.0 * (x / (torch.tensor([W - 1, H - 1], device=x.device, dtype=x.dtype))) - 1.0
26
+
27
+ def extract_values_at_poses(self, x, pos, H, W):
28
+ """Extract values from tensor x at the positions given by pos.
29
+
30
+ Args:
31
+ - x (Tensor): Tensor of size (C, H, W).
32
+ - pos (Tensor): Tensor of size (N, 2) containing the x, y positions.
33
+
34
+ Returns:
35
+ - values (Tensor): Tensor of size (N, C) with the values from f at the positions given by p.
36
+ """
37
+
38
+ # check if grid is float32
39
+ if x.dtype != torch.float32:
40
+ x = x.to(torch.float32)
41
+
42
+ grid = self.normgrid(pos, H, W).unsqueeze(-2)
43
+
44
+ x = F.grid_sample(x, grid, mode=self.mode, align_corners=True)
45
+ return x.permute(0, 2, 3, 1).squeeze(-2)
46
+
47
+ def forward(self, x, pos, H, W):
48
+ descs = []
49
+
50
+ for layer in x:
51
+ desc = self.extract_values_at_poses(layer, pos, H, W)
52
+ descs.append(desc)
53
+
54
+ return descs
ripe/models/upsampler/interpolate_sparse2d.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class InterpolateSparse2d(nn.Module):
7
+ """
8
+ Interpolate 3D tensor given N sparse 2D positions
9
+ Input
10
+ x: list([C, H, W]) feature tensors at different scales (e.g. from a U-Net), ONLY the last one is used
11
+ pos: [N, 2] tensor of positions
12
+ H: int, height of the OUTPUT map
13
+ W: int, width of the OUTPUT map
14
+
15
+ Returns
16
+ [N, C] sampled features at 2d positions
17
+ """
18
+
19
+ def __init__(self, mode="bicubic"):
20
+ super().__init__()
21
+ self.mode = mode
22
+ self.name = "InterpolateSparse2d"
23
+
24
+ def normgrid(self, x, H, W):
25
+ return 2.0 * (x / (torch.tensor([W - 1, H - 1], device=x.device, dtype=x.dtype))) - 1.0
26
+
27
+ def forward(self, x, pos, H, W):
28
+ x = x[-1] # only use the last layer
29
+
30
+ # check if grid is float32
31
+ if x.dtype != torch.float32:
32
+ x = x.to(torch.float32)
33
+
34
+ grid = self.normgrid(pos, H, W).unsqueeze(-2)
35
+
36
+ x = F.grid_sample(x, grid, mode=self.mode, align_corners=True)
37
+ return [x.permute(0, 2, 3, 1).squeeze(-2)]
ripe/scheduler/__init__.py ADDED
File without changes
ripe/scheduler/constant.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class ConstantScheduler:
2
+ def __init__(self, value):
3
+ self.value = value
4
+
5
+ def __call__(self, step):
6
+ return self.value
ripe/scheduler/expDecay.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from ripe import utils
4
+
5
+ log = utils.get_pylogger(__name__)
6
+
7
+
8
+ class ExpDecay:
9
+ """Exponential decay scheduler.
10
+ args:
11
+ a: float, a + c = initial value
12
+ b: decay rate
13
+ c: float, final value
14
+
15
+ f(x) = a * e^(-b * x) + c
16
+ """
17
+
18
+ def __init__(self, a, b, c):
19
+ self.a = a
20
+ self.b = b
21
+ self.c = c
22
+
23
+ log.info(f"ExpDecay: a={a}, b={b}, c={c}")
24
+
25
+ def __call__(self, step):
26
+ return self.a * np.exp(-self.b * step) + self.c
ripe/scheduler/linearLR.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class StepLinearLR:
2
+ """Decay the learning rate by a linearly changing factor at each STEP (not epoch).
3
+
4
+ Args:
5
+ optimizer (Optimizer): Wrapped optimizer.
6
+ num_steps (int): Total number of steps in the training process.
7
+ initial_lr (float): Initial learning rate.
8
+ final_lr (float): Final learning rate.
9
+ """
10
+
11
+ def __init__(self, optimizer, steps_init, num_steps, initial_lr, final_lr):
12
+ self.optimizer = optimizer
13
+ self.num_steps = num_steps
14
+ self.initial_lr = initial_lr
15
+ self.final_lr = final_lr
16
+ self.i_step = steps_init
17
+ self.decay_factor = (final_lr - initial_lr) / num_steps
18
+
19
+ def step(self):
20
+ """Decay the learning rate by decay_factor."""
21
+ self.i_step += 1
22
+
23
+ if self.i_step > self.num_steps:
24
+ return
25
+
26
+ lr = self.initial_lr + self.i_step * self.decay_factor
27
+ for param_group in self.optimizer.param_groups:
28
+ param_group["lr"] = lr
29
+
30
+ def get_lr(self):
31
+ return self.optimizer.param_groups[0]["lr"]
32
+
33
+ def get_last_lr(self):
34
+ return self.optimizer.param_groups[0]["lr"]
35
+
36
+ def get_step(self):
37
+ return self.i_step
ripe/scheduler/linear_with_plateaus.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ripe import utils
2
+
3
+ log = utils.get_pylogger(__name__)
4
+
5
+
6
+ class LinearWithPlateaus:
7
+ """Linear scheduler with plateaus.
8
+
9
+ Linearly increases from `start_val` to `end_val`.
10
+ Stays at `start_val` for `plateau_start_steps` steps and at `end_val` for `plateau_end_steps` steps.
11
+ Linearly changes from `start_val` to `end_val` during the remaining steps.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ start_val,
17
+ end_val,
18
+ steps_total,
19
+ rel_length_start_plateau=0.0,
20
+ rel_length_end_plateu=0.0,
21
+ ):
22
+ self.start_val = start_val
23
+ self.end_val = end_val
24
+ self.steps_total = steps_total
25
+ self.plateau_start_steps = steps_total * rel_length_start_plateau
26
+ self.plateau_end_steps = steps_total * rel_length_end_plateu
27
+
28
+ assert self.plateau_start_steps >= 0
29
+ assert self.plateau_end_steps >= 0
30
+ assert self.plateau_start_steps + self.plateau_end_steps <= self.steps_total
31
+
32
+ self.slope = (end_val - start_val) / (steps_total - self.plateau_start_steps - self.plateau_end_steps)
33
+
34
+ log.info(
35
+ f"LinearWithPlateaus: start_val={start_val}, end_val={end_val}, steps_total={steps_total}, "
36
+ f"plateau_start_steps={self.plateau_start_steps}, plateau_end_steps={self.plateau_end_steps}"
37
+ )
38
+
39
+ def __call__(self, step):
40
+ if step < self.plateau_start_steps:
41
+ return self.start_val
42
+ if step < self.steps_total - self.plateau_end_steps:
43
+ return self.start_val + self.slope * (step - self.plateau_start_steps)
44
+ return self.end_val
ripe/train.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".git", "pyproject.toml"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ SEED = 32000
11
+
12
+ import collections
13
+ import os
14
+
15
+ import hydra
16
+ from hydra.utils import instantiate
17
+ from lightning.fabric import Fabric
18
+
19
+ print(SEED)
20
+ import random
21
+
22
+ os.environ["PYTHONHASHSEED"] = str(SEED)
23
+
24
+ import numpy as np
25
+ import torch
26
+ import tqdm
27
+ import wandb
28
+ from torch.optim.adamw import AdamW
29
+ from torch.utils.data import DataLoader
30
+
31
+ from ripe import utils
32
+ from ripe.benchmarks.imw_2020 import IMW_2020_Benchmark
33
+ from ripe.utils.utils import get_rewards
34
+ from ripe.utils.wandb_utils import get_flattened_wandb_cfg
35
+
36
+ log = utils.get_pylogger(__name__)
37
+ from pathlib import Path
38
+
39
+ torch.manual_seed(SEED)
40
+ np.random.seed(SEED)
41
+ random.seed(SEED)
42
+
43
+
44
+ def unpack_batch(batch):
45
+ src_image = batch["src_image"]
46
+ trg_image = batch["trg_image"]
47
+ trg_mask = batch["trg_mask"]
48
+ src_mask = batch["src_mask"]
49
+ label = batch["label"]
50
+ H = batch["homography"]
51
+
52
+ return src_image, trg_image, src_mask, trg_mask, H, label
53
+
54
+
55
+ @hydra.main(config_path="../conf/", config_name="config", version_base=None)
56
+ def train(cfg):
57
+ """Main training function for the RIPE model."""
58
+ # Prepare model, data and hyperparms
59
+
60
+ strategy = "ddp" if cfg.num_gpus > 1 else "auto"
61
+ fabric = Fabric(
62
+ accelerator="cuda",
63
+ devices=cfg.num_gpus,
64
+ precision=cfg.precision,
65
+ strategy=strategy,
66
+ )
67
+ fabric.launch()
68
+
69
+ output_dir = Path(cfg.output_dir)
70
+ experiment_name = output_dir.parent.parent.parent.name
71
+ run_id = output_dir.parent.parent.name
72
+ timestamp = output_dir.parent.name + "_" + output_dir.name
73
+
74
+ experiment_name = run_id + " " + timestamp + " " + experiment_name
75
+
76
+ # setup logger
77
+ wandb_logger = wandb.init(
78
+ project=cfg.project_name,
79
+ name=experiment_name,
80
+ config=get_flattened_wandb_cfg(cfg),
81
+ dir=cfg.output_dir,
82
+ mode=cfg.wandb_mode,
83
+ )
84
+
85
+ min_nums_matches = {"homography": 4, "fundamental": 8, "fundamental_7pt": 7}
86
+ min_num_matches = min_nums_matches[cfg.transformation_model]
87
+ print(f"Minimum number of matches for {cfg.transformation_model} is {min_num_matches}")
88
+
89
+ batch_size = cfg.batch_size
90
+ steps = cfg.num_steps
91
+ lr = cfg.lr
92
+
93
+ num_grad_accs = (
94
+ cfg.num_grad_accs
95
+ ) # this performs grad accumulation to simulate larger batch size, set to 1 to disable;
96
+
97
+ # instantiate dataset
98
+ ds = instantiate(cfg.data)
99
+
100
+ # prepare dataloader
101
+ dl = DataLoader(
102
+ ds,
103
+ batch_size=batch_size,
104
+ shuffle=True,
105
+ drop_last=True,
106
+ persistent_workers=False,
107
+ num_workers=cfg.num_workers,
108
+ )
109
+ dl = fabric.setup_dataloaders(dl)
110
+ i_dl = iter(dl)
111
+
112
+ # create matcher
113
+ matcher = instantiate(cfg.matcher)
114
+
115
+ if cfg.desc_loss_weight != 0.0:
116
+ descriptor_loss = instantiate(cfg.descriptor_loss)
117
+ else:
118
+ log.warning(
119
+ "Descriptor loss weight is 0.0, descriptor loss will not be used. 1x1 conv for descriptors will be deactivated!"
120
+ )
121
+ descriptor_loss = None
122
+
123
+ upsampler = instantiate(cfg.upsampler) if "upsampler" in cfg else None
124
+
125
+ # create network
126
+ net = instantiate(cfg.network)(
127
+ net=instantiate(cfg.backbones),
128
+ upsampler=upsampler,
129
+ descriptor_dim=cfg.descriptor_dim if descriptor_loss is not None else None,
130
+ device=fabric.device,
131
+ ).train()
132
+
133
+ # get num parameters
134
+ num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
135
+ log.info(f"Number of parameters: {num_params}")
136
+
137
+ fp_penalty = cfg.fp_penalty # small penalty for not finding a match
138
+ kp_penalty = cfg.kp_penalty # small penalty for low logprob keypoints
139
+
140
+ opt_pi = AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=lr, weight_decay=1e-5)
141
+ net, opt_pi = fabric.setup(net, opt_pi)
142
+
143
+ if cfg.lr_scheduler:
144
+ scheduler = instantiate(cfg.lr_scheduler)(optimizer=opt_pi, steps_init=0)
145
+ else:
146
+ scheduler = None
147
+
148
+ val_benchmark = IMW_2020_Benchmark(
149
+ use_predefined_subset=True,
150
+ conf_inference=cfg.conf_inference,
151
+ edge_input_divisible_by=None,
152
+ )
153
+
154
+ # mean average of skipped batches
155
+ # this is used to monitor how many batches were skipped due to not enough keypoints
156
+ # this is useful to detect if the model is not learning anything -> should be zero
157
+ ma_skipped_batches = collections.deque(maxlen=100)
158
+
159
+ opt_pi.zero_grad()
160
+
161
+ # initialize scheduler
162
+ alpha_scheduler = instantiate(cfg.alpha_scheduler)
163
+ beta_scheduler = instantiate(cfg.beta_scheduler)
164
+ inl_th_scheduler = instantiate(cfg.inl_th)
165
+
166
+ # ====== Training Loop ======
167
+ # check if the model is in training mode
168
+ net.train()
169
+
170
+ with tqdm.tqdm(total=steps) as pbar:
171
+ for i_step in range(steps):
172
+ alpha = alpha_scheduler(i_step)
173
+ beta = beta_scheduler(i_step)
174
+ inl_th = inl_th_scheduler(i_step)
175
+
176
+ if scheduler:
177
+ scheduler.step()
178
+
179
+ # Initialize vars for current step
180
+ # We need to handle batching because the description can have arbitrary number of keypoints
181
+ sum_reward_batch = 0
182
+ sum_num_keypoints_1 = 0
183
+ sum_num_keypoints_2 = 0
184
+ loss = None
185
+ loss_policy_stack = None
186
+ loss_desc_stack = None
187
+ loss_kp_stack = None
188
+
189
+ try:
190
+ batch = next(i_dl)
191
+ except StopIteration:
192
+ i_dl = iter(dl)
193
+ batch = next(i_dl)
194
+
195
+ p1, p2, mask_padding_1, mask_padding_2, Hs, label = unpack_batch(batch)
196
+
197
+ (
198
+ kpts1,
199
+ logprobs1,
200
+ selected_mask1,
201
+ mask_padding_grid_1,
202
+ logits_selected_1,
203
+ out1,
204
+ ) = net(p1, mask_padding_1, training=True)
205
+ (
206
+ kpts2,
207
+ logprobs2,
208
+ selected_mask2,
209
+ mask_padding_grid_2,
210
+ logits_selected_2,
211
+ out2,
212
+ ) = net(p2, mask_padding_2, training=True)
213
+
214
+ # upsample coarse descriptors for all keypoints from the intermediate feature maps from the encoder
215
+ desc_1 = net.get_descs(out1["coarse_descs"], p1, kpts1, p1.shape[2], p1.shape[3])
216
+ desc_2 = net.get_descs(out2["coarse_descs"], p2, kpts2, p2.shape[2], p2.shape[3])
217
+
218
+ if cfg.padding_filter_mode == "ignore": # remove keypoints that are in padding
219
+ batch_mask_selection_for_matching_1 = selected_mask1 & mask_padding_grid_1
220
+ batch_mask_selection_for_matching_2 = selected_mask2 & mask_padding_grid_2
221
+ elif cfg.padding_filter_mode == "punish":
222
+ batch_mask_selection_for_matching_1 = selected_mask1 # keep all keypoints
223
+ batch_mask_selection_for_matching_2 = selected_mask2 # punish the keypoints in the padding area
224
+ else:
225
+ raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}")
226
+
227
+ (
228
+ batch_rel_idx_matches,
229
+ batch_abs_idx_matches,
230
+ batch_ransac_inliers,
231
+ batch_Fm,
232
+ ) = matcher(
233
+ kpts1,
234
+ kpts2,
235
+ desc_1,
236
+ desc_2,
237
+ batch_mask_selection_for_matching_1,
238
+ batch_mask_selection_for_matching_2,
239
+ inl_th,
240
+ label if cfg.no_filtering_negatives else None,
241
+ )
242
+
243
+ for b in range(batch_size):
244
+ # ignore if less than 16 keypoints have been detected
245
+ if batch_rel_idx_matches[b] is None:
246
+ ma_skipped_batches.append(1)
247
+ continue
248
+ else:
249
+ ma_skipped_batches.append(0)
250
+
251
+ mask_selection_for_matching_1 = batch_mask_selection_for_matching_1[b]
252
+ mask_selection_for_matching_2 = batch_mask_selection_for_matching_2[b]
253
+
254
+ rel_idx_matches = batch_rel_idx_matches[b]
255
+ abs_idx_matches = batch_abs_idx_matches[b]
256
+ ransac_inliers = batch_ransac_inliers[b]
257
+
258
+ if cfg.selected_only:
259
+ # every SELECTED keypoint with every other SELECTED keypoint
260
+ dense_logprobs = logprobs1[b][mask_selection_for_matching_1].view(-1, 1) + logprobs2[b][
261
+ mask_selection_for_matching_2
262
+ ].view(1, -1)
263
+ else:
264
+ if cfg.padding_filter_mode == "ignore":
265
+ # every keypoint with every other keypoint, but WITHOUT keypoint in the padding area
266
+ dense_logprobs = logprobs1[b][mask_padding_grid_1[b]].view(-1, 1) + logprobs2[b][
267
+ mask_padding_grid_2[b]
268
+ ].view(1, -1)
269
+ elif cfg.padding_filter_mode == "punish":
270
+ # every keypoint with every other keypoint, also WITH keypoints in the padding areas -> will be punished by the reward
271
+ dense_logprobs = logprobs1[b].view(-1, 1) + logprobs2[b].view(1, -1)
272
+ else:
273
+ raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}")
274
+
275
+ reward = None
276
+
277
+ if cfg.reward_type == "inlier":
278
+ reward = (
279
+ 0.5 if cfg.no_filtering_negatives and not label[b] else 1.0
280
+ ) # reward is 1.0 if the pair is positive, 0.5 if negative and no filtering is applied
281
+ elif cfg.reward_type == "inlier_ratio":
282
+ ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches)
283
+ reward = ratio_inlier # reward is the ratio of inliers -> higher if more matches are inliers
284
+ elif cfg.reward_type == "inlier+inlier_ratio":
285
+ ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches)
286
+ reward = (
287
+ (1.0 - beta) * 1.0 + beta * ratio_inlier
288
+ ) # reward is a combination of the ratio of inliers and the number of inliers -> gradually changes
289
+ else:
290
+ raise ValueError(f"Unknown reward type: {cfg.reward_type}")
291
+
292
+ dense_rewards = get_rewards(
293
+ reward,
294
+ kpts1[b],
295
+ kpts2[b],
296
+ mask_selection_for_matching_1,
297
+ mask_selection_for_matching_2,
298
+ mask_padding_grid_1[b],
299
+ mask_padding_grid_2[b],
300
+ rel_idx_matches,
301
+ abs_idx_matches,
302
+ ransac_inliers,
303
+ label[b],
304
+ fp_penalty * alpha,
305
+ use_whitening=cfg.use_whitening,
306
+ selected_only=cfg.selected_only,
307
+ filter_mode=cfg.padding_filter_mode,
308
+ )
309
+
310
+ if descriptor_loss is not None:
311
+ hard_loss = descriptor_loss(
312
+ desc1=desc_1[b],
313
+ desc2=desc_2[b],
314
+ matches=abs_idx_matches,
315
+ inliers=ransac_inliers,
316
+ label=label[b],
317
+ logits_1=None,
318
+ logits_2=None,
319
+ )
320
+ loss_desc_stack = (
321
+ hard_loss if loss_desc_stack is None else torch.hstack((loss_desc_stack, hard_loss))
322
+ )
323
+
324
+ sum_reward_batch += dense_rewards.sum()
325
+
326
+ current_loss_policy = (dense_rewards * dense_logprobs).view(-1)
327
+
328
+ loss_policy_stack = (
329
+ current_loss_policy
330
+ if loss_policy_stack is None
331
+ else torch.hstack((loss_policy_stack, current_loss_policy))
332
+ )
333
+
334
+ if kp_penalty != 0.0:
335
+ # keypoints with low logprob are penalized
336
+ # as they get large negative logprob values multiplying them with the penalty will make the loss larger
337
+ loss_kp = (
338
+ logprobs1[b][mask_selection_for_matching_1]
339
+ * torch.full_like(
340
+ logprobs1[b][mask_selection_for_matching_1],
341
+ kp_penalty * alpha,
342
+ )
343
+ ).mean() + (
344
+ logprobs2[b][mask_selection_for_matching_2]
345
+ * torch.full_like(
346
+ logprobs2[b][mask_selection_for_matching_2],
347
+ kp_penalty * alpha,
348
+ )
349
+ ).mean()
350
+ loss_kp_stack = loss_kp if loss_kp_stack is None else torch.hstack((loss_kp_stack, loss_kp))
351
+
352
+ sum_num_keypoints_1 += mask_selection_for_matching_1.sum()
353
+ sum_num_keypoints_2 += mask_selection_for_matching_2.sum()
354
+
355
+ loss = loss_policy_stack.mean()
356
+ if loss_kp_stack is not None:
357
+ loss += loss_kp_stack.mean()
358
+
359
+ loss = -loss
360
+
361
+ if descriptor_loss is not None:
362
+ loss += cfg.desc_loss_weight * loss_desc_stack.mean()
363
+
364
+ pbar.set_description(
365
+ f"LP: {loss.item():.4f} - Det: ({sum_num_keypoints_1 / batch_size:.4f}, {sum_num_keypoints_2 / batch_size:.4f}), #mRwd: {sum_reward_batch / batch_size:.1f}"
366
+ )
367
+ pbar.update()
368
+
369
+ # backward pass
370
+ loss /= num_grad_accs
371
+ fabric.backward(loss)
372
+
373
+ if i_step % num_grad_accs == 0:
374
+ opt_pi.step()
375
+ opt_pi.zero_grad()
376
+
377
+ if i_step % cfg.log_interval == 0:
378
+ wandb_logger.log(
379
+ {
380
+ # "loss": loss.item() if not use_amp else scaled_loss.item(),
381
+ "loss": loss.item(),
382
+ "loss_policy": -loss_policy_stack.mean().item(),
383
+ "loss_kp": loss_kp_stack.mean().item() if loss_kp_stack is not None else 0.0,
384
+ "loss_hard": (loss_desc_stack.mean().item() if loss_desc_stack is not None else 0.0),
385
+ "mean_num_det_kpts1": sum_num_keypoints_1 / batch_size,
386
+ "mean_num_det_kpts2": sum_num_keypoints_2 / batch_size,
387
+ "mean_reward": sum_reward_batch / batch_size,
388
+ "lr": opt_pi.param_groups[0]["lr"],
389
+ "ma_skipped_batches": sum(ma_skipped_batches) / len(ma_skipped_batches),
390
+ "inl_th": inl_th,
391
+ },
392
+ step=i_step,
393
+ )
394
+
395
+ if i_step % cfg.val_interval == 0:
396
+ val_benchmark.evaluate(net, fabric.device, progress_bar=False)
397
+ val_benchmark.log_results(logger=wandb_logger, step=i_step)
398
+
399
+ # ensure that the model is in training mode again
400
+ net.train()
401
+
402
+ # save the model
403
+ torch.save(
404
+ net.state_dict(),
405
+ output_dir / ("model" + "_" + str(i_step + 1) + "_final" + ".pth"),
406
+ )
407
+
408
+
409
+ if __name__ == "__main__":
410
+ train()
ripe/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from x_dd.utils import loggers
2
+ from ripe.utils.pylogger import get_pylogger # noqa: F401
ripe/utils/image_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ class Camera:
7
+ def __init__(self, K, R, t):
8
+ self.K = K
9
+ self.R = R
10
+ self.t = t
11
+
12
+ @classmethod
13
+ def from_calibration_file(cls, path: str):
14
+ with h5py.File(path, "r") as f:
15
+ K = torch.tensor(np.array(f["K"]), dtype=torch.float32)
16
+ R = torch.tensor(np.array(f["R"]), dtype=torch.float32)
17
+ T = torch.tensor(np.array(f["T"]), dtype=torch.float32)
18
+
19
+ return cls(K, R, T)
20
+
21
+ @property
22
+ def K_inv(self):
23
+ return self.K.inverse()
24
+
25
+ def to_cameradict(self):
26
+ fx = self.K[0, 0].item()
27
+ fy = self.K[1, 1].item()
28
+ cx = self.K[0, 2].item()
29
+ cy = self.K[1, 2].item()
30
+
31
+ params = {
32
+ "model": "PINHOLE",
33
+ "width": int(cx * 2),
34
+ "height": int(cy * 2),
35
+ "params": [fx, fy, cx, cy],
36
+ }
37
+
38
+ return params
39
+
40
+ def __repr__(self):
41
+ return f"ImageData(K={self.K}, R={self.R}, t={self.t})"
42
+
43
+
44
+ def cameras2F(cam1: Camera, cam2: Camera) -> torch.Tensor:
45
+ E = cameras2E(cam1, cam2)
46
+ return cam2.K_inv.T @ E @ cam1.K_inv
47
+
48
+
49
+ def cameras2E(cam1: Camera, cam2: Camera) -> torch.Tensor:
50
+ R = cam2.R @ cam1.R.T
51
+ T = cam2.t - R @ cam1.t
52
+ return cross_product_matrix(T) @ R
53
+
54
+
55
+ def cross_product_matrix(v) -> torch.Tensor:
56
+ """Following en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication."""
57
+
58
+ return torch.tensor(
59
+ [[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]],
60
+ dtype=v.dtype,
61
+ device=v.device,
62
+ )
ripe/utils/pose_error.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mostly from: https://github.com/cvg/glue-factory/blob/main/gluefactory/geometry/epipolar.py
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def angle_error_mat(R1, R2):
8
+ cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2
9
+ cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds
10
+ return torch.rad2deg(torch.abs(torch.arccos(cos)))
11
+
12
+
13
+ def angle_error_vec(v1, v2, eps=1e-10):
14
+ n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps)
15
+ v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension
16
+ return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0)))
17
+
18
+
19
+ def relative_pose_error(R_gt, t_gt, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
20
+ # angle error between 2 vectors
21
+ t_err = angle_error_vec(t, t_gt, eps)
22
+ t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity
23
+ if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging
24
+ t_err = torch.zeros_like(t_err)
25
+
26
+ # angle error between 2 rotation matrices
27
+ r_err = angle_error_mat(R, R_gt)
28
+
29
+ return t_err, r_err
30
+
31
+
32
+ def cal_error_auc(errors, thresholds):
33
+ sort_idx = np.argsort(errors)
34
+ errors = np.array(errors.copy())[sort_idx]
35
+ recall = (np.arange(len(errors)) + 1) / len(errors)
36
+ errors = np.r_[0.0, errors]
37
+ recall = np.r_[0.0, recall]
38
+ aucs = []
39
+ for t in thresholds:
40
+ last_index = np.searchsorted(errors, t)
41
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
42
+ e = np.r_[errors[:last_index], t]
43
+ aucs.append(np.round((np.trapz(r, x=e) / t), 4))
44
+ return aucs
45
+
46
+
47
+ class AUCMetric:
48
+ def __init__(self, thresholds, elements=None):
49
+ self._elements = elements
50
+ self.thresholds = thresholds
51
+ if not isinstance(thresholds, list):
52
+ self.thresholds = [thresholds]
53
+
54
+ def update(self, tensor):
55
+ assert tensor.dim() == 1
56
+ self._elements += tensor.cpu().numpy().tolist()
57
+
58
+ def compute(self):
59
+ if len(self._elements) == 0:
60
+ return np.nan
61
+ else:
62
+ return cal_error_auc(self._elements, self.thresholds)
ripe/utils/pylogger.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ # from pytorch_lightning.utilities import rank_zero_only
4
+
5
+
6
+ def init_base_pylogger():
7
+ """Initializes base python command line logger."""
8
+
9
+ logging.basicConfig(
10
+ level=logging.WARNING,
11
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
12
+ handlers=[logging.StreamHandler()],
13
+ )
14
+
15
+
16
+ def get_pylogger(name=__name__) -> logging.Logger:
17
+ """Initializes multi-GPU-friendly python command line logger."""
18
+
19
+ if not logging.root.handlers:
20
+ init_base_pylogger()
21
+
22
+ logger = logging.getLogger(name)
23
+
24
+ logger.setLevel(logging.DEBUG)
25
+
26
+ # this ensures all logging levels get marked with the rank zero decorator
27
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
28
+ # logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
29
+ # for level in logging_levels:
30
+ # setattr(logger, level, rank_zero_only(getattr(logger, level)))
31
+
32
+ return logger
ripe/utils/utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.transforms.functional import resize
8
+
9
+ from ripe import utils
10
+
11
+ log = utils.get_pylogger(__name__)
12
+
13
+
14
+ def gridify(x, window_size):
15
+ """Turn a tensor of BxCxHxW into a tensor of
16
+ BxCx(H//window_size)x(W//window_size)x(window_size**2)
17
+
18
+ Params:
19
+ x: Input tensor of shape BxCxHxW
20
+ window_size: Size of the window
21
+
22
+ Returns:
23
+ x: Output tensor of shape BxCx(H//window_size)x(W//window_size)x(window_size**2)
24
+ """
25
+
26
+ assert x.dim() == 4, "Input tensor x must have 4 dimensions"
27
+
28
+ B, C, H, W = x.shape
29
+ x = (
30
+ x.unfold(2, window_size, window_size)
31
+ .unfold(3, window_size, window_size)
32
+ .reshape(B, C, H // window_size, W // window_size, window_size**2)
33
+ )
34
+
35
+ return x
36
+
37
+
38
+ def get_grid(B, H, W, device):
39
+ x1_n = torch.meshgrid(*[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)])
40
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
41
+ return x1_n
42
+
43
+
44
+ def cv2_matches_from_kornia(match_dists: torch.Tensor, match_idxs: torch.Tensor) -> List[cv2.DMatch]:
45
+ return [cv2.DMatch(idx[0].item(), idx[1].item(), d.item()) for idx, d in zip(match_idxs, match_dists)]
46
+
47
+
48
+ def to_cv_kpts(kpts, scores):
49
+ kp = kpts.cpu().numpy().astype(np.int16)
50
+ s = scores.cpu().numpy()
51
+
52
+ cv_kp = [cv2.KeyPoint(kp[i][0], kp[i][1], 6, 0, s[i]) for i in range(len(kp))]
53
+
54
+ return cv_kp
55
+
56
+
57
+ def resize_image(image, min_size=512, max_size=768):
58
+ """Resize image to a new size while maintaining the aspect ratio.
59
+
60
+ Params:
61
+ image (torch.tensor): Image to be resized.
62
+ min_size (int): Minimum size of the smaller dimension.
63
+ max_size (int): Maximum size of the larger dimension.
64
+
65
+ Returns:
66
+ image: Resized image.
67
+ """
68
+
69
+ h, w = image.shape[-2:]
70
+
71
+ aspect_ratio = w / h
72
+
73
+ if w > h:
74
+ new_w = max(min_size, min(max_size, w))
75
+ new_h = int(new_w / aspect_ratio)
76
+ else:
77
+ new_h = max(min_size, min(max_size, h))
78
+ new_w = int(new_h * aspect_ratio)
79
+
80
+ new_size = (new_h, new_w)
81
+
82
+ image = resize(image, new_size)
83
+
84
+ return image
85
+
86
+
87
+ def get_rewards(
88
+ reward,
89
+ kps1,
90
+ kps2,
91
+ selected_mask1,
92
+ selected_mask2,
93
+ padding_mask1,
94
+ padding_mask2,
95
+ rel_idx_matches,
96
+ abs_idx_matches,
97
+ ransac_inliers,
98
+ label,
99
+ penalty=0.0,
100
+ use_whitening=False,
101
+ selected_only=False,
102
+ filter_mode=None,
103
+ ):
104
+ with torch.no_grad():
105
+ reward *= 1.0 if label else -1.0
106
+
107
+ dense_returns = torch.zeros((len(kps1), len(kps2)), device=kps1.device)
108
+
109
+ if filter_mode == "ignore":
110
+ dense_returns[
111
+ abs_idx_matches[:, 0][ransac_inliers],
112
+ abs_idx_matches[:, 1][ransac_inliers],
113
+ ] = reward
114
+ elif filter_mode == "punish":
115
+ in_padding_area = (
116
+ padding_mask1[abs_idx_matches[:, 0]] & padding_mask2[abs_idx_matches[:, 1]]
117
+ ) # both in the image area (not in padding area)
118
+
119
+ dense_returns[
120
+ abs_idx_matches[:, 0][ransac_inliers & in_padding_area],
121
+ abs_idx_matches[:, 1][ransac_inliers & in_padding_area],
122
+ ] = reward
123
+ dense_returns[
124
+ abs_idx_matches[:, 0][ransac_inliers & ~in_padding_area],
125
+ abs_idx_matches[:, 1][ransac_inliers & ~in_padding_area],
126
+ ] = -1.0
127
+ else:
128
+ raise ValueError(f"Unknown filter mode: {filter_mode}")
129
+
130
+ if selected_only:
131
+ dense_returns = dense_returns[selected_mask1, :][:, selected_mask2]
132
+ if filter_mode == "ignore" and not selected_only:
133
+ dense_returns = dense_returns[padding_mask1, :][:, padding_mask2]
134
+
135
+ if penalty != 0.0:
136
+ # pos. pair: small penalty for not finding a match
137
+ # neg. pair: small reward for not finding a match
138
+ penalty_val = penalty if label else -penalty
139
+
140
+ dense_returns[dense_returns == 0.0] = penalty_val
141
+
142
+ if use_whitening:
143
+ dense_returns = (dense_returns - dense_returns.mean()) / (dense_returns.std() + 1e-6)
144
+
145
+ return dense_returns
146
+
147
+
148
+ def get_other_random_id(idx: int, len_dataset: int, min_dist: int = 20):
149
+ for _ in range(10):
150
+ tgt_id = random.randint(0, len_dataset - 1)
151
+ if abs(idx - tgt_id) >= min_dist:
152
+ return tgt_id
153
+
154
+ raise ValueError(f"Could not find target image with distance >= {min_dist} from source image {idx}")
155
+
156
+
157
+ def cv_resize_and_pad_to_shape(image, new_shape, padding_color=(0, 0, 0)):
158
+ """Resizes image to new_shape with maintaining the aspect ratio and pads with padding_color if
159
+ needed.
160
+
161
+ Params:
162
+ image: Image to be resized.
163
+ new_shape: Expected (height, width) of new image.
164
+ padding_color: Tuple in BGR of padding color
165
+ Returns:
166
+ image: Resized image with padding
167
+ """
168
+ h, w = image.shape[:2]
169
+
170
+ scale_h = new_shape[0] / h
171
+ scale_w = new_shape[1] / w
172
+
173
+ scale = None
174
+ if scale_w * h > new_shape[0]:
175
+ scale = scale_h
176
+ elif scale_h * w > new_shape[1]:
177
+ scale = scale_w
178
+ else:
179
+ scale = max(scale_h, scale_w)
180
+
181
+ new_w, new_h = int(round(w * scale)), int(round(h * scale))
182
+
183
+ image = cv2.resize(image, (new_w, new_h))
184
+
185
+ missing_h = new_shape[0] - new_h
186
+ missing_w = new_shape[1] - new_w
187
+
188
+ top, bottom = missing_h // 2, missing_h - (missing_h // 2)
189
+ left, right = missing_w // 2, missing_w - (missing_w // 2)
190
+
191
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color)
192
+ return image
ripe/utils/wandb_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import omegaconf
2
+
3
+
4
+ def get_flattened_wandb_cfg(conf_dict):
5
+ flattened = {}
6
+
7
+ def _flatten(cfg, prefix=""):
8
+ for k, v in cfg.items():
9
+ new_key = f"{prefix}.{k}" if prefix else k
10
+ if isinstance(v, omegaconf.dictconfig.DictConfig):
11
+ _flatten(v, new_key)
12
+ else:
13
+ flattened[new_key] = v
14
+
15
+ _flatten(conf_dict)
16
+ return flattened