Realcat commited on
Commit
2507d2f
1 Parent(s): 40c4807

add: omniglue

Browse files
README.md CHANGED
@@ -34,6 +34,7 @@ Here is a demo of the tool:
34
  ![demo](assets/demo.gif)
35
 
36
  The tool currently supports various popular image matching algorithms, namely:
 
37
  - [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024
38
  - [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
39
  - [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
 
34
  ![demo](assets/demo.gif)
35
 
36
  The tool currently supports various popular image matching algorithms, namely:
37
+ - [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024
38
  - [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024
39
  - [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
40
  - [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
common/app_class.py CHANGED
@@ -12,6 +12,7 @@ from common.utils import (
12
  run_ransac,
13
  gen_examples,
14
  GRADIO_VERSION,
 
15
  )
16
 
17
 
@@ -49,288 +50,327 @@ class ImageMatchingApp:
49
 
50
  def init_interface(self):
51
  with gr.Blocks() as self.app:
52
- with gr.Row():
53
- with gr.Column(scale=1):
54
- gr.Image(
55
- str(Path(__file__).parent.parent / "assets/logo.webp"),
56
- elem_id="logo-img",
57
- show_label=False,
58
- show_share_button=False,
59
- show_download_button=False,
60
- )
61
- with gr.Column(scale=3):
62
- gr.Markdown(DESCRIPTION)
63
- with gr.Row(equal_height=False):
64
- with gr.Column():
65
- with gr.Row():
66
- matcher_list = gr.Dropdown(
67
- choices=self.init_matcher_dropdown(),
68
- value="disk+lightglue",
69
- label="Matching Model",
70
- interactive=True,
71
- )
72
- match_image_src = gr.Radio(
73
- (
74
- ["upload", "webcam", "clipboard"]
75
- if GRADIO_VERSION > "3"
76
- else ["upload", "webcam", "canvas"]
77
  ),
78
- label="Image Source",
79
- value="upload",
80
- )
81
- with gr.Row():
82
- input_image0 = gr.Image(
83
- label="Image 0",
84
- type="numpy",
85
- image_mode="RGB",
86
- height=300 if GRADIO_VERSION > "3" else None,
87
- interactive=True,
88
- )
89
- input_image1 = gr.Image(
90
- label="Image 1",
91
- type="numpy",
92
- image_mode="RGB",
93
- height=300 if GRADIO_VERSION > "3" else None,
94
- interactive=True,
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- with gr.Row():
98
- button_reset = gr.Button(value="Reset")
99
- button_run = gr.Button(
100
- value="Run Match", variant="primary"
101
- )
102
 
103
- with gr.Accordion("Advanced Setting", open=False):
104
- with gr.Accordion("Matching Setting", open=True):
105
- with gr.Row():
106
- match_setting_threshold = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  minimum=0.0,
108
- maximum=1,
109
- step=0.001,
110
- label="Match thres.",
111
- value=0.1,
112
- )
113
- match_setting_max_features = gr.Slider(
114
- minimum=10,
115
- maximum=10000,
116
- step=10,
117
- label="Max features",
118
- value=1000,
119
  )
120
- # TODO: add line settings
121
- with gr.Row():
122
- detect_keypoints_threshold = gr.Slider(
123
- minimum=0,
124
  maximum=1,
125
- step=0.001,
126
- label="Keypoint thres.",
127
- value=0.015,
 
 
128
  )
129
- detect_line_threshold = gr.Slider(
130
- minimum=0.1,
131
- maximum=1,
132
- step=0.01,
133
- label="Line thres.",
134
- value=0.2,
 
 
135
  )
136
- # matcher_lists = gr.Radio(
137
- # ["NN-mutual", "Dual-Softmax"],
138
- # label="Matcher mode",
139
- # value="NN-mutual",
140
- # )
141
- with gr.Accordion("RANSAC Setting", open=True):
142
- with gr.Row(equal_height=False):
143
- ransac_method = gr.Dropdown(
144
- choices=ransac_zoo.keys(),
145
- value=self.cfg["defaults"]["ransac_method"],
146
- label="RANSAC Method",
147
- interactive=True,
148
  )
149
- ransac_reproj_threshold = gr.Slider(
150
- minimum=0.0,
151
- maximum=12,
152
- step=0.01,
153
- label="Ransac Reproj threshold",
154
- value=8.0,
155
- )
156
- ransac_confidence = gr.Slider(
157
- minimum=0.0,
158
- maximum=1,
159
- step=0.00001,
160
- label="Ransac Confidence",
161
- value=self.cfg["defaults"]["ransac_confidence"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
- ransac_max_iter = gr.Slider(
164
- minimum=0.0,
165
- maximum=100000,
166
- step=100,
167
- label="Ransac Iterations",
168
- value=self.cfg["defaults"]["ransac_max_iter"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  )
170
- button_ransac = gr.Button(
171
- value="Rerun RANSAC", variant="primary"
 
 
 
 
 
172
  )
173
- with gr.Accordion("Geometry Setting", open=False):
174
- with gr.Row(equal_height=False):
175
- choice_geometry_type = gr.Radio(
176
- ["Fundamental", "Homography"],
177
- label="Reconstruct Geometry",
178
- value=self.cfg["defaults"][
179
- "setting_geometry"
180
- ],
181
  )
182
 
183
- # collect inputs
184
- state_cache = gr.State({})
185
- inputs = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  input_image0,
187
  input_image1,
188
  match_setting_threshold,
189
  match_setting_max_features,
190
  detect_keypoints_threshold,
191
  matcher_list,
 
 
 
 
 
 
 
 
 
 
192
  ransac_method,
193
  ransac_reproj_threshold,
194
  ransac_confidence,
195
  ransac_max_iter,
196
  choice_geometry_type,
197
- gr.State(self.matcher_zoo),
198
- # state_cache,
199
  ]
 
 
 
 
 
200
 
201
- # Add some examples
202
- with gr.Row():
203
- # Example inputs
204
- gr.Examples(
205
- examples=gen_examples(),
206
- inputs=inputs,
207
- outputs=[],
208
- fn=run_matching,
209
- cache_examples=False,
210
- label=(
211
- "Examples (click one of the images below to Run"
212
- " Match). Thx: WxBS"
213
- ),
214
- )
215
- with gr.Accordion("Supported Algorithms", open=False):
216
- # add a table of supported algorithms
217
- self.display_supported_algorithms()
218
-
219
- with gr.Column():
220
- output_keypoints = gr.Image(label="Keypoints", type="numpy")
221
- output_matches_raw = gr.Image(
222
- label="Raw Matches",
223
- type="numpy",
224
  )
225
- output_matches_ransac = gr.Image(
226
- label="Ransac Matches", type="numpy"
 
 
 
 
 
 
 
 
 
227
  )
228
- with gr.Accordion(
229
- "Open for More: Matches Statistics", open=False
230
- ):
231
- matches_result_info = gr.JSON(
232
- label="Matches Statistics"
233
- )
234
- matcher_info = gr.JSON(label="Match info")
235
 
236
- with gr.Accordion(
237
- "Open for More: Warped Image", open=False
238
- ):
239
- output_wrapped = gr.Image(
240
- label="Wrapped Pair", type="numpy"
 
 
 
 
 
 
241
  )
242
- with gr.Accordion(
243
- "Open for More: Geometry info", open=False
244
- ):
245
- geometry_result = gr.JSON(
246
- label="Reconstructed Geometry"
247
- )
248
-
249
- # callbacks
250
- match_image_src.change(
251
- fn=self.ui_change_imagebox,
252
- inputs=match_image_src,
253
- outputs=input_image0,
254
- )
255
- match_image_src.change(
256
- fn=self.ui_change_imagebox,
257
- inputs=match_image_src,
258
- outputs=input_image1,
259
- )
260
-
261
- # collect outputs
262
- outputs = [
263
- output_keypoints,
264
- output_matches_raw,
265
- output_matches_ransac,
266
- matches_result_info,
267
- matcher_info,
268
- geometry_result,
269
- output_wrapped,
270
- state_cache,
271
- ]
272
- # button callbacks
273
- button_run.click(
274
- fn=run_matching, inputs=inputs, outputs=outputs
275
- )
276
-
277
- # Reset images
278
- reset_outputs = [
279
- input_image0,
280
- input_image1,
281
- match_setting_threshold,
282
- match_setting_max_features,
283
- detect_keypoints_threshold,
284
- matcher_list,
285
- input_image0,
286
- input_image1,
287
- match_image_src,
288
- output_keypoints,
289
- output_matches_raw,
290
- output_matches_ransac,
291
- matches_result_info,
292
- matcher_info,
293
- output_wrapped,
294
- geometry_result,
295
- ransac_method,
296
- ransac_reproj_threshold,
297
- ransac_confidence,
298
- ransac_max_iter,
299
- choice_geometry_type,
300
- ]
301
- button_reset.click(
302
- fn=self.ui_reset_state, inputs=None, outputs=reset_outputs
303
- )
304
-
305
- # run ransac button action
306
- button_ransac.click(
307
- fn=run_ransac,
308
- inputs=[
309
- state_cache,
310
- choice_geometry_type,
311
- ransac_method,
312
- ransac_reproj_threshold,
313
- ransac_confidence,
314
- ransac_max_iter,
315
- ],
316
- outputs=[
317
- output_matches_ransac,
318
- matches_result_info,
319
- output_wrapped,
320
- ],
321
- )
322
-
323
- # estimate geo
324
- choice_geometry_type.change(
325
- fn=generate_warp_images,
326
- inputs=[
327
- input_image0,
328
- input_image1,
329
- geometry_result,
330
- choice_geometry_type,
331
- ],
332
- outputs=[output_wrapped, geometry_result],
333
- )
334
 
335
  def run(self):
336
  self.app.queue().launch(
 
12
  run_ransac,
13
  gen_examples,
14
  GRADIO_VERSION,
15
+ ROOT,
16
  )
17
 
18
 
 
50
 
51
  def init_interface(self):
52
  with gr.Blocks() as self.app:
53
+ with gr.Tab("Image Matching"):
54
+ with gr.Row():
55
+ with gr.Column(scale=1):
56
+ gr.Image(
57
+ str(
58
+ Path(__file__).parent.parent
59
+ / "assets/logo.webp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ),
61
+ elem_id="logo-img",
62
+ show_label=False,
63
+ show_share_button=False,
64
+ show_download_button=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
+ with gr.Column(scale=3):
67
+ gr.Markdown(DESCRIPTION)
68
+ with gr.Row(equal_height=False):
69
+ with gr.Column():
70
+ with gr.Row():
71
+ matcher_list = gr.Dropdown(
72
+ choices=self.init_matcher_dropdown(),
73
+ value="disk+lightglue",
74
+ label="Matching Model",
75
+ interactive=True,
76
+ )
77
+ match_image_src = gr.Radio(
78
+ (
79
+ ["upload", "webcam", "clipboard"]
80
+ if GRADIO_VERSION > "3"
81
+ else ["upload", "webcam", "canvas"]
82
+ ),
83
+ label="Image Source",
84
+ value="upload",
85
+ )
86
+ with gr.Row():
87
+ input_image0 = gr.Image(
88
+ label="Image 0",
89
+ type="numpy",
90
+ image_mode="RGB",
91
+ height=300 if GRADIO_VERSION > "3" else None,
92
+ interactive=True,
93
+ )
94
+ input_image1 = gr.Image(
95
+ label="Image 1",
96
+ type="numpy",
97
+ image_mode="RGB",
98
+ height=300 if GRADIO_VERSION > "3" else None,
99
+ interactive=True,
100
+ )
101
 
102
+ with gr.Row():
103
+ button_reset = gr.Button(value="Reset")
104
+ button_run = gr.Button(
105
+ value="Run Match", variant="primary"
106
+ )
107
 
108
+ with gr.Accordion("Advanced Setting", open=False):
109
+ with gr.Accordion("Matching Setting", open=True):
110
+ with gr.Row():
111
+ match_setting_threshold = gr.Slider(
112
+ minimum=0.0,
113
+ maximum=1,
114
+ step=0.001,
115
+ label="Match thres.",
116
+ value=0.1,
117
+ )
118
+ match_setting_max_features = gr.Slider(
119
+ minimum=10,
120
+ maximum=10000,
121
+ step=10,
122
+ label="Max features",
123
+ value=1000,
124
+ )
125
+ # TODO: add line settings
126
+ with gr.Row():
127
+ detect_keypoints_threshold = gr.Slider(
128
+ minimum=0,
129
+ maximum=1,
130
+ step=0.001,
131
+ label="Keypoint thres.",
132
+ value=0.015,
133
+ )
134
+ detect_line_threshold = gr.Slider(
135
+ minimum=0.1,
136
+ maximum=1,
137
+ step=0.01,
138
+ label="Line thres.",
139
+ value=0.2,
140
+ )
141
+ # matcher_lists = gr.Radio(
142
+ # ["NN-mutual", "Dual-Softmax"],
143
+ # label="Matcher mode",
144
+ # value="NN-mutual",
145
+ # )
146
+ with gr.Accordion("RANSAC Setting", open=True):
147
+ with gr.Row(equal_height=False):
148
+ ransac_method = gr.Dropdown(
149
+ choices=ransac_zoo.keys(),
150
+ value=self.cfg["defaults"][
151
+ "ransac_method"
152
+ ],
153
+ label="RANSAC Method",
154
+ interactive=True,
155
+ )
156
+ ransac_reproj_threshold = gr.Slider(
157
  minimum=0.0,
158
+ maximum=12,
159
+ step=0.01,
160
+ label="Ransac Reproj threshold",
161
+ value=8.0,
 
 
 
 
 
 
 
162
  )
163
+ ransac_confidence = gr.Slider(
164
+ minimum=0.0,
 
 
165
  maximum=1,
166
+ step=0.00001,
167
+ label="Ransac Confidence",
168
+ value=self.cfg["defaults"][
169
+ "ransac_confidence"
170
+ ],
171
  )
172
+ ransac_max_iter = gr.Slider(
173
+ minimum=0.0,
174
+ maximum=100000,
175
+ step=100,
176
+ label="Ransac Iterations",
177
+ value=self.cfg["defaults"][
178
+ "ransac_max_iter"
179
+ ],
180
  )
181
+ button_ransac = gr.Button(
182
+ value="Rerun RANSAC", variant="primary"
 
 
 
 
 
 
 
 
 
 
183
  )
184
+ with gr.Accordion("Geometry Setting", open=False):
185
+ with gr.Row(equal_height=False):
186
+ choice_geometry_type = gr.Radio(
187
+ ["Fundamental", "Homography"],
188
+ label="Reconstruct Geometry",
189
+ value=self.cfg["defaults"][
190
+ "setting_geometry"
191
+ ],
192
+ )
193
+
194
+ # collect inputs
195
+ state_cache = gr.State({})
196
+ inputs = [
197
+ input_image0,
198
+ input_image1,
199
+ match_setting_threshold,
200
+ match_setting_max_features,
201
+ detect_keypoints_threshold,
202
+ matcher_list,
203
+ ransac_method,
204
+ ransac_reproj_threshold,
205
+ ransac_confidence,
206
+ ransac_max_iter,
207
+ choice_geometry_type,
208
+ gr.State(self.matcher_zoo),
209
+ # state_cache,
210
+ ]
211
+
212
+ # Add some examples
213
+ with gr.Row():
214
+ # Example inputs
215
+ gr.Examples(
216
+ examples=gen_examples(),
217
+ inputs=inputs,
218
+ outputs=[],
219
+ fn=run_matching,
220
+ cache_examples=False,
221
+ label=(
222
+ "Examples (click one of the images below to Run"
223
+ " Match). Thx: WxBS"
224
+ ),
225
  )
226
+ with gr.Accordion("Supported Algorithms", open=False):
227
+ # add a table of supported algorithms
228
+ self.display_supported_algorithms()
229
+
230
+ with gr.Column():
231
+ output_keypoints = gr.Image(
232
+ label="Keypoints", type="numpy"
233
+ )
234
+ output_matches_raw = gr.Image(
235
+ label="Raw Matches",
236
+ type="numpy",
237
+ )
238
+ output_matches_ransac = gr.Image(
239
+ label="Ransac Matches", type="numpy"
240
+ )
241
+ with gr.Accordion(
242
+ "Open for More: Matches Statistics", open=False
243
+ ):
244
+ matches_result_info = gr.JSON(
245
+ label="Matches Statistics"
246
  )
247
+ matcher_info = gr.JSON(label="Match info")
248
+
249
+ with gr.Accordion(
250
+ "Open for More: Warped Image", open=False
251
+ ):
252
+ output_wrapped = gr.Image(
253
+ label="Wrapped Pair", type="numpy"
254
  )
255
+ with gr.Accordion(
256
+ "Open for More: Geometry info", open=False
257
+ ):
258
+ geometry_result = gr.JSON(
259
+ label="Reconstructed Geometry"
 
 
 
260
  )
261
 
262
+ # callbacks
263
+ match_image_src.change(
264
+ fn=self.ui_change_imagebox,
265
+ inputs=match_image_src,
266
+ outputs=input_image0,
267
+ )
268
+ match_image_src.change(
269
+ fn=self.ui_change_imagebox,
270
+ inputs=match_image_src,
271
+ outputs=input_image1,
272
+ )
273
+
274
+ # collect outputs
275
+ outputs = [
276
+ output_keypoints,
277
+ output_matches_raw,
278
+ output_matches_ransac,
279
+ matches_result_info,
280
+ matcher_info,
281
+ geometry_result,
282
+ output_wrapped,
283
+ state_cache,
284
+ ]
285
+ # button callbacks
286
+ button_run.click(
287
+ fn=run_matching, inputs=inputs, outputs=outputs
288
+ )
289
+
290
+ # Reset images
291
+ reset_outputs = [
292
  input_image0,
293
  input_image1,
294
  match_setting_threshold,
295
  match_setting_max_features,
296
  detect_keypoints_threshold,
297
  matcher_list,
298
+ input_image0,
299
+ input_image1,
300
+ match_image_src,
301
+ output_keypoints,
302
+ output_matches_raw,
303
+ output_matches_ransac,
304
+ matches_result_info,
305
+ matcher_info,
306
+ output_wrapped,
307
+ geometry_result,
308
  ransac_method,
309
  ransac_reproj_threshold,
310
  ransac_confidence,
311
  ransac_max_iter,
312
  choice_geometry_type,
 
 
313
  ]
314
+ button_reset.click(
315
+ fn=self.ui_reset_state,
316
+ inputs=None,
317
+ outputs=reset_outputs,
318
+ )
319
 
320
+ # run ransac button action
321
+ button_ransac.click(
322
+ fn=run_ransac,
323
+ inputs=[
324
+ state_cache,
325
+ choice_geometry_type,
326
+ ransac_method,
327
+ ransac_reproj_threshold,
328
+ ransac_confidence,
329
+ ransac_max_iter,
330
+ ],
331
+ outputs=[
332
+ output_matches_ransac,
333
+ matches_result_info,
334
+ output_wrapped,
335
+ ],
 
 
 
 
 
 
 
336
  )
337
+
338
+ # estimate geo
339
+ choice_geometry_type.change(
340
+ fn=generate_warp_images,
341
+ inputs=[
342
+ input_image0,
343
+ input_image1,
344
+ geometry_result,
345
+ choice_geometry_type,
346
+ ],
347
+ outputs=[output_wrapped, geometry_result],
348
  )
349
+ with gr.Tab("Under construction"):
350
+ self.init_tab_sfm()
 
 
 
 
 
351
 
352
+ def init_tab_sfm(self):
353
+ with gr.Row():
354
+ with gr.Column():
355
+ with gr.Row():
356
+ gr.Textbox("Under construction", label="A", visible=True)
357
+ gr.Textbox("Under construction", label="B", visible=True)
358
+ gr.Textbox("Under construction", label="C", visible=True)
359
+ with gr.Row():
360
+ with gr.Accordion("Open for More", open=False):
361
+ gr.Textbox(
362
+ "Under construction", label="A1", visible=True
363
  )
364
+ gr.Textbox(
365
+ "Under construction", label="B1", visible=True
366
+ )
367
+ gr.Textbox(
368
+ "Under construction", label="C1", visible=True
369
+ )
370
+ with gr.Column():
371
+ gr.Textbox("Under construction", label="D", visible=True)
372
+ gr.Textbox("Under construction", label="E", visible=True)
373
+ gr.Textbox("Under construction", label="F", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
  def run(self):
376
  self.app.queue().launch(
common/config.yaml CHANGED
@@ -16,6 +16,17 @@ defaults:
16
  setting_geometry: Homography
17
 
18
  matcher_zoo:
 
 
 
 
 
 
 
 
 
 
 
19
  DUSt3R:
20
  # TODO: duster is under development
21
  enable: false
 
16
  setting_geometry: Homography
17
 
18
  matcher_zoo:
19
+ omniglue:
20
+ enable: true
21
+ matcher: omniglue
22
+ dense: true
23
+ info:
24
+ name: OmniGlue
25
+ source: "CVPR 2024"
26
+ github: https://github.com/Vincentqyw/omniglue-onnx
27
+ paper: https://arxiv.org/abs/2405.12979
28
+ project: https://hwjiang1510.github.io/OmniGlue/
29
+ display: true
30
  DUSt3R:
31
  # TODO: duster is under development
32
  enable: false
env-docker.txt CHANGED
@@ -29,4 +29,5 @@ tensorboardX==2.6.1
29
  torchmetrics==0.6.0
30
  torchvision==0.17.1
31
  tqdm==4.65.0
32
- yacs==0.1.8
 
 
29
  torchmetrics==0.6.0
30
  torchvision==0.17.1
31
  tqdm==4.65.0
32
+ yacs==0.1.8
33
+ onnxruntime
hloc/match_dense.py CHANGED
@@ -211,6 +211,20 @@ confs = {
211
  "dfactor": 8,
212
  },
213
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  "sold2": {
215
  "output": "matches-sold2",
216
  "model": {
 
211
  "dfactor": 8,
212
  },
213
  },
214
+ "omniglue": {
215
+ "output": "matches-omniglue",
216
+ "model": {
217
+ "name": "omniglue",
218
+ "match_threshold": 0.2,
219
+ "features": "null",
220
+ },
221
+ "preprocessing": {
222
+ "grayscale": False,
223
+ "resize_max": 1024,
224
+ "dfactor": 8,
225
+ "force_resize": False,
226
+ },
227
+ },
228
  "sold2": {
229
  "output": "matches-sold2",
230
  "model": {
hloc/matchers/omniglue.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import subprocess
4
+ import numpy as np
5
+ from pathlib import Path
6
+
7
+ from .. import logger
8
+ from ..utils.base_model import BaseModel
9
+
10
+ omniglue_path = Path(__file__).parent / "../../third_party/omniglue"
11
+ sys.path.append(str(omniglue_path))
12
+ from src import omniglue
13
+
14
+
15
+ class OmniGlue(BaseModel):
16
+ default_conf = {
17
+ "match_threshold": 0.02,
18
+ "max_keypoints": 2048,
19
+ }
20
+ required_inputs = ["image0", "image1"]
21
+ dino_v2_link_dict = {
22
+ "dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth"
23
+ }
24
+
25
+ def _init(self, conf):
26
+ logger.info(f"Loadeding OmniGlue model")
27
+ og_model_path = omniglue_path / "models" / "omniglue.onnx"
28
+ sp_model_path = omniglue_path / "models" / "sp_v6.onnx"
29
+ dino_model_path = (
30
+ omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" # ~330MB
31
+ )
32
+ if not dino_model_path.exists():
33
+ link = self.dino_v2_link_dict.get(dino_model_path.name, None)
34
+ if link is not None:
35
+ cmd = ["wget", link, "-O", str(dino_model_path)]
36
+ logger.info(f"Downloading the dinov2 model with `{cmd}`.")
37
+ subprocess.run(cmd, check=True)
38
+ else:
39
+ logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
40
+
41
+ self.net = omniglue.OmniGlue(
42
+ og_export=str(og_model_path),
43
+ sp_export=str(sp_model_path),
44
+ dino_export=str(dino_model_path),
45
+ max_keypoints=self.conf["max_keypoints"] * 4,
46
+ )
47
+ logger.info(f"Loaded OmniGlue model done!")
48
+
49
+ def _forward(self, data):
50
+ image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255
51
+ image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255
52
+ image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
53
+ image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
54
+ match_kp0, match_kp1, match_confidences = self.net.FindMatches(
55
+ image0_rgb_np, image1_rgb_np
56
+ )
57
+
58
+ # filter matches
59
+ match_threshold = self.conf["match_threshold"]
60
+ keep_idx = []
61
+ for i in range(match_kp0.shape[0]):
62
+ if match_confidences[i] > match_threshold:
63
+ keep_idx.append(i)
64
+ num_filtered_matches = len(keep_idx)
65
+ scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1)
66
+ pred = {
67
+ "keypoints0": torch.from_numpy(match_kp0[keep_idx]),
68
+ "keypoints1": torch.from_numpy(match_kp1[keep_idx]),
69
+ "mconf": scores,
70
+ }
71
+
72
+ top_k = self.conf["max_keypoints"]
73
+ if top_k is not None and len(scores) > top_k:
74
+ keep = torch.argsort(scores, descending=True)[:top_k]
75
+ scores = scores[keep]
76
+ pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
77
+ pred["keypoints0"][keep],
78
+ pred["keypoints1"][keep],
79
+ scores,
80
+ )
81
+ return pred
requirements.txt CHANGED
@@ -30,4 +30,5 @@ tensorboardX==2.6.1
30
  torchmetrics==0.6.0
31
  torchvision==0.17.1
32
  tqdm==4.65.0
33
- yacs==0.1.8
 
 
30
  torchmetrics==0.6.0
31
  torchvision==0.17.1
32
  tqdm==4.65.0
33
+ yacs==0.1.8
34
+ onnxruntime
test_app_cli.py CHANGED
@@ -11,6 +11,7 @@ from common.utils import (
11
  )
12
  from common.api import ImageMatchingAPI
13
 
 
14
  def test_api(config: dict = None):
15
  img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
16
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
@@ -32,6 +33,7 @@ def test_api(config: dict = None):
32
  else:
33
  logger.info(f"Skipping {k} ...")
34
 
 
35
  if __name__ == "__main__":
36
  import argparse
37
 
 
11
  )
12
  from common.api import ImageMatchingAPI
13
 
14
+
15
  def test_api(config: dict = None):
16
  img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
17
  img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
 
33
  else:
34
  logger.info(f"Skipping {k} ...")
35
 
36
+
37
  if __name__ == "__main__":
38
  import argparse
39
 
third_party/omniglue/.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compiled python modules.
2
+ *.pyc
3
+
4
+ # Byte-compiled
5
+ _pycache__/
6
+ .cache/
7
+
8
+ # Poetry, setuptools, PyPI distribution artifacts.
9
+ /*.egg-info
10
+ .eggs/
11
+ build/
12
+ dist/
13
+ poetry.lock
14
+
15
+ # Tests
16
+ .pytest_cache/
17
+
18
+ # Type checking
19
+ .pytype/
20
+
21
+ # Other
22
+ *.DS_Store
23
+
24
+ # PyCharm
25
+ .idea
26
+ models/sp_v6*
27
+ models/og_export*
28
+ models/dinov2_vitb14_pretrain.pth
third_party/omniglue/CHANGELOG.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ <!--
4
+
5
+ Changelog follow the https://keepachangelog.com/ standard (at least the headers)
6
+
7
+ This allow to:
8
+
9
+ * auto-parsing release notes during the automated releases from github-action:
10
+ https://github.com/marketplace/actions/pypi-github-auto-release
11
+ * Have clickable headers in the rendered markdown
12
+
13
+ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
14
+
15
+ * Create a new `# [2.0.0] - YYYY-MM-DD` header and add the current
16
+ `[Unreleased]` notes.
17
+ * At the end of the file:
18
+ * Define the new link url:
19
+ `[2.0.0]: https://github.com/google-research/omniglue/compare/v1.0.0...v2.0.0`
20
+ * Update the `[Unreleased]` url: `v1.0.0...HEAD` -> `v2.0.0...HEAD`
21
+
22
+ -->
23
+
24
+ ## [Unreleased]
25
+
26
+ ## [0.1.0] - 2022-01-01
27
+
28
+ * Initial release
29
+
30
+ [Unreleased]: https://github.com/google-research/omniglue/compare/v0.1.0...HEAD
31
+ [0.1.0]: https://github.com/google-research/omniglue/releases/tag/v0.1.0
third_party/omniglue/CONTRIBUTING.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We'd love to accept your patches and contributions to this project. There are
4
+ just a few small guidelines you need to follow.
5
+
6
+ ## Contributor License Agreement
7
+
8
+ Contributions to this project must be accompanied by a Contributor License
9
+ Agreement (CLA). You (or your employer) retain the copyright to your
10
+ contribution; this simply gives us permission to use and redistribute your
11
+ contributions as part of the project. Head over to
12
+ <https://cla.developers.google.com/> to see your current agreements on file or
13
+ to sign a new one.
14
+
15
+ You generally only need to submit a CLA once, so if you've already submitted one
16
+ (even if it was for a different project), you probably don't need to do it
17
+ again.
18
+
19
+ ## Code Reviews
20
+
21
+ All submissions, including submissions by project members, require review. We
22
+ use GitHub pull requests for this purpose. Consult
23
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24
+ information on using pull requests.
25
+
26
+ ## Community Guidelines
27
+
28
+ This project follows
29
+ [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
third_party/omniglue/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
third_party/omniglue/README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # \[CVPR'24\] Code release for OmniGlue(ONNX)
4
+
5
+ [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/Realcat/image-matching-webui)
6
+
7
+ <p align="center">
8
+ <a href="https://hwjiang1510.github.io/">Hanwen Jiang</a>,
9
+ <a href="https://scholar.google.com/citations?user=jgSItF4AAAAJ">Arjun Karpur</a>,
10
+ <a href="https://scholar.google.com/citations?user=7EeSOcgAAAAJ">Bingyi Cao</a>,
11
+ <a href="https://www.cs.utexas.edu/~huangqx/">Qixing Huang</a>,
12
+ <a href="https://andrefaraujo.github.io/">Andre Araujo</a>
13
+ </p>
14
+
15
+ </div>
16
+
17
+ --------------------------------------------------------------------------------
18
+
19
+ <div align="center">
20
+ <a href="https://hwjiang1510.github.io/OmniGlue/"><strong>Project Page</strong></a> |
21
+ <a href="https://arxiv.org/abs/2405.12979"><strong>Paper</strong></a> |
22
+ <a href="#installation"><strong>Usage</strong></a> |
23
+ <a href="https://huggingface.co/spaces/qubvel-hf/omniglue"><strong>Demo</strong></a>
24
+ </div>
25
+
26
+ <br>
27
+
28
+ ONNX-compatible release for the CVPR 2024 paper: **OmniGlue: Generalizable Feature
29
+ Matching with Foundation Model Guidance**.
30
+
31
+ ![og_diagram.png](res/og_diagram.png "og_diagram.png")
32
+
33
+ **Abstract:** The image matching field has been witnessing a continuous
34
+ emergence of novel learnable feature matching techniques, with ever-improving
35
+ performance on conventional benchmarks. However, our investigation shows that
36
+ despite these gains, their potential for real-world applications is restricted
37
+ by their limited generalization capabilities to novel image domains. In this
38
+ paper, we introduce OmniGlue, the first learnable image matcher that is designed
39
+ with generalization as a core principle. OmniGlue leverages broad knowledge from
40
+ a vision foundation model to guide the feature matching process, boosting
41
+ generalization to domains not seen at training time. Additionally, we propose a
42
+ novel keypoint position-guided attention mechanism which disentangles spatial
43
+ and appearance information, leading to enhanced matching descriptors. We perform
44
+ comprehensive experiments on a suite of 6 datasets with varied image domains,
45
+ including scene-level, object-centric and aerial images. OmniGlue’s novel
46
+ components lead to relative gains on unseen domains of 18.8% with respect to a
47
+ directly comparable reference model, while also outperforming the recent
48
+ LightGlue method by 10.1% relatively.
49
+
50
+
51
+ ## Installation
52
+
53
+ First, use pip to install `omniglue`:
54
+
55
+ ```sh
56
+ conda create -n omniglue pip
57
+ conda activate omniglue
58
+
59
+ git clone https://github.com/google-research/omniglue.git
60
+ cd omniglue
61
+ pip install -e .
62
+ ```
63
+
64
+ Then, download the following models to `./models/`
65
+
66
+ ```sh
67
+ # Download to ./models/ dir.
68
+ mkdir models
69
+ cd models
70
+
71
+ # SuperPoint.
72
+ git clone https://github.com/rpautrat/SuperPoint.git
73
+ mv SuperPoint/pretrained_models/sp_v6.tgz . && rm -rf SuperPoint
74
+ tar zxvf sp_v6.tgz && rm sp_v6.tgz
75
+
76
+ # DINOv2 - vit-b14.
77
+ wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
78
+
79
+ # OmniGlue.
80
+ wget https://storage.googleapis.com/omniglue/og_export.zip
81
+ unzip og_export.zip && rm og_export.zip
82
+ ```
83
+
84
+ Direct download links:
85
+
86
+ - [[SuperPoint weights]](https://github.com/rpautrat/SuperPoint/tree/master/pretrained_models): from [github.com/rpautrat/SuperPoint](https://github.com/rpautrat/SuperPoint)
87
+ - [[DINOv2 weights]](https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth): from [github.com/facebookresearch/dinov2](https://github.com/facebookresearch/dinov2) (ViT-B/14 distilled backbone without register).
88
+ - [[OmniGlue weights]](https://storage.googleapis.com/omniglue/og_export.zip)
89
+
90
+ ## Usage
91
+ The code snippet below outlines how you can perform OmniGlue inference in your
92
+ own python codebase.
93
+
94
+ ```py
95
+
96
+ from src import omniglue
97
+
98
+ image0 = ... # load images from file into np.array
99
+ image1 = ...
100
+
101
+ og = omniglue.OmniGlue(
102
+ og_export="./models/omniglue.onnx",
103
+ sp_export="./models/sp_v6.onnx",
104
+ dino_export="./models/dinov2_vitb14_pretrain.pth",
105
+ )
106
+
107
+ match_kp0s, match_kp1s, match_confidences = og.FindMatches(image0, image1)
108
+ # Output:
109
+ # match_kp0: (N, 2) array of (x,y) coordinates in image0.
110
+ # match_kp1: (N, 2) array of (x,y) coordinates in image1.
111
+ # match_confidences: N-dim array of each of the N match confidence scores.
112
+ ```
113
+
114
+ ## Demo
115
+
116
+ `demo.py` contains example usage of the `omniglue` module. To try with your own
117
+ images, replace `./res/demo1.jpg` and `./res/demo2.jpg` with your own
118
+ filepaths.
119
+
120
+ ```sh
121
+ conda activate omniglue
122
+ python demo.py ./res/demo1.jpg ./res/demo2.jpg
123
+ # <see output in './demo_output.png'>
124
+ ```
125
+
126
+ Expected output:
127
+ ![demo_output.png](res/demo_output.png "demo_output.png")
128
+
129
+ Comparison of Results Between TensorFlow and ONNX:
130
+ ![result_tf_and_onnx.png](res/result_tf_and_onnx.png "result_tf_and_onnx.png")
131
+
132
+
133
+ ## Repo TODOs
134
+
135
+ - ~~Provide `demo.py` example usage script.~~
136
+ - Support matching for pre-extracted features.
137
+ - Release eval pipelines for in-domain (MegaDepth).
138
+ - Release eval pipelines for all out-of-domain datasets.
139
+
140
+ ## BibTex
141
+ ```
142
+ @inproceedings{jiang2024Omniglue,
143
+ title={OmniGlue: Generalizable Feature Matching with Foundation Model Guidance},
144
+ author={Jiang, Hanwen and Karpur, Arjun and Cao, Bingyi and Huang, Qixing and Araujo, Andre},
145
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
146
+ year={2024},
147
+ }
148
+ ```
149
+
150
+ --------------------------------------------------------------------------------
151
+
152
+ This is not an officially supported Google product.
third_party/omniglue/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """omniglue API."""
16
+
17
+ # A new PyPI release will be pushed every time `__version__` is increased.
18
+ # When changing this, also update the CHANGELOG.md.
19
+ __version__ = "0.1.0"
third_party/omniglue/demo.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Demo script for performing OmniGlue inference."""
17
+
18
+ import sys
19
+ import time
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ from src import omniglue
23
+ from src.omniglue import utils
24
+ from PIL import Image
25
+
26
+
27
+ def main(argv) -> None:
28
+ if len(argv) != 3:
29
+ print("error - usage: python demo.py <img1_fp> <img2_fp>")
30
+ return
31
+
32
+ # Load images.
33
+ print("> Loading images...")
34
+ image0 = np.array(Image.open(argv[1]))
35
+ image1 = np.array(Image.open(argv[2]))
36
+
37
+ # Load models.
38
+ print("> Loading OmniGlue (and its submodules: SuperPoint & DINOv2)...")
39
+ start = time.time()
40
+ og = omniglue.OmniGlue(
41
+ og_export="./models/omniglue.onnx",
42
+ sp_export="./models/sp_v6.onnx",
43
+ dino_export="./models/dinov2_vitb14_pretrain.pth",
44
+ )
45
+ print(f"> \tTook {time.time() - start} seconds.")
46
+
47
+ # Perform inference.
48
+ print("> Finding matches...")
49
+ start = time.time()
50
+ match_kp0, match_kp1, match_confidences = og.FindMatches(image0, image1)
51
+ num_matches = match_kp0.shape[0]
52
+ print(f"> \tFound {num_matches} matches.")
53
+ print(f"> \tTook {time.time() - start} seconds.")
54
+
55
+ # Filter by confidence (0.02).
56
+ print("> Filtering matches...")
57
+ match_threshold = 0.02 # Choose any value [0.0, 1.0).
58
+ keep_idx = []
59
+ for i in range(match_kp0.shape[0]):
60
+ if match_confidences[i] > match_threshold:
61
+ keep_idx.append(i)
62
+ num_filtered_matches = len(keep_idx)
63
+ match_kp0 = match_kp0[keep_idx]
64
+ match_kp1 = match_kp1[keep_idx]
65
+ match_confidences = match_confidences[keep_idx]
66
+ print(
67
+ f"> \tFound {num_filtered_matches}/{num_matches} above threshold {match_threshold}"
68
+ )
69
+
70
+ # Visualize.
71
+ print("> Visualizing matches...")
72
+ viz = utils.visualize_matches(
73
+ image0,
74
+ image1,
75
+ match_kp0,
76
+ match_kp1,
77
+ np.eye(num_filtered_matches),
78
+ show_keypoints=True,
79
+ highlight_unmatched=True,
80
+ title=f"{num_filtered_matches} matches",
81
+ line_width=2,
82
+ )
83
+ plt.figure(figsize=(20, 10), dpi=100, facecolor="w", edgecolor="k")
84
+ plt.axis("off")
85
+ plt.imshow(viz)
86
+ plt.imsave("./demo_output.png", viz)
87
+ print("> \tSaved visualization to ./demo_output.png")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main(sys.argv)
third_party/omniglue/init_repo.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ mkdir models
3
+ cd models
4
+
5
+ # SuperPoint.
6
+ git clone https://github.com/rpautrat/SuperPoint.git
7
+ mv SuperPoint/pretrained_models/sp_v6.tgz . && rm -rf SuperPoint
8
+ tar zxvf sp_v6.tgz && rm sp_v6.tgz
9
+
10
+ # DINOv2 - vit-b14.
11
+ wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
12
+
13
+ # OmniGlue.
14
+ wget https://storage.googleapis.com/omniglue/og_export.zip
15
+ unzip og_export.zip && rm og_export.zip
16
+
17
+ cd ..
18
+
19
+ saved_model=./models/og_export
20
+ output_onnx=./models/omniglue.onnx
21
+ python -m tf2onnx.convert --saved-model ${saved_model} --output ${output_onnx} --tag serve
22
+
23
+
24
+ saved_model=./models/sp_v6
25
+ output_onnx=./models/sp_v6.onnx
26
+ python -m tf2onnx.convert --saved-model ${saved_model} --output ${output_onnx} --tag serve
27
+
third_party/omniglue/models/omniglue.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cc095d640e8d32b9ef2b29e8029d316e8a50cfed94968d3881265811b03ad28
3
+ size 51182029
third_party/omniglue/pyproject.toml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ # Project metadata. Available keys are documented at:
3
+ # https://packaging.python.org/en/latest/specifications/declaring-project-metadata
4
+ name = "omniglue"
5
+ description = "Official code release for CVPR'24 paper 'OmniGlue: Generalizable Feature Matching with Foundation Model Guidance"
6
+ readme = "README.md"
7
+ requires-python = ">=3.8"
8
+ license = {file = "LICENSE"}
9
+ authors = [{name = "OmniGlue authors"}]
10
+ classifiers = [ # List of https://pypi.org/classifiers/
11
+ "License :: OSI Approved :: Apache Software License",
12
+ "Intended Audience :: Science/Research",
13
+ ]
14
+ keywords = ["feature matching"]
15
+ dynamic = ["version", "dependencies"]
16
+
17
+ # pip dependencies of the project
18
+ # Installed locally with `pip install -e .`
19
+ [tool.setuptools.dynamic]
20
+ dependencies = {file = ["requirements.txt"]}
21
+
22
+ [project.urls]
23
+ homepage = "https://github.com/google-research/omniglue"
24
+ repository = "https://github.com/google-research/omniglue"
25
+ changelog = "https://github.com/google-research/omniglue/blob/main/CHANGELOG.md"
26
+ # documentation = ""
27
+
28
+ [tool.setuptools.packages.find]
29
+ where = ["src", "third_party"]
30
+ include = ["omniglue*", "dinov2*"]
31
+
32
+ [project.optional-dependencies]
33
+ # Development deps (unittest, linting, formating,...)
34
+ # Installed through `pip install -e .[dev]`
35
+ dev = [
36
+ "pytest",
37
+ "pytest-xdist",
38
+ "pylint>=2.6.0",
39
+ "pyink",
40
+ ]
41
+
42
+ [tool.pyink]
43
+ # Formatting configuration to follow Google style-guide
44
+ line-length = 80
45
+ unstable = true
46
+ pyink-indentation = 2
47
+ pyink-use-majority-quotes = true
48
+
49
+ [build-system]
50
+ # Build system specify which backend is used to build/install the project (flit,
51
+ # poetry, setuptools,...). All backends are supported by `pip install`
52
+ requires = ["setuptools", "wheel"]
53
+ build-backend = "setuptools.build_meta"
54
+
55
+ [tool.flit.sdist]
56
+ # Flit specific options (files to exclude from the PyPI package).
57
+ # If using another build backend (setuptools, poetry), you can remove this
58
+ # section.
59
+ exclude = [
60
+ # Do not release tests files on PyPI
61
+ "**/*_test.py",
62
+ ]
third_party/omniglue/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ opencv-python
4
+ Pillow
5
+ torch
6
+ gdown
7
+ tf2onnx
8
+ onnxruntime
third_party/omniglue/res/demo1.jpg ADDED

Git LFS Details

  • SHA256: 0c3719183ae9139e45569e16861f42ac8e47b46c86f3536fdc52b22011f31871
  • Pointer size: 130 Bytes
  • Size of remote file: 85.3 kB
third_party/omniglue/res/demo2.jpg ADDED

Git LFS Details

  • SHA256: 24dbe3a2ee909002b265e647b96a7141419c954a2a90b235699c186f927705c4
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
third_party/omniglue/res/demo_output.png ADDED

Git LFS Details

  • SHA256: 6ecf8c48a70baefb6982c088167774a5bbc75c704e6697c23958f56a55a0a717
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
third_party/omniglue/res/og_diagram.png ADDED

Git LFS Details

  • SHA256: c0f8ee5541fde5f6cbb81485106ddd268c58de006590f8b6dea58039e5b0a476
  • Pointer size: 132 Bytes
  • Size of remote file: 4.82 MB
third_party/omniglue/res/result_tf_and_onnx.png ADDED

Git LFS Details

  • SHA256: 3d2949de656dbfd39103d9819fc2392b9ad65ed189ead9fe2ef844f618ac204c
  • Pointer size: 131 Bytes
  • Size of remote file: 978 kB
third_party/omniglue/src/omniglue/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from . import omniglue_extract
16
+
17
+ OmniGlue = omniglue_extract.OmniGlue
third_party/omniglue/src/omniglue/dino_extract.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Wrapper for performing DINOv2 inference."""
16
+
17
+ import cv2
18
+ import numpy as np
19
+ from third_party.dinov2 import dino
20
+
21
+ from . import utils
22
+ import torch
23
+
24
+
25
+ class DINOExtract:
26
+ """Class to initialize DINO model and extract features from an image."""
27
+
28
+ def __init__(self, cpt_path: str, feature_layer: int = 1):
29
+ self.feature_layer = feature_layer
30
+ self.model = dino.vit_base()
31
+ state_dict_raw = torch.load(cpt_path, map_location="cpu")
32
+
33
+ # state_dict = {}
34
+ # for k, v in state_dict_raw.items():
35
+ # state_dict[k.replace('blocks', 'blocks.0')] = v
36
+
37
+ self.model.load_state_dict(state_dict_raw)
38
+ self.model.eval()
39
+
40
+ self.image_size_max = 630
41
+
42
+ self.h_down_rate = self.model.patch_embed.patch_size[0]
43
+ self.w_down_rate = self.model.patch_embed.patch_size[1]
44
+
45
+ def __call__(self, image) -> np.ndarray:
46
+ return self.forward(image)
47
+
48
+ def forward(self, image: np.ndarray) -> np.ndarray:
49
+ """Feeds image through DINO ViT model to extract features.
50
+
51
+ Args:
52
+ image: (H, W, 3) numpy array, decoded image bytes, value range [0, 255].
53
+
54
+ Returns:
55
+ features: (H // 14, W // 14, C) numpy array image features.
56
+ """
57
+ image = self._resize_input_image(image)
58
+ image_processed = self._process_image(image)
59
+ image_processed = image_processed.unsqueeze(0).float()
60
+ features = self.extract_feature(image_processed)
61
+ features = features.squeeze(0).permute(1, 2, 0).cpu().numpy()
62
+ return features
63
+
64
+ def _resize_input_image(
65
+ self, image: np.ndarray, interpolation=cv2.INTER_LINEAR
66
+ ):
67
+ """Resizes image such that both dimensions are divisble by down_rate."""
68
+ h_image, w_image = image.shape[:2]
69
+ h_larger_flag = h_image > w_image
70
+ large_side_image = max(h_image, w_image)
71
+
72
+ # resize the image with the largest side length smaller than a threshold
73
+ # to accelerate ViT backbone inference (which has quadratic complexity).
74
+ if large_side_image > self.image_size_max:
75
+ if h_larger_flag:
76
+ h_image_target = self.image_size_max
77
+ w_image_target = int(self.image_size_max * w_image / h_image)
78
+ else:
79
+ w_image_target = self.image_size_max
80
+ h_image_target = int(self.image_size_max * h_image / w_image)
81
+ else:
82
+ h_image_target = h_image
83
+ w_image_target = w_image
84
+
85
+ h, w = (
86
+ h_image_target // self.h_down_rate,
87
+ w_image_target // self.w_down_rate,
88
+ )
89
+ h_resize, w_resize = h * self.h_down_rate, w * self.w_down_rate
90
+ image = cv2.resize(
91
+ image, (w_resize, h_resize), interpolation=interpolation
92
+ )
93
+ return image
94
+
95
+ def _process_image(self, image: np.ndarray) -> torch.Tensor:
96
+ """Turn image into pytorch tensor and normalize it."""
97
+ mean = np.array([0.485, 0.456, 0.406])
98
+ std = np.array([0.229, 0.224, 0.225])
99
+
100
+ image_processed = image / 255.0
101
+ image_processed = (image_processed - mean) / std
102
+ image_processed = torch.from_numpy(image_processed).permute(2, 0, 1)
103
+ return image_processed
104
+
105
+ def extract_feature(self, image):
106
+ """Extracts features from image.
107
+
108
+ Args:
109
+ image: (B, 3, H, W) torch tensor, normalized with ImageNet mean/std.
110
+
111
+ Returns:
112
+ features: (B, C, H//14, W//14) torch tensor image features.
113
+ """
114
+ b, _, h_origin, w_origin = image.shape
115
+ out = self.model.get_intermediate_layers(image, n=self.feature_layer)[0]
116
+ h = int(h_origin / self.h_down_rate)
117
+ w = int(w_origin / self.w_down_rate)
118
+ dim = out.shape[-1]
119
+ out = out.reshape(b, h, w, dim).permute(0, 3, 1, 2).detach()
120
+ return out
121
+
122
+
123
+ def _preprocess_shape(
124
+ h_image, w_image, image_size_max=630, h_down_rate=14, w_down_rate=14
125
+ ):
126
+ h_image = h_image.squeeze()
127
+ w_image = w_image.squeeze()
128
+
129
+ h_larger_flag = h_image > w_image
130
+ large_side_image = max(h_image, w_image)
131
+
132
+ def resize_h_larger():
133
+ h_image_target = image_size_max
134
+ w_image_target = int(image_size_max * w_image / h_image)
135
+ return h_image_target, w_image_target
136
+
137
+ def resize_w_larger_or_equal():
138
+ w_image_target = image_size_max
139
+ h_image_target = int(image_size_max * h_image / w_image)
140
+ return h_image_target, w_image_target
141
+
142
+ def keep_original():
143
+ return h_image, w_image
144
+
145
+ if large_side_image > image_size_max:
146
+ if h_larger_flag:
147
+ h_image_target, w_image_target = resize_h_larger()
148
+ else:
149
+ h_image_target, w_image_target = resize_w_larger_or_equal()
150
+ else:
151
+ h_image_target, w_image_target = keep_original()
152
+
153
+ h = h_image_target // h_down_rate
154
+ w = w_image_target // w_down_rate
155
+ h_resize = torch.tensor(h * h_down_rate)
156
+ w_resize = torch.tensor(w * w_down_rate)
157
+
158
+ h_resize = h_resize.unsqueeze(0)
159
+ w_resize = w_resize.unsqueeze(0)
160
+
161
+ return h_resize, w_resize
162
+
163
+
164
+ def get_dino_descriptors(dino_features, keypoints, height, width, feature_dim):
165
+ """Get DINO descriptors using Superpoint keypoints.
166
+
167
+ Args:
168
+ dino_features: DINO features in 1-D.
169
+ keypoints: Superpoint keypoint locations, in format (x, y), in pixels, shape
170
+ (N, 2).
171
+ height: image height, type torch int32.
172
+ width: image width, type torch int32.
173
+ feature_dim: DINO feature channel size, type torch int32.
174
+
175
+ Returns:
176
+ Interpolated DINO descriptors.
177
+ """
178
+ height_1d = height.reshape([1])
179
+ width_1d = width.reshape([1])
180
+
181
+ height_1d_resized, width_1d_resized = _preprocess_shape(
182
+ height_1d, width_1d, image_size_max=630, h_down_rate=14, w_down_rate=14
183
+ )
184
+
185
+ height_feat = height_1d_resized // 14
186
+ width_feat = width_1d_resized // 14
187
+ feature_dim_1d = torch.tensor(feature_dim).reshape([1])
188
+
189
+ dino_features = dino_features.reshape(
190
+ height_feat, width_feat, feature_dim_1d
191
+ )
192
+
193
+ img_size = torch.cat([width_1d, height_1d], dim=0).float()
194
+ feature_size = torch.cat([width_feat, height_feat], dim=0).float()
195
+ keypoints_feature = (
196
+ keypoints[0] / img_size.unsqueeze(0) * feature_size.unsqueeze(0)
197
+ )
198
+
199
+ dino_descriptors = []
200
+ for kp in keypoints_feature:
201
+ dino_descriptors.append(
202
+ utils.lookup_descriptor_bilinear(kp.numpy(), dino_features)
203
+ )
204
+ dino_descriptors = torch.tensor(
205
+ np.array(dino_descriptors), dtype=torch.float32
206
+ )
207
+ return dino_descriptors
third_party/omniglue/src/omniglue/omniglue_extract.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Wrapper for performing OmniGlue inference, plus (optionally) SP/DINO."""
16
+ import cv2
17
+ import torch
18
+ import numpy as np
19
+ import onnxruntime
20
+
21
+ from . import dino_extract
22
+ from . import superpoint_extract
23
+ from . import utils
24
+
25
+
26
+ DINO_FEATURE_DIM = 768
27
+ MATCH_THRESHOLD = 1e-3
28
+
29
+
30
+ class OmniGlue:
31
+ # TODO(omniglue): class docstring
32
+
33
+ def __init__(
34
+ self,
35
+ og_export: str,
36
+ sp_export: str | None = None,
37
+ dino_export: str | None = None,
38
+ max_keypoints: int = 2048,
39
+ ) -> None:
40
+ self.max_keypoints = max_keypoints
41
+ self.matcher = onnxruntime.InferenceSession(og_export)
42
+ if sp_export is not None:
43
+ self.sp_extract = superpoint_extract.SuperPointExtract(sp_export)
44
+ if dino_export is not None:
45
+ self.dino_extract = dino_extract.DINOExtract(
46
+ dino_export, feature_layer=1
47
+ )
48
+
49
+ def FindMatches(self, image0: np.ndarray, image1: np.ndarray):
50
+ """TODO(omniglue): docstring."""
51
+ height0, width0 = image0.shape[:2]
52
+ height1, width1 = image1.shape[:2]
53
+ # TODO: numpy to torch inputs
54
+ sp_features0 = self.sp_extract(image0, num_features=self.max_keypoints)
55
+ sp_features1 = self.sp_extract(image1, num_features=self.max_keypoints)
56
+ dino_features0 = self.dino_extract(image0)
57
+ dino_features1 = self.dino_extract(image1)
58
+ dino_descriptors0 = dino_extract.get_dino_descriptors(
59
+ dino_features0,
60
+ sp_features0,
61
+ torch.tensor(height0),
62
+ torch.tensor(width0),
63
+ DINO_FEATURE_DIM,
64
+ )
65
+ dino_descriptors1 = dino_extract.get_dino_descriptors(
66
+ dino_features1,
67
+ sp_features1,
68
+ torch.tensor(height1),
69
+ torch.tensor(width1),
70
+ DINO_FEATURE_DIM,
71
+ )
72
+
73
+ inputs = self._construct_inputs(
74
+ width0,
75
+ height0,
76
+ width1,
77
+ height1,
78
+ sp_features0,
79
+ sp_features1,
80
+ dino_descriptors0,
81
+ dino_descriptors1,
82
+ )
83
+
84
+ og_outputs = self.matcher.run(None, inputs)
85
+ soft_assignment = torch.from_numpy(og_outputs[0][:, :-1, :-1])
86
+
87
+ match_matrix = (
88
+ utils.soft_assignment_to_match_matrix(
89
+ soft_assignment, MATCH_THRESHOLD
90
+ )
91
+ .numpy()
92
+ .squeeze()
93
+ )
94
+
95
+ # Filter out any matches with 0.0 confidence keypoints.
96
+ match_indices = np.argwhere(match_matrix)
97
+ keep = []
98
+ for i in range(match_indices.shape[0]):
99
+ match = match_indices[i, :]
100
+ if (sp_features0[2][match[0]] > 0.0) and (
101
+ sp_features1[2][match[1]] > 0.0
102
+ ):
103
+ keep.append(i)
104
+ match_indices = match_indices[keep]
105
+
106
+ # Format matches in terms of keypoint locations.
107
+ match_kp0s = []
108
+ match_kp1s = []
109
+ match_confidences = []
110
+ for match in match_indices:
111
+ match_kp0s.append(sp_features0[0][match[0], :])
112
+ match_kp1s.append(sp_features1[0][match[1], :])
113
+ match_confidences.append(soft_assignment[0, match[0], match[1]])
114
+ match_kp0s = np.array(match_kp0s)
115
+ match_kp1s = np.array(match_kp1s)
116
+ match_confidences = np.array(match_confidences)
117
+ return match_kp0s, match_kp1s, match_confidences
118
+
119
+ ### Private methods ###
120
+
121
+ def _construct_inputs(
122
+ self,
123
+ width0,
124
+ height0,
125
+ width1,
126
+ height1,
127
+ sp_features0,
128
+ sp_features1,
129
+ dino_descriptors0,
130
+ dino_descriptors1,
131
+ ):
132
+ keypoints0 = sp_features0[0]
133
+ keypoints1 = sp_features1[0]
134
+ descriptors0 = sp_features0[1]
135
+ descriptors1 = sp_features1[1]
136
+ scores0 = sp_features0[2]
137
+ scores1 = sp_features1[2]
138
+ descriptors0_dino = dino_descriptors0
139
+ descriptors1_dino = dino_descriptors1
140
+ if isinstance(keypoints0, torch.Tensor):
141
+ keypoints0 = keypoints0.detach().numpy()
142
+ if isinstance(keypoints1, torch.Tensor):
143
+ keypoints1 = keypoints1.detach().numpy()
144
+ if isinstance(descriptors0, torch.Tensor):
145
+ descriptors0 = descriptors0.detach().numpy()
146
+ if isinstance(descriptors1, torch.Tensor):
147
+ descriptors1 = descriptors1.detach().numpy()
148
+ if isinstance(scores0, torch.Tensor):
149
+ scores0 = scores0.detach().numpy()
150
+ if isinstance(scores1, torch.Tensor):
151
+ scores1 = scores1.detach().numpy()
152
+ if isinstance(descriptors0_dino, torch.Tensor):
153
+ descriptors0_dino = descriptors0_dino.detach().numpy()
154
+ if isinstance(descriptors1_dino, torch.Tensor):
155
+ descriptors1_dino = descriptors1_dino.detach().numpy()
156
+ inputs = {
157
+ "keypoints0": np.expand_dims(keypoints0, axis=0).astype(np.float32),
158
+ "keypoints1": np.expand_dims(keypoints1, axis=0).astype(np.float32),
159
+ "descriptors0": np.expand_dims(descriptors0, axis=0).astype(
160
+ np.float32
161
+ ),
162
+ "descriptors1": np.expand_dims(descriptors1, axis=0).astype(
163
+ np.float32
164
+ ),
165
+ "scores0": np.expand_dims(
166
+ np.expand_dims(scores0, axis=0), axis=-1
167
+ ).astype(np.float32),
168
+ "scores1": np.expand_dims(
169
+ np.expand_dims(scores1, axis=0), axis=-1
170
+ ).astype(np.float32),
171
+ "descriptors0_dino": np.expand_dims(descriptors0_dino, axis=0),
172
+ "descriptors1_dino": np.expand_dims(descriptors1_dino, axis=0),
173
+ "width0": np.expand_dims(width0, axis=0).astype(np.int32),
174
+ "width1": np.expand_dims(width1, axis=0).astype(np.int32),
175
+ "height0": np.expand_dims(height0, axis=0).astype(np.int32),
176
+ "height1": np.expand_dims(height1, axis=0).astype(np.int32),
177
+ }
178
+ return inputs
third_party/omniglue/src/omniglue/superpoint_extract.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Wrapper for performing SuperPoint inference."""
16
+
17
+ import math
18
+ from typing import Optional, Tuple
19
+
20
+ import cv2
21
+ import numpy as np
22
+ from . import utils
23
+ import onnxruntime
24
+
25
+
26
+ class SuperPointExtract:
27
+ """Class to initialize SuperPoint model and extract features from an image.
28
+
29
+ To stay consistent with SuperPoint training and eval configurations, resize
30
+ images to (320x240) or (640x480).
31
+
32
+ Attributes
33
+ model_path: string, filepath to saved SuperPoint ONNX model weights.
34
+ """
35
+
36
+ def __init__(self, model_path: str):
37
+ self.model_path = model_path
38
+ self.net = onnxruntime.InferenceSession(self.model_path)
39
+
40
+ def __call__(
41
+ self,
42
+ image,
43
+ segmentation_mask=None,
44
+ num_features=1024,
45
+ pad_random_features=False,
46
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
47
+ return self.compute(
48
+ image,
49
+ segmentation_mask=segmentation_mask,
50
+ num_features=num_features,
51
+ pad_random_features=pad_random_features,
52
+ )
53
+
54
+ def compute(
55
+ self,
56
+ image: np.ndarray,
57
+ segmentation_mask: Optional[np.ndarray] = None,
58
+ num_features: int = 1024,
59
+ pad_random_features: bool = False,
60
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
61
+ """Feeds image through SuperPoint model to extract keypoints and features.
62
+
63
+ Args:
64
+ image: (H, W, 3) numpy array, decoded image bytes.
65
+ segmentation_mask: (H, W) binary numpy array or None. If not None,
66
+ extracted keypoints are restricted to being within the mask.
67
+ num_features: max number of features to extract (or 0 to indicate keeping
68
+ all extracted features).
69
+ pad_random_features: if True, adds randomly sampled keypoints to the
70
+ output such that there are exactly 'num_features' keypoints. Descriptors
71
+ for these sampled keypoints are taken from the network's descriptor map
72
+ output, and scores are set to 0. No action taken if num_features = 0.
73
+
74
+ Returns:
75
+ keypoints: (N, 2) numpy array, coordinates of keypoints as floats.
76
+ descriptors: (N, 256) numpy array, descriptors for keypoints as floats.
77
+ scores: (N, 1) numpy array, confidence values for keypoints as floats.
78
+ """
79
+
80
+ # Resize image so both dimensions are divisible by 8.
81
+ image, keypoint_scale_factors = self._resize_input_image(image)
82
+ if segmentation_mask is not None:
83
+ segmentation_mask, _ = self._resize_input_image(
84
+ segmentation_mask, interpolation=cv2.INTER_NEAREST
85
+ )
86
+ assert (
87
+ segmentation_mask is None
88
+ or image.shape[:2] == segmentation_mask.shape[:2]
89
+ )
90
+
91
+ # Preprocess and feed-forward image.
92
+ image_preprocessed = self._preprocess_image(image)
93
+ out = self.net.run(
94
+ None,
95
+ {
96
+ self.net.get_inputs()[0].name: np.expand_dims(
97
+ image_preprocessed, 0
98
+ )
99
+ },
100
+ )
101
+ # Format output from network.
102
+ keypoint_map = np.squeeze(out[5])
103
+ descriptor_map = np.squeeze(out[0])
104
+ if segmentation_mask is not None:
105
+ keypoint_map = np.where(segmentation_mask, keypoint_map, 0.0)
106
+ keypoints, descriptors, scores = self._extract_superpoint_output(
107
+ keypoint_map, descriptor_map, num_features, pad_random_features
108
+ )
109
+
110
+ # Rescale keypoint locations to match original input image size, and return.
111
+ keypoints = keypoints / keypoint_scale_factors
112
+ return (keypoints, descriptors, scores)
113
+
114
+ def _resize_input_image(self, image, interpolation=cv2.INTER_LINEAR):
115
+ """Resizes image such that both dimensions are divisble by 8."""
116
+
117
+ # Calculate new image dimensions and per-dimension resizing scale factor.
118
+ new_dim = [-1, -1]
119
+ keypoint_scale_factors = [1.0, 1.0]
120
+ for i in range(2):
121
+ dim_size = image.shape[i]
122
+ mod_eight = dim_size % 8
123
+ if mod_eight < 4:
124
+ # Round down to nearest multiple of 8.
125
+ new_dim[i] = dim_size - mod_eight
126
+ elif mod_eight >= 4:
127
+ # Round up to nearest multiple of 8.
128
+ new_dim[i] = dim_size + (8 - mod_eight)
129
+ keypoint_scale_factors[i] = (new_dim[i] - 1) / (dim_size - 1)
130
+
131
+ # Resize and return image + scale factors.
132
+ new_dim = new_dim[::-1] # Convert from (row, col) to (x,y).
133
+ keypoint_scale_factors = keypoint_scale_factors[::-1]
134
+ image = cv2.resize(image, tuple(new_dim), interpolation=interpolation)
135
+ return image, keypoint_scale_factors
136
+
137
+ def _preprocess_image(self, image):
138
+ """Converts image to grayscale and normalizes values for model input."""
139
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
140
+ image = np.expand_dims(image, 2)
141
+ image = image.astype(np.float32)
142
+ image = image / 255.0
143
+ return image
144
+
145
+ def _extract_superpoint_output(
146
+ self,
147
+ keypoint_map,
148
+ descriptor_map,
149
+ keep_k_points=512,
150
+ pad_random_features=False,
151
+ ):
152
+ """Converts from raw SuperPoint output (feature maps) into numpy arrays.
153
+
154
+ If keep_k_points is 0, then keep all detected keypoints. Otherwise, sort by
155
+ confidence and keep only the top k confidence keypoints.
156
+
157
+ Args:
158
+ keypoint_map: (H, W, 1) numpy array, raw output confidence values from
159
+ SuperPoint model.
160
+ descriptor_map: (H, W, 256) numpy array, raw output descriptors from
161
+ SuperPoint model.
162
+ keep_k_points: int, number of keypoints to keep (or 0 to indicate keeping
163
+ all detected keypoints).
164
+ pad_random_features: if True, adds randomly sampled keypoints to the
165
+ output such that there are exactly 'num_features' keypoints. Descriptors
166
+ for these sampled keypoints are taken from the network's descriptor map
167
+ output, and scores are set to 0. No action taken if keep_k_points = 0.
168
+
169
+ Returns:
170
+ keypoints: (N, 2) numpy array, image coordinates (x, y) of keypoints as
171
+ floats.
172
+ descriptors: (N, 256) numpy array, descriptors for keypoints as floats.
173
+ scores: (N, 1) numpy array, confidence values for keypoints as floats.
174
+ """
175
+
176
+ def _select_k_best(points, k):
177
+ sorted_prob = points[points[:, 2].argsort(), :]
178
+ start = min(k, points.shape[0])
179
+ return sorted_prob[-start:, :2], sorted_prob[-start:, 2]
180
+
181
+ keypoints = np.where(keypoint_map > 0)
182
+ prob = keypoint_map[keypoints[0], keypoints[1]]
183
+ keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1)
184
+
185
+ # Keep only top k points, or all points if keep_k_points param is 0.
186
+ if keep_k_points == 0:
187
+ keep_k_points = keypoints.shape[0]
188
+ keypoints, scores = _select_k_best(keypoints, keep_k_points)
189
+
190
+ # Optionally, pad with random features (and confidence scores of 0).
191
+ image_shape = np.array(keypoint_map.shape[:2])
192
+ if pad_random_features and (keep_k_points > keypoints.shape[0]):
193
+ num_pad = keep_k_points - keypoints.shape[0]
194
+ keypoints_pad = (image_shape - 1) * np.random.uniform(
195
+ size=(num_pad, 2)
196
+ )
197
+ keypoints = np.concatenate((keypoints, keypoints_pad))
198
+ scores_pad = np.zeros((num_pad))
199
+ scores = np.concatenate((scores, scores_pad))
200
+
201
+ # Lookup descriptors via bilinear interpolation.
202
+ # TODO: batch descriptor lookup with bilinear interpolation.
203
+ keypoints[:, [0, 1]] = keypoints[
204
+ :, [1, 0]
205
+ ] # Swap from (row,col) to (x,y).
206
+ descriptors = []
207
+ for kp in keypoints:
208
+ descriptors.append(
209
+ utils.lookup_descriptor_bilinear(kp, descriptor_map)
210
+ )
211
+ descriptors = np.array(descriptors)
212
+ return keypoints, descriptors, scores
third_party/omniglue/src/omniglue/utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Shared utility functions for OmniGlue."""
16
+ import cv2
17
+ import torch
18
+ import math
19
+ import numpy as np
20
+ from typing import Optional
21
+
22
+
23
+ def lookup_descriptor_bilinear(
24
+ keypoint: np.ndarray, descriptor_map: np.ndarray
25
+ ) -> np.ndarray:
26
+ """Looks up descriptor value for keypoint from a dense descriptor map.
27
+
28
+ Uses bilinear interpolation to find descriptor value at non-integer
29
+ positions.
30
+
31
+ Args:
32
+ keypoint: 2-dim numpy array containing (x, y) keypoint image coordinates.
33
+ descriptor_map: (H, W, D) numpy array representing a dense descriptor map.
34
+
35
+ Returns:
36
+ D-dim descriptor value at the input 'keypoint' location.
37
+
38
+ Raises:
39
+ ValueError, if kepoint position is out of bounds.
40
+ """
41
+ height, width = descriptor_map.shape[:2]
42
+ if (
43
+ keypoint[0] < 0
44
+ or keypoint[0] > width
45
+ or keypoint[1] < 0
46
+ or keypoint[1] > height
47
+ ):
48
+ raise ValueError(
49
+ "Keypoint position (%f, %f) is out of descriptor map bounds (%i w x"
50
+ " %i h)." % (keypoint[0], keypoint[1], width, height)
51
+ )
52
+
53
+ x_range = [math.floor(keypoint[0])]
54
+ if not keypoint[0].is_integer() and keypoint[0] < width - 1:
55
+ x_range.append(x_range[0] + 1)
56
+ y_range = [math.floor(keypoint[1])]
57
+ if not keypoint[1].is_integer() and keypoint[1] < height - 1:
58
+ y_range.append(y_range[0] + 1)
59
+
60
+ bilinear_descriptor = np.zeros(descriptor_map.shape[2])
61
+ for curr_x in x_range:
62
+ for curr_y in y_range:
63
+ curr_descriptor = descriptor_map[curr_y, curr_x, :]
64
+ bilinear_scalar = (1.0 - abs(keypoint[0] - curr_x)) * (
65
+ 1.0 - abs(keypoint[1] - curr_y)
66
+ )
67
+ bilinear_descriptor += bilinear_scalar * curr_descriptor
68
+ return bilinear_descriptor
69
+
70
+
71
+ def soft_assignment_to_match_matrix(
72
+ soft_assignment: torch.Tensor, match_threshold: float
73
+ ) -> torch.Tensor:
74
+ """Converts a matrix of soft assignment values to binary yes/no match matrix.
75
+
76
+ Searches soft_assignment for row- and column-maximum values, which indicate
77
+ mutual nearest neighbor matches between two unique sets of keypoints. Also,
78
+ ensures that score values for matches are above the specified threshold.
79
+
80
+ Args:
81
+ soft_assignment: (B, N, M) tensor, contains matching likelihood value
82
+ between features of different sets. N is number of features in image0, and
83
+ M is number of features in image1. Higher value indicates more likely to
84
+ match.
85
+ match_threshold: float, thresholding value to consider a match valid.
86
+
87
+ Returns:
88
+ (B, N, M) tensor of binary values. A value of 1 at index (x, y) indicates
89
+ a match between index 'x' (out of N) in image0 and index 'y' (out of M) in
90
+ image 1.
91
+ """
92
+
93
+ def _range_like(x, dim):
94
+ return torch.arange(x.shape[dim], dtype=x.dtype)
95
+
96
+ matches = []
97
+ for i in range(soft_assignment.shape[0]):
98
+ scores = soft_assignment[i, :].unsqueeze(0)
99
+
100
+ max0 = torch.max(scores, dim=2)[0]
101
+ indices0 = torch.argmax(scores, dim=2)
102
+ indices1 = torch.argmax(scores, dim=1)
103
+
104
+ mutual = _range_like(indices0, 1).unsqueeze(0) == indices1.gather(
105
+ 1, indices0
106
+ )
107
+
108
+ kp_ind_pairs = torch.stack(
109
+ [_range_like(indices0, 1), indices0.squeeze()], dim=1
110
+ )
111
+ mutual_max0 = torch.where(
112
+ mutual, max0, torch.zeros_like(max0)
113
+ ).squeeze()
114
+ sparse = torch.sparse_coo_tensor(
115
+ kp_ind_pairs.t(), mutual_max0, scores.shape[1:]
116
+ )
117
+ match_matrix = sparse.to_dense()
118
+ matches.append(match_matrix)
119
+
120
+ match_matrix = torch.stack(matches)
121
+ match_matrix = match_matrix > match_threshold
122
+ return match_matrix
123
+
124
+
125
+ def visualize_matches(
126
+ image0: np.ndarray,
127
+ image1: np.ndarray,
128
+ kp0: np.ndarray,
129
+ kp1: np.ndarray,
130
+ match_matrix: np.ndarray,
131
+ match_labels: Optional[np.ndarray] = None,
132
+ show_keypoints: bool = False,
133
+ highlight_unmatched: bool = False,
134
+ title: Optional[str] = None,
135
+ line_width: int = 1,
136
+ circle_radius: int = 4,
137
+ circle_thickness: int = 2,
138
+ rng: Optional["np.random.Generator"] = None,
139
+ ):
140
+ """Generates visualization of keypoints and matches for two images.
141
+
142
+ Stacks image0 and image1 horizontally. In case the two images have different
143
+ heights, scales image1 (and its keypoints) to match image0's height. Note
144
+ that keypoints must be in (x, y) format, NOT (row, col). If match_matrix
145
+ includes unmatched dustbins, the dustbins will be removed before visualizing
146
+ matches.
147
+
148
+ Args:
149
+ image0: (H, W, 3) array containing image0 contents.
150
+ image1: (H, W, 3) array containing image1 contents.
151
+ kp0: (N, 2) array where each row represents (x, y) coordinates of keypoints
152
+ in image0.
153
+ kp1: (M, 2) array, where each row represents (x, y) coordinates of keypoints
154
+ in image1.
155
+ match_matrix: (N, M) binary array, where values are non-zero for keypoint
156
+ indices making up a match.
157
+ match_labels: (N, M) binary array, where values are non-zero for keypoint
158
+ indices making up a ground-truth match. When None, matches from
159
+ 'match_matrix' are colored randomly. Otherwise, matches from
160
+ 'match_matrix' are colored according to accuracy (compared to labels).
161
+ show_keypoints: if True, all image0 and image1 keypoints (including
162
+ unmatched ones) are visualized.
163
+ highlight_unmatched: if True, highlights unmatched keypoints in blue.
164
+ title: if not None, adds title text to top left of visualization.
165
+ line_width: width of correspondence line, in pixels.
166
+ circle_radius: radius of keypoint circles, if visualized.
167
+ circle_thickness: thickness of keypoint circles, if visualized.
168
+ rng: np random number generator to generate the line colors.
169
+
170
+ Returns:
171
+ Numpy array of image0 and image1 side-by-side, with lines between matches
172
+ according to match_matrix. If show_keypoints is True, keypoints from both
173
+ images are also visualized.
174
+ """
175
+ # initialize RNG
176
+ if rng is None:
177
+ rng = np.random.default_rng()
178
+
179
+ # Make copy of input param that may be modified in this function.
180
+ kp1 = np.copy(kp1)
181
+
182
+ # Detect unmatched dustbins.
183
+ has_unmatched_dustbins = (match_matrix.shape[0] == kp0.shape[0] + 1) and (
184
+ match_matrix.shape[1] == kp1.shape[0] + 1
185
+ )
186
+
187
+ # If necessary, resize image1 so that the pair can be stacked horizontally.
188
+ height0 = image0.shape[0]
189
+ height1 = image1.shape[0]
190
+ if height0 != height1:
191
+ scale_factor = height0 / height1
192
+ if scale_factor <= 1.0:
193
+ interp_method = cv2.INTER_AREA
194
+ else:
195
+ interp_method = cv2.INTER_LINEAR
196
+ new_dim1 = (int(image1.shape[1] * scale_factor), height0)
197
+ image1 = cv2.resize(image1, new_dim1, interpolation=interp_method)
198
+ kp1 *= scale_factor
199
+
200
+ # Create side-by-side image and add lines for all matches.
201
+ viz = cv2.hconcat([image0, image1])
202
+ w0 = image0.shape[1]
203
+ matches = np.argwhere(
204
+ match_matrix[:-1, :-1] if has_unmatched_dustbins else match_matrix
205
+ )
206
+ for match in matches:
207
+ mpt0 = kp0[match[0]]
208
+ mpt1 = kp1[match[1]]
209
+ if isinstance(mpt0, torch.Tensor):
210
+ mpt0 = mpt0.numpy()
211
+ if isinstance(mpt1, torch.Tensor):
212
+ mpt1 = mpt1.numpy()
213
+ pt0 = (int(mpt0[0]), int(mpt0[1]))
214
+ pt1 = (int(mpt1[0] + w0), int(mpt1[1]))
215
+ if match_labels is None:
216
+ color = tuple(rng.integers(0, 255, size=3).tolist())
217
+ else:
218
+ if match_labels[match[0], match[1]]:
219
+ color = (0, 255, 0)
220
+ else:
221
+ color = (255, 0, 0)
222
+ cv2.line(viz, pt0, pt1, color, line_width)
223
+
224
+ # Optionally, add circles to output image to represent each keypoint.
225
+ if show_keypoints:
226
+ for i in range(np.shape(kp0)[0]):
227
+ kp = kp0[i].numpy() if isinstance(kp0[i], torch.Tensor) else kp0[i]
228
+ if (
229
+ highlight_unmatched
230
+ and has_unmatched_dustbins
231
+ and match_matrix[i, -1]
232
+ ):
233
+ cv2.circle(
234
+ viz,
235
+ tuple(kp.astype(np.int32).tolist()),
236
+ circle_radius,
237
+ (255, 0, 0),
238
+ circle_thickness,
239
+ )
240
+ else:
241
+ cv2.circle(
242
+ viz,
243
+ tuple(kp.astype(np.int32).tolist()),
244
+ circle_radius,
245
+ (0, 0, 255),
246
+ circle_thickness,
247
+ )
248
+ for j in range(np.shape(kp1)[0]):
249
+ kp = kp1[j].numpy() if isinstance(kp1[j], torch.Tensor) else kp1[j]
250
+ kp[0] += w0
251
+ if (
252
+ highlight_unmatched
253
+ and has_unmatched_dustbins
254
+ and match_matrix[-1, j]
255
+ ):
256
+ cv2.circle(
257
+ viz,
258
+ tuple(kp.astype(np.int32).tolist()),
259
+ circle_radius,
260
+ (255, 0, 0),
261
+ circle_thickness,
262
+ )
263
+ else:
264
+ cv2.circle(
265
+ viz,
266
+ tuple(kp.astype(np.int32).tolist()),
267
+ circle_radius,
268
+ (0, 0, 255),
269
+ circle_thickness,
270
+ )
271
+ if title is not None:
272
+ viz = cv2.putText(
273
+ viz,
274
+ title,
275
+ (5, 30),
276
+ cv2.FONT_HERSHEY_SIMPLEX,
277
+ 1,
278
+ (0, 0, 255),
279
+ 2,
280
+ cv2.LINE_AA,
281
+ )
282
+ return viz
third_party/omniglue/third_party/dinov2/__init__.py ADDED
File without changes
third_party/omniglue/third_party/dinov2/dino.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ from functools import partial
12
+ import math
13
+ from typing import Callable, Sequence, Tuple, Union
14
+
15
+ from third_party.dinov2 import dino_utils
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn.init import trunc_normal_
19
+ import torch.utils.checkpoint
20
+
21
+
22
+ def named_apply(
23
+ fn: Callable,
24
+ module: nn.Module,
25
+ name="",
26
+ depth_first=True,
27
+ include_root=False,
28
+ ) -> nn.Module:
29
+ if not depth_first and include_root:
30
+ fn(module=module, name=name)
31
+ for child_name, child_module in module.named_children():
32
+ child_name = ".".join((name, child_name)) if name else child_name
33
+ named_apply(
34
+ fn=fn,
35
+ module=child_module,
36
+ name=child_name,
37
+ depth_first=depth_first,
38
+ include_root=True,
39
+ )
40
+ if depth_first and include_root:
41
+ fn(module=module, name=name)
42
+ return module
43
+
44
+
45
+ class BlockChunk(nn.ModuleList):
46
+
47
+ def forward(self, x):
48
+ for b in self:
49
+ x = b(x)
50
+ return x
51
+
52
+
53
+ class DinoVisionTransformer(nn.Module):
54
+
55
+ def __init__(
56
+ self,
57
+ img_size=518,
58
+ patch_size=16,
59
+ in_chans=3,
60
+ embed_dim=768,
61
+ depth=12,
62
+ num_heads=12,
63
+ mlp_ratio=4.0,
64
+ qkv_bias=True,
65
+ ffn_bias=True,
66
+ proj_bias=True,
67
+ drop_path_rate=0.0,
68
+ drop_path_uniform=False,
69
+ init_values=None, # for layerscale: None or 0 => no layerscale
70
+ embed_layer=dino_utils.PatchEmbed,
71
+ act_layer=nn.GELU,
72
+ block_fn=dino_utils.Block,
73
+ ffn_layer="mlp",
74
+ block_chunks=0,
75
+ ):
76
+ """Args:
77
+
78
+ img_size (int, tuple): input image size
79
+ patch_size (int, tuple): patch size
80
+ in_chans (int): number of input channels
81
+ embed_dim (int): embedding dimension
82
+ depth (int): depth of transformer
83
+ num_heads (int): number of attention heads
84
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
85
+ qkv_bias (bool): enable bias for qkv if True
86
+ proj_bias (bool): enable bias for proj in attn if True
87
+ ffn_bias (bool): enable bias for ffn if True
88
+ drop_path_rate (float): stochastic depth rate
89
+ drop_path_uniform (bool): apply uniform drop rate across blocks
90
+ weight_init (str): weight init scheme
91
+ init_values (float): layer-scale init values
92
+ embed_layer (nn.Module): patch embedding layer
93
+ act_layer (nn.Module): MLP activation layer
94
+ block_fn (nn.Module): transformer block class
95
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
96
+ block_chunks: (int) split block sequence into block_chunks units for
97
+ FSDP wrap
98
+ """
99
+ super().__init__()
100
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
101
+
102
+ self.num_features = self.embed_dim = (
103
+ embed_dim # num_features for consistency with other models
104
+ )
105
+ self.num_tokens = 1
106
+ self.n_blocks = depth
107
+ self.num_heads = num_heads
108
+ self.patch_size = patch_size
109
+
110
+ self.patch_embed = embed_layer(
111
+ img_size=img_size,
112
+ patch_size=patch_size,
113
+ in_chans=in_chans,
114
+ embed_dim=embed_dim,
115
+ )
116
+ num_patches = self.patch_embed.num_patches
117
+
118
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
119
+ self.pos_embed = nn.Parameter(
120
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
121
+ )
122
+
123
+ if drop_path_uniform is True:
124
+ dpr = [drop_path_rate] * depth
125
+ else:
126
+ dpr = [
127
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
128
+ ] # stochastic depth decay rule
129
+
130
+ if ffn_layer == "mlp":
131
+ ffn_layer = dino_utils.Mlp
132
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
133
+ # ffn_layer = SwiGLUFFNFused
134
+ raise NotImplementedError("FFN only support mlp but using swiglu")
135
+ elif ffn_layer == "identity":
136
+
137
+ def f(*args, **kwargs):
138
+ return nn.Identity()
139
+
140
+ ffn_layer = f
141
+ else:
142
+ raise NotImplementedError
143
+
144
+ blocks_list = [
145
+ block_fn(
146
+ dim=embed_dim,
147
+ num_heads=num_heads,
148
+ mlp_ratio=mlp_ratio,
149
+ qkv_bias=qkv_bias,
150
+ proj_bias=proj_bias,
151
+ ffn_bias=ffn_bias,
152
+ drop_path=dpr[i],
153
+ norm_layer=norm_layer,
154
+ act_layer=act_layer,
155
+ ffn_layer=ffn_layer,
156
+ init_values=init_values,
157
+ )
158
+ for i in range(depth)
159
+ ]
160
+ if block_chunks > 0:
161
+ self.chunked_blocks = True
162
+ chunked_blocks = []
163
+ chunksize = depth // block_chunks
164
+ for i in range(0, depth, chunksize):
165
+ # this is to keep the block index consistent if we chunk the block list
166
+ chunked_blocks.append(
167
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
168
+ )
169
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
170
+ else:
171
+ self.chunked_blocks = False
172
+ self.blocks = nn.ModuleList(blocks_list)
173
+
174
+ self.norm = norm_layer(embed_dim)
175
+ self.head = nn.Identity()
176
+
177
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
178
+
179
+ self.init_weights()
180
+
181
+ def init_weights(self):
182
+ trunc_normal_(self.pos_embed, std=0.02)
183
+ nn.init.normal_(self.cls_token, std=1e-6)
184
+ named_apply(init_weights_vit_timm, self)
185
+
186
+ def interpolate_pos_encoding(self, x, w, h):
187
+ previous_dtype = x.dtype
188
+ npatch = x.shape[1] - 1
189
+ N = self.pos_embed.shape[1] - 1
190
+ if npatch == N and w == h:
191
+ return self.pos_embed
192
+ pos_embed = self.pos_embed.float()
193
+ class_pos_embed = pos_embed[:, 0]
194
+ patch_pos_embed = pos_embed[:, 1:]
195
+ dim = x.shape[-1]
196
+ w0 = w // self.patch_size
197
+ h0 = h // self.patch_size
198
+ # we add a small number to avoid floating point error in the interpolation
199
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
200
+ w0, h0 = w0 + 0.1, h0 + 0.1
201
+
202
+ patch_pos_embed = nn.functional.interpolate(
203
+ patch_pos_embed.reshape(
204
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
205
+ ).permute(0, 3, 1, 2),
206
+ size=None,
207
+ scale_factor=[w0 / math.sqrt(N), h0 / math.sqrt(N)],
208
+ mode="bicubic",
209
+ )
210
+
211
+ assert (
212
+ int(w0) == patch_pos_embed.shape[-2]
213
+ and int(h0) == patch_pos_embed.shape[-1]
214
+ )
215
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
216
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
217
+ previous_dtype
218
+ )
219
+
220
+ def prepare_tokens_with_masks(self, x, masks=None):
221
+ B, nc, w, h = x.shape
222
+ x = self.patch_embed(x)
223
+ if masks is not None:
224
+ x = torch.where(
225
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
226
+ )
227
+
228
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
229
+ x = x + self.interpolate_pos_encoding(x, w, h)
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [
235
+ self.prepare_tokens_with_masks(x, masks)
236
+ for x, masks in zip(x_list, masks_list)
237
+ ]
238
+ for blk in self.blocks:
239
+ x = blk(x)
240
+
241
+ all_x = x
242
+ output = []
243
+ for x, masks in zip(all_x, masks_list):
244
+ x_norm = self.norm(x)
245
+ output.append({
246
+ "x_norm_clstoken": x_norm[:, 0],
247
+ "x_norm_patchtokens": x_norm[:, 1:],
248
+ "x_prenorm": x,
249
+ "masks": masks,
250
+ })
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_patchtokens": x_norm[:, 1:],
266
+ "x_prenorm": x,
267
+ "masks": masks,
268
+ }
269
+
270
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
271
+ x = self.prepare_tokens_with_masks(x)
272
+ # If n is an int, take the n last blocks. If it's a list, take them
273
+ output, total_block_len = [], len(self.blocks)
274
+ blocks_to_take = (
275
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ )
277
+ for i, blk in enumerate(self.blocks):
278
+ x = blk(x)
279
+ if i in blocks_to_take:
280
+ output.append(x)
281
+ assert len(output) == len(
282
+ blocks_to_take
283
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
284
+ return output
285
+
286
+ def _get_intermediate_layers_chunked(self, x, n=1):
287
+ x = self.prepare_tokens_with_masks(x)
288
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
289
+ # If n is an int, take the n last blocks. If it's a list, take them
290
+ blocks_to_take = (
291
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
292
+ )
293
+ for block_chunk in self.blocks:
294
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
295
+ x = blk(x)
296
+ if i in blocks_to_take:
297
+ output.append(x)
298
+ i += 1
299
+ assert len(output) == len(
300
+ blocks_to_take
301
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
302
+ return output
303
+
304
+ def get_intermediate_layers(
305
+ self,
306
+ x: torch.Tensor,
307
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
308
+ reshape: bool = False,
309
+ return_class_token: bool = False,
310
+ norm=True,
311
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
312
+ if self.chunked_blocks:
313
+ outputs = self._get_intermediate_layers_chunked(x, n)
314
+ else:
315
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
316
+ if norm:
317
+ outputs = [self.norm(out) for out in outputs]
318
+ class_tokens = [out[:, 0] for out in outputs]
319
+ outputs = [out[:, 1:] for out in outputs]
320
+ if reshape:
321
+ B, _, w, h = x.shape
322
+ outputs = [
323
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
324
+ .permute(0, 3, 1, 2)
325
+ .contiguous()
326
+ for out in outputs
327
+ ]
328
+ if return_class_token:
329
+ return tuple(zip(outputs, class_tokens))
330
+ return tuple(outputs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.get_intermediate_layers(
334
+ x, n=1, reshape=True, return_class_token=False, norm=True
335
+ )[0]
336
+
337
+ # def forward(self, *args, is_training=False, **kwargs):
338
+ # ret = self.forward_features(*args, **kwargs)
339
+ # if is_training:
340
+ # return ret
341
+ # else:
342
+ # return self.head(ret["x_norm_clstoken"])
343
+
344
+
345
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
346
+ """ViT weight initialization, original timm impl (for reproducibility)"""
347
+ if isinstance(module, nn.Linear):
348
+ trunc_normal_(module.weight, std=0.02)
349
+ if module.bias is not None:
350
+ nn.init.zeros_(module.bias)
351
+
352
+
353
+ def vit_small(patch_size=14, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ img_size=518,
356
+ patch_size=patch_size,
357
+ embed_dim=384,
358
+ depth=12,
359
+ num_heads=6,
360
+ mlp_ratio=4,
361
+ init_values=1e-5,
362
+ block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
363
+ **kwargs,
364
+ )
365
+ return model
366
+
367
+
368
+ def vit_base(patch_size=14, **kwargs):
369
+ model = DinoVisionTransformer(
370
+ img_size=518,
371
+ patch_size=patch_size,
372
+ embed_dim=768,
373
+ depth=12,
374
+ num_heads=12,
375
+ mlp_ratio=4,
376
+ init_values=1e-5,
377
+ block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
378
+ **kwargs,
379
+ )
380
+ return model
381
+
382
+
383
+ def vit_large(patch_size=14, **kwargs):
384
+ model = DinoVisionTransformer(
385
+ img_size=518,
386
+ patch_size=patch_size,
387
+ embed_dim=1024,
388
+ depth=24,
389
+ num_heads=16,
390
+ mlp_ratio=4,
391
+ init_values=1e-5,
392
+ block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def vit_giant2(patch_size=14, **kwargs):
399
+ """Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64"""
400
+ model = DinoVisionTransformer(
401
+ img_size=518,
402
+ patch_size=patch_size,
403
+ embed_dim=1536,
404
+ depth=40,
405
+ num_heads=24,
406
+ mlp_ratio=4,
407
+ init_values=1e-5,
408
+ block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention),
409
+ **kwargs,
410
+ )
411
+ return model
third_party/omniglue/third_party/dinov2/dino_utils.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+ #
6
+ # References:
7
+ # https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/eval/segmentation_m2f/models/backbones/vit.py
8
+
9
+ from typing import Callable, Optional, Tuple, Union
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ class Mlp(nn.Module):
16
+
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
41
+
42
+
43
+ def make_2tuple(x):
44
+ if isinstance(x, tuple):
45
+ assert len(x) == 2
46
+ return x
47
+
48
+ assert isinstance(x, int)
49
+ return (x, x)
50
+
51
+
52
+ class PatchEmbed(nn.Module):
53
+ """2D image to patch embedding: (B,C,H,W) -> (B,N,D)
54
+
55
+ Args:
56
+ img_size: Image size.
57
+ patch_size: Patch token size.
58
+ in_chans: Number of input image channels.
59
+ embed_dim: Number of linear projection output channels.
60
+ norm_layer: Normalization layer.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ img_size: Union[int, Tuple[int, int]] = 224,
66
+ patch_size: Union[int, Tuple[int, int]] = 16,
67
+ in_chans: int = 3,
68
+ embed_dim: int = 768,
69
+ norm_layer: Optional[Callable] = None,
70
+ flatten_embedding: bool = True,
71
+ ) -> None:
72
+ super().__init__()
73
+
74
+ image_HW = make_2tuple(img_size)
75
+ patch_HW = make_2tuple(patch_size)
76
+ patch_grid_size = (
77
+ image_HW[0] // patch_HW[0],
78
+ image_HW[1] // patch_HW[1],
79
+ )
80
+
81
+ self.img_size = image_HW
82
+ self.patch_size = patch_HW
83
+ self.patches_resolution = patch_grid_size
84
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
85
+
86
+ self.in_chans = in_chans
87
+ self.embed_dim = embed_dim
88
+
89
+ self.flatten_embedding = flatten_embedding
90
+
91
+ self.proj = nn.Conv2d(
92
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
93
+ )
94
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ _, _, H, W = x.shape
98
+ patch_H, patch_W = self.patch_size
99
+
100
+ assert (
101
+ H % patch_H == 0
102
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
103
+ assert (
104
+ W % patch_W == 0
105
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
106
+
107
+ x = self.proj(x) # B C H W
108
+ H, W = x.size(2), x.size(3)
109
+ x = x.flatten(2).transpose(1, 2) # B HW C
110
+ x = self.norm(x)
111
+ if not self.flatten_embedding:
112
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
113
+ return x
114
+
115
+ def flops(self) -> float:
116
+ Ho, Wo = self.patches_resolution
117
+ flops = (
118
+ Ho
119
+ * Wo
120
+ * self.embed_dim
121
+ * self.in_chans
122
+ * (self.patch_size[0] * self.patch_size[1])
123
+ )
124
+ if self.norm is not None:
125
+ flops += Ho * Wo * self.embed_dim
126
+ return flops
127
+
128
+
129
+ XFORMERS_AVAILABLE = False
130
+
131
+
132
+ class Attention(nn.Module):
133
+
134
+ def __init__(
135
+ self,
136
+ dim: int,
137
+ num_heads: int = 8,
138
+ qkv_bias: bool = False,
139
+ proj_bias: bool = True,
140
+ attn_drop: float = 0.0,
141
+ proj_drop: float = 0.0,
142
+ ) -> None:
143
+ super().__init__()
144
+ self.num_heads = num_heads
145
+ head_dim = dim // num_heads
146
+ self.scale = head_dim**-0.5
147
+
148
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
149
+ self.attn_drop = nn.Dropout(attn_drop)
150
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
151
+ self.proj_drop = nn.Dropout(proj_drop)
152
+
153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
+ B, N, C = x.shape
155
+ qkv = (
156
+ self.qkv(x)
157
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
158
+ .permute(2, 0, 3, 1, 4)
159
+ )
160
+
161
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
162
+ attn = q @ k.transpose(-2, -1)
163
+
164
+ attn = attn.softmax(dim=-1)
165
+ attn = self.attn_drop(attn)
166
+
167
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
168
+ x = self.proj(x)
169
+ x = self.proj_drop(x)
170
+ return x
171
+
172
+
173
+ class MemEffAttention(Attention):
174
+
175
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
176
+ if not XFORMERS_AVAILABLE:
177
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
178
+ return super().forward(x)
179
+ else:
180
+ raise NotImplementedError("MemEffAttention do not support xFormer")
181
+ # B, N, C = x.shape
182
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
183
+
184
+ # q, k, v = unbind(qkv, 2)
185
+
186
+ # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
187
+ # x = x.reshape([B, N, C])
188
+
189
+ # x = self.proj(x)
190
+ # x = self.proj_drop(x)
191
+ # return x
192
+
193
+
194
+ class LayerScale(nn.Module):
195
+
196
+ def __init__(
197
+ self,
198
+ dim: int,
199
+ init_values: Union[float, torch.Tensor] = 1e-5,
200
+ inplace: bool = False,
201
+ ) -> None:
202
+ super().__init__()
203
+ self.inplace = inplace
204
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
208
+
209
+
210
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
211
+ if drop_prob == 0.0 or not training:
212
+ return x
213
+ keep_prob = 1 - drop_prob
214
+ shape = (x.shape[0],) + (1,) * (
215
+ x.ndim - 1
216
+ ) # work with diff dim tensors, not just 2D ConvNets
217
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
218
+ if keep_prob > 0.0:
219
+ random_tensor.div_(keep_prob)
220
+ output = x * random_tensor
221
+ return output
222
+
223
+
224
+ class DropPath(nn.Module):
225
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
226
+
227
+ def __init__(self, drop_prob=None):
228
+ super(DropPath, self).__init__()
229
+ self.drop_prob = drop_prob
230
+
231
+ def forward(self, x):
232
+ return drop_path(x, self.drop_prob, self.training)
233
+
234
+
235
+ class Block(nn.Module):
236
+
237
+ def __init__(
238
+ self,
239
+ dim: int,
240
+ num_heads: int,
241
+ mlp_ratio: float = 4.0,
242
+ qkv_bias: bool = False,
243
+ proj_bias: bool = True,
244
+ ffn_bias: bool = True,
245
+ drop: float = 0.0,
246
+ attn_drop: float = 0.0,
247
+ init_values=None,
248
+ drop_path: float = 0.0,
249
+ act_layer: Callable[..., nn.Module] = nn.GELU,
250
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
251
+ attn_class: Callable[..., nn.Module] = Attention,
252
+ ffn_layer: Callable[..., nn.Module] = Mlp,
253
+ ) -> None:
254
+ super().__init__()
255
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
256
+ self.norm1 = norm_layer(dim)
257
+ self.attn = attn_class(
258
+ dim,
259
+ num_heads=num_heads,
260
+ qkv_bias=qkv_bias,
261
+ proj_bias=proj_bias,
262
+ attn_drop=attn_drop,
263
+ proj_drop=drop,
264
+ )
265
+ self.ls1 = (
266
+ LayerScale(dim, init_values=init_values)
267
+ if init_values
268
+ else nn.Identity()
269
+ )
270
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
271
+
272
+ self.norm2 = norm_layer(dim)
273
+ mlp_hidden_dim = int(dim * mlp_ratio)
274
+ self.mlp = ffn_layer(
275
+ in_features=dim,
276
+ hidden_features=mlp_hidden_dim,
277
+ act_layer=act_layer,
278
+ drop=drop,
279
+ bias=ffn_bias,
280
+ )
281
+ self.ls2 = (
282
+ LayerScale(dim, init_values=init_values)
283
+ if init_values
284
+ else nn.Identity()
285
+ )
286
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
287
+
288
+ self.sample_drop_ratio = drop_path
289
+
290
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
291
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
292
+ return self.ls1(self.attn(self.norm1(x)))
293
+
294
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
295
+ return self.ls2(self.mlp(self.norm2(x)))
296
+
297
+ if self.training and self.sample_drop_ratio > 0.1:
298
+ # the overhead is compensated only for a drop path rate larger than 0.1
299
+ x = drop_add_residual_stochastic_depth(
300
+ x,
301
+ residual_func=attn_residual_func,
302
+ sample_drop_ratio=self.sample_drop_ratio,
303
+ )
304
+ x = drop_add_residual_stochastic_depth(
305
+ x,
306
+ residual_func=ffn_residual_func,
307
+ sample_drop_ratio=self.sample_drop_ratio,
308
+ )
309
+ elif self.training and self.sample_drop_ratio > 0.0:
310
+ x = x + self.drop_path1(attn_residual_func(x))
311
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
312
+ else:
313
+ x = x + attn_residual_func(x)
314
+ x = x + ffn_residual_func(x)
315
+ return x
316
+
317
+
318
+ def drop_add_residual_stochastic_depth(
319
+ x: torch.Tensor,
320
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
321
+ sample_drop_ratio: float = 0.0,
322
+ ) -> torch.Tensor:
323
+ # 1) extract subset using permutation
324
+ b, n, d = x.shape
325
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
326
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
327
+ x_subset = x[brange]
328
+
329
+ # 2) apply residual_func to get residual
330
+ residual = residual_func(x_subset)
331
+
332
+ x_flat = x.flatten(1)
333
+ residual = residual.flatten(1)
334
+
335
+ residual_scale_factor = b / sample_subset_size
336
+
337
+ # 3) add the residual
338
+ x_plus_residual = torch.index_add(
339
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
340
+ )
341
+ return x_plus_residual.view_as(x)