Vincentqyw commited on
Commit
2947428
1 Parent(s): a9f1fc6
common/utils.py CHANGED
@@ -49,7 +49,7 @@ def gen_examples():
49
  "topicfm",
50
  "superpoint+superglue",
51
  "disk+dualsoftmax",
52
- "lanet",
53
  ]
54
 
55
  def gen_images_pairs(path: str, count: int = 5):
@@ -452,12 +452,11 @@ ransac_zoo = {
452
 
453
  # Matchers collections
454
  matcher_zoo = {
455
- "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
456
- "sold2": {"config": match_dense.confs["sold2"], "dense": True},
457
  # 'dedode-sparse': {
458
  # 'config': match_dense.confs['dedode_sparse'],
459
  # 'dense': True # dense mode, we need 2 images
460
  # },
 
461
  "loftr": {"config": match_dense.confs["loftr"], "dense": True},
462
  "topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
463
  "aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
@@ -556,6 +555,7 @@ matcher_zoo = {
556
  "config_feature": extract_features.confs["sift"],
557
  "dense": False,
558
  },
559
- # "roma": {"config": match_dense.confs["roma"], "dense": True},
 
560
  # "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
561
  }
 
49
  "topicfm",
50
  "superpoint+superglue",
51
  "disk+dualsoftmax",
52
+ "roma",
53
  ]
54
 
55
  def gen_images_pairs(path: str, count: int = 5):
 
452
 
453
  # Matchers collections
454
  matcher_zoo = {
 
 
455
  # 'dedode-sparse': {
456
  # 'config': match_dense.confs['dedode_sparse'],
457
  # 'dense': True # dense mode, we need 2 images
458
  # },
459
+ "roma": {"config": match_dense.confs["roma"], "dense": True},
460
  "loftr": {"config": match_dense.confs["loftr"], "dense": True},
461
  "topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
462
  "aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
 
555
  "config_feature": extract_features.confs["sift"],
556
  "dense": False,
557
  },
558
+ "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
559
+ "sold2": {"config": match_dense.confs["sold2"], "dense": True},
560
  # "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
561
  }
third_party/Roma/roma/models/encoders.py CHANGED
@@ -6,59 +6,37 @@ import torch.nn.functional as F
6
  import torchvision.models as tvm
7
  import gc
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
 
12
  class ResNet50(nn.Module):
13
- def __init__(
14
- self,
15
- pretrained=False,
16
- high_res=False,
17
- weights=None,
18
- dilation=None,
19
- freeze_bn=True,
20
- anti_aliased=False,
21
- early_exit=False,
22
- amp=False,
23
- ) -> None:
24
  super().__init__()
25
  if dilation is None:
26
- dilation = [False, False, False]
27
  if anti_aliased:
28
  pass
29
  else:
30
  if weights is not None:
31
- self.net = tvm.resnet50(
32
- weights=weights, replace_stride_with_dilation=dilation
33
- )
34
  else:
35
- self.net = tvm.resnet50(
36
- pretrained=pretrained, replace_stride_with_dilation=dilation
37
- )
38
-
39
  self.high_res = high_res
40
  self.freeze_bn = freeze_bn
41
  self.early_exit = early_exit
42
  self.amp = amp
43
- if torch.cuda.is_available():
44
- if torch.cuda.is_bf16_supported():
45
- self.amp_dtype = torch.bfloat16
46
- else:
47
- self.amp_dtype = torch.float16
48
- else:
49
- self.amp_dtype = torch.float32
50
 
51
  def forward(self, x, **kwargs):
52
- with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
53
  net = self.net
54
- feats = {1: x}
55
  x = net.conv1(x)
56
  x = net.bn1(x)
57
  x = net.relu(x)
58
- feats[2] = x
59
  x = net.maxpool(x)
60
  x = net.layer1(x)
61
- feats[4] = x
62
  x = net.layer2(x)
63
  feats[8] = x
64
  if self.early_exit:
@@ -77,48 +55,35 @@ class ResNet50(nn.Module):
77
  m.eval()
78
  pass
79
 
80
-
81
  class VGG19(nn.Module):
82
- def __init__(self, pretrained=False, amp=False) -> None:
83
  super().__init__()
84
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
85
  self.amp = amp
86
- if torch.cuda.is_available():
87
- if torch.cuda.is_bf16_supported():
88
- self.amp_dtype = torch.bfloat16
89
- else:
90
- self.amp_dtype = torch.float16
91
- else:
92
- self.amp_dtype = torch.float32
93
 
94
  def forward(self, x, **kwargs):
95
- with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
96
  feats = {}
97
  scale = 1
98
  for layer in self.layers:
99
  if isinstance(layer, nn.MaxPool2d):
100
  feats[scale] = x
101
- scale = scale * 2
102
  x = layer(x)
103
  return feats
104
 
105
-
106
  class CNNandDinov2(nn.Module):
107
- def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None):
108
  super().__init__()
109
  if dinov2_weights is None:
110
- dinov2_weights = torch.hub.load_state_dict_from_url(
111
- "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
112
- map_location="cpu",
113
- )
114
  from .transformer import vit_large
115
-
116
- vit_kwargs = dict(
117
- img_size=518,
118
- patch_size=14,
119
- init_values=1.0,
120
- ffn_layer="mlp",
121
- block_chunks=0,
122
  )
123
 
124
  dinov2_vitl14 = vit_large(**vit_kwargs).eval()
@@ -129,38 +94,25 @@ class CNNandDinov2(nn.Module):
129
  else:
130
  self.cnn = VGG19(**cnn_kwargs)
131
  self.amp = amp
132
- if torch.cuda.is_available():
133
- if torch.cuda.is_bf16_supported():
134
- self.amp_dtype = torch.bfloat16
135
- else:
136
- self.amp_dtype = torch.float16
137
- else:
138
- self.amp_dtype = torch.float32
139
  if self.amp:
140
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
141
- self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
142
-
 
143
  def train(self, mode: bool = True):
144
  return self.cnn.train(mode)
145
-
146
- def forward(self, x, upsample=False):
147
- B, C, H, W = x.shape
148
  feature_pyramid = self.cnn(x)
149
-
150
  if not upsample:
151
  with torch.no_grad():
152
  if self.dinov2_vitl14[0].device != x.device:
153
- self.dinov2_vitl14[0] = (
154
- self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
155
- )
156
- dinov2_features_16 = self.dinov2_vitl14[0].forward_features(
157
- x.to(self.amp_dtype)
158
- )
159
- features_16 = (
160
- dinov2_features_16["x_norm_patchtokens"]
161
- .permute(0, 2, 1)
162
- .reshape(B, 1024, H // 14, W // 14)
163
- )
164
  del dinov2_features_16
165
  feature_pyramid[16] = features_16
166
- return feature_pyramid
 
6
  import torchvision.models as tvm
7
  import gc
8
 
 
 
9
 
10
  class ResNet50(nn.Module):
11
+ def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
 
 
 
 
 
 
 
 
 
 
12
  super().__init__()
13
  if dilation is None:
14
+ dilation = [False,False,False]
15
  if anti_aliased:
16
  pass
17
  else:
18
  if weights is not None:
19
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
 
 
20
  else:
21
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
22
+
 
 
23
  self.high_res = high_res
24
  self.freeze_bn = freeze_bn
25
  self.early_exit = early_exit
26
  self.amp = amp
27
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
 
 
 
28
 
29
  def forward(self, x, **kwargs):
30
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
31
  net = self.net
32
+ feats = {1:x}
33
  x = net.conv1(x)
34
  x = net.bn1(x)
35
  x = net.relu(x)
36
+ feats[2] = x
37
  x = net.maxpool(x)
38
  x = net.layer1(x)
39
+ feats[4] = x
40
  x = net.layer2(x)
41
  feats[8] = x
42
  if self.early_exit:
 
55
  m.eval()
56
  pass
57
 
 
58
  class VGG19(nn.Module):
59
+ def __init__(self, pretrained=False, amp = False) -> None:
60
  super().__init__()
61
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
62
  self.amp = amp
63
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
 
 
 
64
 
65
  def forward(self, x, **kwargs):
66
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
67
  feats = {}
68
  scale = 1
69
  for layer in self.layers:
70
  if isinstance(layer, nn.MaxPool2d):
71
  feats[scale] = x
72
+ scale = scale*2
73
  x = layer(x)
74
  return feats
75
 
 
76
  class CNNandDinov2(nn.Module):
77
+ def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
78
  super().__init__()
79
  if dinov2_weights is None:
80
+ dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
 
 
 
81
  from .transformer import vit_large
82
+ vit_kwargs = dict(img_size= 518,
83
+ patch_size= 14,
84
+ init_values = 1.0,
85
+ ffn_layer = "mlp",
86
+ block_chunks = 0,
 
 
87
  )
88
 
89
  dinov2_vitl14 = vit_large(**vit_kwargs).eval()
 
94
  else:
95
  self.cnn = VGG19(**cnn_kwargs)
96
  self.amp = amp
97
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
 
 
 
98
  if self.amp:
99
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
100
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
101
+
102
+
103
  def train(self, mode: bool = True):
104
  return self.cnn.train(mode)
105
+
106
+ def forward(self, x, upsample = False):
107
+ B,C,H,W = x.shape
108
  feature_pyramid = self.cnn(x)
109
+
110
  if not upsample:
111
  with torch.no_grad():
112
  if self.dinov2_vitl14[0].device != x.device:
113
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
114
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
115
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
 
 
 
 
 
 
 
 
116
  del dinov2_features_16
117
  feature_pyramid[16] = features_16
118
+ return feature_pyramid
third_party/Roma/roma/models/matcher.py CHANGED
@@ -14,9 +14,6 @@ from roma.utils.local_correlation import local_correlation
14
  from roma.utils.utils import cls_to_flow_refine
15
  from roma.utils.kde import kde
16
 
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
-
19
-
20
  class ConvRefiner(nn.Module):
21
  def __init__(
22
  self,
@@ -26,29 +23,25 @@ class ConvRefiner(nn.Module):
26
  dw=False,
27
  kernel_size=5,
28
  hidden_blocks=3,
29
- displacement_emb=None,
30
- displacement_emb_dim=None,
31
- local_corr_radius=None,
32
- corr_in_other=None,
33
- no_im_B_fm=False,
34
- amp=False,
35
- concat_logits=False,
36
- use_bias_block_1=True,
37
- use_cosine_corr=False,
38
- disable_local_corr_grad=False,
39
- is_classifier=False,
40
- sample_mode="bilinear",
41
- norm_type=nn.BatchNorm2d,
42
- bn_momentum=0.1,
43
  ):
44
  super().__init__()
45
  self.bn_momentum = bn_momentum
46
  self.block1 = self.create_block(
47
- in_dim,
48
- hidden_dim,
49
- dw=dw,
50
- kernel_size=kernel_size,
51
- bias=use_bias_block_1,
52
  )
53
  self.hidden_blocks = nn.Sequential(
54
  *[
@@ -66,7 +59,7 @@ class ConvRefiner(nn.Module):
66
  self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
67
  if displacement_emb:
68
  self.has_displacement_emb = True
69
- self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
70
  else:
71
  self.has_displacement_emb = False
72
  self.local_corr_radius = local_corr_radius
@@ -78,22 +71,16 @@ class ConvRefiner(nn.Module):
78
  self.disable_local_corr_grad = disable_local_corr_grad
79
  self.is_classifier = is_classifier
80
  self.sample_mode = sample_mode
81
- if torch.cuda.is_available():
82
- if torch.cuda.is_bf16_supported():
83
- self.amp_dtype = torch.bfloat16
84
- else:
85
- self.amp_dtype = torch.float16
86
- else:
87
- self.amp_dtype = torch.float32
88
-
89
  def create_block(
90
  self,
91
  in_dim,
92
  out_dim,
93
  dw=False,
94
  kernel_size=5,
95
- bias=True,
96
- norm_type=nn.BatchNorm2d,
97
  ):
98
  num_groups = 1 if not dw else in_dim
99
  if dw:
@@ -109,56 +96,38 @@ class ConvRefiner(nn.Module):
109
  groups=num_groups,
110
  bias=bias,
111
  )
112
- norm = (
113
- norm_type(out_dim, momentum=self.bn_momentum)
114
- if norm_type is nn.BatchNorm2d
115
- else norm_type(num_channels=out_dim)
116
- )
117
  relu = nn.ReLU(inplace=True)
118
  conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
119
  return nn.Sequential(conv1, norm, relu, conv2)
120
-
121
- def forward(self, x, y, flow, scale_factor=1, logits=None):
122
- b, c, hs, ws = x.shape
123
- with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
124
  with torch.no_grad():
125
- x_hat = F.grid_sample(
126
- y,
127
- flow.permute(0, 2, 3, 1),
128
- align_corners=False,
129
- mode=self.sample_mode,
130
- )
131
  if self.has_displacement_emb:
132
  im_A_coords = torch.meshgrid(
133
- (
134
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
135
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
136
- )
137
  )
138
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
139
  im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
140
- in_displacement = flow - im_A_coords
141
- emb_in_displacement = self.disp_emb(
142
- 40 / 32 * scale_factor * in_displacement
143
- )
144
  if self.local_corr_radius:
145
  if self.corr_in_other:
146
  # Corr in other means take a kxk grid around the predicted coordinate in other image
147
- local_corr = local_correlation(
148
- x,
149
- y,
150
- local_radius=self.local_corr_radius,
151
- flow=flow,
152
- sample_mode=self.sample_mode,
153
- )
154
  else:
155
- raise NotImplementedError(
156
- "Local corr in own frame should not be used."
157
- )
158
  if self.no_im_B_fm:
159
  x_hat = torch.zeros_like(x)
160
  d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
161
- else:
162
  d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
163
  else:
164
  if self.no_im_B_fm:
@@ -172,7 +141,6 @@ class ConvRefiner(nn.Module):
172
  displacement, certainty = d[:, :-1], d[:, -1:]
173
  return displacement, certainty
174
 
175
-
176
  class CosKernel(nn.Module): # similar to softmax kernel
177
  def __init__(self, T, learn_temperature=False):
178
  super().__init__()
@@ -193,7 +161,6 @@ class CosKernel(nn.Module): # similar to softmax kernel
193
  K = ((c - 1.0) / T).exp()
194
  return K
195
 
196
-
197
  class GP(nn.Module):
198
  def __init__(
199
  self,
@@ -207,7 +174,7 @@ class GP(nn.Module):
207
  only_nearest_neighbour=False,
208
  sigma_noise=0.1,
209
  no_cov=False,
210
- predict_features=False,
211
  ):
212
  super().__init__()
213
  self.K = kernel(T=T, learn_temperature=learn_temperature)
@@ -295,9 +262,7 @@ class GP(nn.Module):
295
  mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
296
  if not self.no_cov:
297
  cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
298
- cov_x = rearrange(
299
- cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
300
- )
301
  local_cov_x = self.get_local_cov(cov_x)
302
  local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
303
  gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
@@ -305,22 +270,11 @@ class GP(nn.Module):
305
  gp_feats = mu_x
306
  return gp_feats
307
 
308
-
309
  class Decoder(nn.Module):
310
  def __init__(
311
- self,
312
- embedding_decoder,
313
- gps,
314
- proj,
315
- conv_refiner,
316
- detach=False,
317
- scales="all",
318
- pos_embeddings=None,
319
- num_refinement_steps_per_scale=1,
320
- warp_noise_std=0.0,
321
- displacement_dropout_p=0.0,
322
- gm_warp_dropout_p=0.0,
323
- flow_upsample_mode="bilinear",
324
  ):
325
  super().__init__()
326
  self.embedding_decoder = embedding_decoder
@@ -342,14 +296,8 @@ class Decoder(nn.Module):
342
  self.displacement_dropout_p = displacement_dropout_p
343
  self.gm_warp_dropout_p = gm_warp_dropout_p
344
  self.flow_upsample_mode = flow_upsample_mode
345
- if torch.cuda.is_available():
346
- if torch.cuda.is_bf16_supported():
347
- self.amp_dtype = torch.bfloat16
348
- else:
349
- self.amp_dtype = torch.float16
350
- else:
351
- self.amp_dtype = torch.float32
352
-
353
  def get_placeholder_flow(self, b, h, w, device):
354
  coarse_coords = torch.meshgrid(
355
  (
@@ -362,8 +310,8 @@ class Decoder(nn.Module):
362
  ].expand(b, h, w, 2)
363
  coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
364
  return coarse_coords
365
-
366
- def get_positional_embedding(self, b, h, w, device):
367
  coarse_coords = torch.meshgrid(
368
  (
369
  torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
@@ -378,29 +326,16 @@ class Decoder(nn.Module):
378
  coarse_embedded_coords = self.pos_embedding(coarse_coords)
379
  return coarse_embedded_coords
380
 
381
- def forward(
382
- self,
383
- f1,
384
- f2,
385
- gt_warp=None,
386
- gt_prob=None,
387
- upsample=False,
388
- flow=None,
389
- certainty=None,
390
- scale_factor=1,
391
- ):
392
  coarse_scales = self.embedding_decoder.scales()
393
- all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
394
  sizes = {scale: f1[scale].shape[-2:] for scale in f1}
395
  h, w = sizes[1]
396
  b = f1[1].shape[0]
397
  device = f1[1].device
398
  coarsest_scale = int(all_scales[0])
399
  old_stuff = torch.zeros(
400
- b,
401
- self.embedding_decoder.hidden_dim,
402
- *sizes[coarsest_scale],
403
- device=f1[coarsest_scale].device,
404
  )
405
  corresps = {}
406
  if not upsample:
@@ -408,24 +343,24 @@ class Decoder(nn.Module):
408
  certainty = 0.0
409
  else:
410
  flow = F.interpolate(
411
- flow,
412
- size=sizes[coarsest_scale],
413
- align_corners=False,
414
- mode="bilinear",
415
- )
416
  certainty = F.interpolate(
417
- certainty,
418
- size=sizes[coarsest_scale],
419
- align_corners=False,
420
- mode="bilinear",
421
- )
422
  displacement = 0.0
423
  for new_scale in all_scales:
424
  ins = int(new_scale)
425
  corresps[ins] = {}
426
  f1_s, f2_s = f1[ins], f2[ins]
427
  if new_scale in self.proj:
428
- with torch.autocast(device, self.amp_dtype):
429
  f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
430
 
431
  if ins in coarse_scales:
@@ -436,59 +371,32 @@ class Decoder(nn.Module):
436
  gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
437
  gp_posterior, f1_s, old_stuff, new_scale
438
  )
439
-
440
  if self.embedding_decoder.is_classifier:
441
  flow = cls_to_flow_refine(
442
  gm_warp_or_cls,
443
- ).permute(0, 3, 1, 2)
444
- corresps[ins].update(
445
- {
446
- "gm_cls": gm_warp_or_cls,
447
- "gm_certainty": certainty,
448
- }
449
- ) if self.training else None
450
  else:
451
- corresps[ins].update(
452
- {
453
- "gm_flow": gm_warp_or_cls,
454
- "gm_certainty": certainty,
455
- }
456
- ) if self.training else None
457
  flow = gm_warp_or_cls.detach()
458
-
459
  if new_scale in self.conv_refiner:
460
- corresps[ins].update(
461
- {"flow_pre_delta": flow}
462
- ) if self.training else None
463
  delta_flow, delta_certainty = self.conv_refiner[new_scale](
464
- f1_s,
465
- f2_s,
466
- flow,
467
- scale_factor=scale_factor,
468
- logits=certainty,
469
- )
470
- corresps[ins].update(
471
- {
472
- "delta_flow": delta_flow,
473
- }
474
- ) if self.training else None
475
- displacement = ins * torch.stack(
476
- (
477
- delta_flow[:, 0].float() / (self.refine_init * w),
478
- delta_flow[:, 1].float() / (self.refine_init * h),
479
- ),
480
- dim=1,
481
- )
482
  flow = flow + displacement
483
  certainty = (
484
  certainty + delta_certainty
485
  ) # predict both certainty and displacement
486
- corresps[ins].update(
487
- {
488
- "certainty": certainty,
489
- "flow": flow,
490
- }
491
- )
492
  if new_scale != "1":
493
  flow = F.interpolate(
494
  flow,
@@ -503,7 +411,7 @@ class Decoder(nn.Module):
503
  if self.detach:
504
  flow = flow.detach()
505
  certainty = certainty.detach()
506
- # torch.cuda.empty_cache()
507
  return corresps
508
 
509
 
@@ -514,11 +422,11 @@ class RegressionMatcher(nn.Module):
514
  decoder,
515
  h=448,
516
  w=448,
517
- sample_mode="threshold",
518
- upsample_preds=False,
519
- symmetric=False,
520
- name=None,
521
- attenuate_cert=None,
522
  ):
523
  super().__init__()
524
  self.attenuate_cert = attenuate_cert
@@ -530,26 +438,24 @@ class RegressionMatcher(nn.Module):
530
  self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
531
  self.sample_mode = sample_mode
532
  self.upsample_preds = upsample_preds
533
- self.upsample_res = (14 * 16 * 6, 14 * 16 * 6)
534
  self.symmetric = symmetric
535
  self.sample_thresh = 0.05
536
-
537
  def get_output_resolution(self):
538
  if not self.upsample_preds:
539
  return self.h_resized, self.w_resized
540
  else:
541
  return self.upsample_res
542
-
543
- def extract_backbone_features(self, batch, batched=True, upsample=False):
544
  x_q = batch["im_A"]
545
  x_s = batch["im_B"]
546
  if batched:
547
- X = torch.cat((x_q, x_s), dim=0)
548
- feature_pyramid = self.encoder(X, upsample=upsample)
549
  else:
550
- feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder(
551
- x_s, upsample=upsample
552
- )
553
  return feature_pyramid
554
 
555
  def sample(
@@ -567,28 +473,22 @@ class RegressionMatcher(nn.Module):
567
  certainty.reshape(-1),
568
  )
569
  expansion_factor = 4 if "balanced" in self.sample_mode else 1
570
- good_samples = torch.multinomial(
571
- certainty,
572
- num_samples=min(expansion_factor * num, len(certainty)),
573
- replacement=False,
574
- )
575
  good_matches, good_certainty = matches[good_samples], certainty[good_samples]
576
  if "balanced" not in self.sample_mode:
577
  return good_matches, good_certainty
578
  density = kde(good_matches, std=0.1)
579
- p = 1 / (density + 1)
580
- p[
581
- density < 10
582
- ] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
583
- balanced_samples = torch.multinomial(
584
- p, num_samples=min(num, len(good_certainty)), replacement=False
585
- )
586
  return good_matches[balanced_samples], good_certainty[balanced_samples]
587
 
588
- def forward(self, batch, batched=True, upsample=False, scale_factor=1):
589
- feature_pyramid = self.extract_backbone_features(
590
- batch, batched=batched, upsample=upsample
591
- )
592
  if batched:
593
  f_q_pyramid = {
594
  scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
@@ -598,42 +498,32 @@ class RegressionMatcher(nn.Module):
598
  }
599
  else:
600
  f_q_pyramid, f_s_pyramid = feature_pyramid
601
- corresps = self.decoder(
602
- f_q_pyramid,
603
- f_s_pyramid,
604
- upsample=upsample,
605
- **(batch["corresps"] if "corresps" in batch else {}),
606
- scale_factor=scale_factor,
607
- )
608
-
609
  return corresps
610
 
611
- def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
612
- feature_pyramid = self.extract_backbone_features(
613
- batch, batched=batched, upsample=upsample
614
- )
615
  f_q_pyramid = feature_pyramid
616
  f_s_pyramid = {
617
- scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
618
  for scale, f_scale in feature_pyramid.items()
619
  }
620
- corresps = self.decoder(
621
- f_q_pyramid,
622
- f_s_pyramid,
623
- upsample=upsample,
624
- **(batch["corresps"] if "corresps" in batch else {}),
625
- scale_factor=scale_factor,
626
- )
627
  return corresps
628
-
629
  def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
630
- kpts_A, kpts_B = matches[..., :2], matches[..., 2:]
631
- kpts_A = torch.stack(
632
- (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1
633
- )
634
- kpts_B = torch.stack(
635
- (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1
636
- )
637
  return kpts_A, kpts_B
638
 
639
  def match(
@@ -642,12 +532,11 @@ class RegressionMatcher(nn.Module):
642
  im_B_path,
643
  *args,
644
  batched=False,
645
- device=None,
646
  ):
647
  if device is None:
648
- device = torch.device(device if torch.cuda.is_available() else "cpu")
649
  from PIL import Image
650
-
651
  if isinstance(im_A_path, (str, os.PathLike)):
652
  im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
653
  else:
@@ -663,9 +552,9 @@ class RegressionMatcher(nn.Module):
663
  # Get images in good format
664
  ws = self.w_resized
665
  hs = self.h_resized
666
-
667
  test_transform = get_tuple_transform_ops(
668
- resize=(hs, ws), normalize=True, clahe=False
669
  )
670
  im_A, im_B = test_transform((im_A, im_B))
671
  batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
@@ -675,32 +564,25 @@ class RegressionMatcher(nn.Module):
675
  assert w == w2 and h == h2, "For batched images we assume same size"
676
  batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
677
  if h != self.h_resized or self.w_resized != w:
678
- warn(
679
- "Model resolution and batch resolution differ, may produce unexpected results"
680
- )
681
  hs, ws = h, w
682
  finest_scale = 1
683
  # Run matcher
684
  if symmetric:
685
- corresps = self.forward_symmetric(batch)
686
  else:
687
- corresps = self.forward(batch, batched=True)
688
 
689
  if self.upsample_preds:
690
  hs, ws = self.upsample_res
691
-
692
  if self.attenuate_cert:
693
  low_res_certainty = F.interpolate(
694
- corresps[16]["certainty"],
695
- size=(hs, ws),
696
- align_corners=False,
697
- mode="bilinear",
698
  )
699
  cert_clamp = 0
700
  factor = 0.5
701
- low_res_certainty = (
702
- factor * low_res_certainty * (low_res_certainty < cert_clamp)
703
- )
704
 
705
  if self.upsample_preds:
706
  finest_corresps = corresps[finest_scale]
@@ -711,38 +593,30 @@ class RegressionMatcher(nn.Module):
711
  im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
712
  im_A, im_B = test_transform((im_A, im_B))
713
  im_A, im_B = im_A[None].to(device), im_B[None].to(device)
714
- scale_factor = math.sqrt(
715
- self.upsample_res[0]
716
- * self.upsample_res[1]
717
- / (self.w_resized * self.h_resized)
718
- )
719
  batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
720
  if symmetric:
721
- corresps = self.forward_symmetric(
722
- batch, upsample=True, batched=True, scale_factor=scale_factor
723
- )
724
  else:
725
- corresps = self.forward(
726
- batch, batched=True, upsample=True, scale_factor=scale_factor
727
- )
728
-
729
- im_A_to_im_B = corresps[finest_scale]["flow"]
730
- certainty = corresps[finest_scale]["certainty"] - (
731
- low_res_certainty if self.attenuate_cert else 0
732
- )
733
  if finest_scale != 1:
734
  im_A_to_im_B = F.interpolate(
735
- im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
736
  )
737
  certainty = F.interpolate(
738
- certainty, size=(hs, ws), align_corners=False, mode="bilinear"
 
 
 
739
  )
740
- im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
741
  # Create im_A meshgrid
742
  im_A_coords = torch.meshgrid(
743
  (
744
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
745
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
746
  )
747
  )
748
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -751,21 +625,25 @@ class RegressionMatcher(nn.Module):
751
  im_A_coords = im_A_coords.permute(0, 2, 3, 1)
752
  if (im_A_to_im_B.abs() > 1).any() and True:
753
  wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
754
- certainty[wrong[:, None]] = 0
755
  im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
756
  if symmetric:
757
  A_to_B, B_to_A = im_A_to_im_B.chunk(2)
758
  q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
759
  im_B_coords = im_A_coords
760
  s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
761
- warp = torch.cat((q_warp, s_warp), dim=2)
762
  certainty = torch.cat(certainty.chunk(2), dim=3)
763
  else:
764
  warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
765
  if batched:
766
- return (warp, certainty[:, 0])
 
 
 
767
  else:
768
  return (
769
  warp[0],
770
  certainty[0, 0],
771
  )
 
 
14
  from roma.utils.utils import cls_to_flow_refine
15
  from roma.utils.kde import kde
16
 
 
 
 
17
  class ConvRefiner(nn.Module):
18
  def __init__(
19
  self,
 
23
  dw=False,
24
  kernel_size=5,
25
  hidden_blocks=3,
26
+ displacement_emb = None,
27
+ displacement_emb_dim = None,
28
+ local_corr_radius = None,
29
+ corr_in_other = None,
30
+ no_im_B_fm = False,
31
+ amp = False,
32
+ concat_logits = False,
33
+ use_bias_block_1 = True,
34
+ use_cosine_corr = False,
35
+ disable_local_corr_grad = False,
36
+ is_classifier = False,
37
+ sample_mode = "bilinear",
38
+ norm_type = nn.BatchNorm2d,
39
+ bn_momentum = 0.1,
40
  ):
41
  super().__init__()
42
  self.bn_momentum = bn_momentum
43
  self.block1 = self.create_block(
44
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
 
 
 
 
45
  )
46
  self.hidden_blocks = nn.Sequential(
47
  *[
 
59
  self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
60
  if displacement_emb:
61
  self.has_displacement_emb = True
62
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
63
  else:
64
  self.has_displacement_emb = False
65
  self.local_corr_radius = local_corr_radius
 
71
  self.disable_local_corr_grad = disable_local_corr_grad
72
  self.is_classifier = is_classifier
73
  self.sample_mode = sample_mode
74
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
75
+
 
 
 
 
 
 
76
  def create_block(
77
  self,
78
  in_dim,
79
  out_dim,
80
  dw=False,
81
  kernel_size=5,
82
+ bias = True,
83
+ norm_type = nn.BatchNorm2d,
84
  ):
85
  num_groups = 1 if not dw else in_dim
86
  if dw:
 
96
  groups=num_groups,
97
  bias=bias,
98
  )
99
+ norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
 
 
 
 
100
  relu = nn.ReLU(inplace=True)
101
  conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
102
  return nn.Sequential(conv1, norm, relu, conv2)
103
+
104
+ def forward(self, x, y, flow, scale_factor = 1, logits = None):
105
+ b,c,hs,ws = x.shape
106
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
107
  with torch.no_grad():
108
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
 
 
 
 
 
109
  if self.has_displacement_emb:
110
  im_A_coords = torch.meshgrid(
111
+ (
112
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
113
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
114
+ )
115
  )
116
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
117
  im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
118
+ in_displacement = flow-im_A_coords
119
+ emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
 
 
120
  if self.local_corr_radius:
121
  if self.corr_in_other:
122
  # Corr in other means take a kxk grid around the predicted coordinate in other image
123
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
124
+ sample_mode = self.sample_mode)
 
 
 
 
 
125
  else:
126
+ raise NotImplementedError("Local corr in own frame should not be used.")
 
 
127
  if self.no_im_B_fm:
128
  x_hat = torch.zeros_like(x)
129
  d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
130
+ else:
131
  d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
132
  else:
133
  if self.no_im_B_fm:
 
141
  displacement, certainty = d[:, :-1], d[:, -1:]
142
  return displacement, certainty
143
 
 
144
  class CosKernel(nn.Module): # similar to softmax kernel
145
  def __init__(self, T, learn_temperature=False):
146
  super().__init__()
 
161
  K = ((c - 1.0) / T).exp()
162
  return K
163
 
 
164
  class GP(nn.Module):
165
  def __init__(
166
  self,
 
174
  only_nearest_neighbour=False,
175
  sigma_noise=0.1,
176
  no_cov=False,
177
+ predict_features = False,
178
  ):
179
  super().__init__()
180
  self.K = kernel(T=T, learn_temperature=learn_temperature)
 
262
  mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
263
  if not self.no_cov:
264
  cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
265
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
 
 
266
  local_cov_x = self.get_local_cov(cov_x)
267
  local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
268
  gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
 
270
  gp_feats = mu_x
271
  return gp_feats
272
 
 
273
  class Decoder(nn.Module):
274
  def __init__(
275
+ self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
276
+ num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
277
+ flow_upsample_mode = "bilinear"
 
 
 
 
 
 
 
 
 
 
278
  ):
279
  super().__init__()
280
  self.embedding_decoder = embedding_decoder
 
296
  self.displacement_dropout_p = displacement_dropout_p
297
  self.gm_warp_dropout_p = gm_warp_dropout_p
298
  self.flow_upsample_mode = flow_upsample_mode
299
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
300
+
 
 
 
 
 
 
301
  def get_placeholder_flow(self, b, h, w, device):
302
  coarse_coords = torch.meshgrid(
303
  (
 
310
  ].expand(b, h, w, 2)
311
  coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
312
  return coarse_coords
313
+
314
+ def get_positional_embedding(self, b, h ,w, device):
315
  coarse_coords = torch.meshgrid(
316
  (
317
  torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
 
326
  coarse_embedded_coords = self.pos_embedding(coarse_coords)
327
  return coarse_embedded_coords
328
 
329
+ def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
 
 
 
 
 
 
 
 
 
 
330
  coarse_scales = self.embedding_decoder.scales()
331
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
332
  sizes = {scale: f1[scale].shape[-2:] for scale in f1}
333
  h, w = sizes[1]
334
  b = f1[1].shape[0]
335
  device = f1[1].device
336
  coarsest_scale = int(all_scales[0])
337
  old_stuff = torch.zeros(
338
+ b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
 
 
 
339
  )
340
  corresps = {}
341
  if not upsample:
 
343
  certainty = 0.0
344
  else:
345
  flow = F.interpolate(
346
+ flow,
347
+ size=sizes[coarsest_scale],
348
+ align_corners=False,
349
+ mode="bilinear",
350
+ )
351
  certainty = F.interpolate(
352
+ certainty,
353
+ size=sizes[coarsest_scale],
354
+ align_corners=False,
355
+ mode="bilinear",
356
+ )
357
  displacement = 0.0
358
  for new_scale in all_scales:
359
  ins = int(new_scale)
360
  corresps[ins] = {}
361
  f1_s, f2_s = f1[ins], f2[ins]
362
  if new_scale in self.proj:
363
+ with torch.autocast("cuda", self.amp_dtype):
364
  f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
365
 
366
  if ins in coarse_scales:
 
371
  gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
372
  gp_posterior, f1_s, old_stuff, new_scale
373
  )
374
+
375
  if self.embedding_decoder.is_classifier:
376
  flow = cls_to_flow_refine(
377
  gm_warp_or_cls,
378
+ ).permute(0,3,1,2)
379
+ corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
 
 
 
 
 
380
  else:
381
+ corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
 
 
 
 
 
382
  flow = gm_warp_or_cls.detach()
383
+
384
  if new_scale in self.conv_refiner:
385
+ corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
 
 
386
  delta_flow, delta_certainty = self.conv_refiner[new_scale](
387
+ f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
388
+ )
389
+ corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
390
+ displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
391
+ delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  flow = flow + displacement
393
  certainty = (
394
  certainty + delta_certainty
395
  ) # predict both certainty and displacement
396
+ corresps[ins].update({
397
+ "certainty": certainty,
398
+ "flow": flow,
399
+ })
 
 
400
  if new_scale != "1":
401
  flow = F.interpolate(
402
  flow,
 
411
  if self.detach:
412
  flow = flow.detach()
413
  certainty = certainty.detach()
414
+ #torch.cuda.empty_cache()
415
  return corresps
416
 
417
 
 
422
  decoder,
423
  h=448,
424
  w=448,
425
+ sample_mode = "threshold",
426
+ upsample_preds = False,
427
+ symmetric = False,
428
+ name = None,
429
+ attenuate_cert = None,
430
  ):
431
  super().__init__()
432
  self.attenuate_cert = attenuate_cert
 
438
  self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
439
  self.sample_mode = sample_mode
440
  self.upsample_preds = upsample_preds
441
+ self.upsample_res = (14*16*6, 14*16*6)
442
  self.symmetric = symmetric
443
  self.sample_thresh = 0.05
444
+
445
  def get_output_resolution(self):
446
  if not self.upsample_preds:
447
  return self.h_resized, self.w_resized
448
  else:
449
  return self.upsample_res
450
+
451
+ def extract_backbone_features(self, batch, batched = True, upsample = False):
452
  x_q = batch["im_A"]
453
  x_s = batch["im_B"]
454
  if batched:
455
+ X = torch.cat((x_q, x_s), dim = 0)
456
+ feature_pyramid = self.encoder(X, upsample = upsample)
457
  else:
458
+ feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
 
 
459
  return feature_pyramid
460
 
461
  def sample(
 
473
  certainty.reshape(-1),
474
  )
475
  expansion_factor = 4 if "balanced" in self.sample_mode else 1
476
+ good_samples = torch.multinomial(certainty,
477
+ num_samples = min(expansion_factor*num, len(certainty)),
478
+ replacement=False)
 
 
479
  good_matches, good_certainty = matches[good_samples], certainty[good_samples]
480
  if "balanced" not in self.sample_mode:
481
  return good_matches, good_certainty
482
  density = kde(good_matches, std=0.1)
483
+ p = 1 / (density+1)
484
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
485
+ balanced_samples = torch.multinomial(p,
486
+ num_samples = min(num,len(good_certainty)),
487
+ replacement=False)
 
 
488
  return good_matches[balanced_samples], good_certainty[balanced_samples]
489
 
490
+ def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
491
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
 
 
492
  if batched:
493
  f_q_pyramid = {
494
  scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
 
498
  }
499
  else:
500
  f_q_pyramid, f_s_pyramid = feature_pyramid
501
+ corresps = self.decoder(f_q_pyramid,
502
+ f_s_pyramid,
503
+ upsample = upsample,
504
+ **(batch["corresps"] if "corresps" in batch else {}),
505
+ scale_factor=scale_factor)
506
+
 
 
507
  return corresps
508
 
509
+ def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
510
+ feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
 
 
511
  f_q_pyramid = feature_pyramid
512
  f_s_pyramid = {
513
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
514
  for scale, f_scale in feature_pyramid.items()
515
  }
516
+ corresps = self.decoder(f_q_pyramid,
517
+ f_s_pyramid,
518
+ upsample = upsample,
519
+ **(batch["corresps"] if "corresps" in batch else {}),
520
+ scale_factor=scale_factor)
 
 
521
  return corresps
522
+
523
  def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
524
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
525
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
526
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
 
 
 
 
527
  return kpts_A, kpts_B
528
 
529
  def match(
 
532
  im_B_path,
533
  *args,
534
  batched=False,
535
+ device = None,
536
  ):
537
  if device is None:
538
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
539
  from PIL import Image
 
540
  if isinstance(im_A_path, (str, os.PathLike)):
541
  im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
542
  else:
 
552
  # Get images in good format
553
  ws = self.w_resized
554
  hs = self.h_resized
555
+
556
  test_transform = get_tuple_transform_ops(
557
+ resize=(hs, ws), normalize=True, clahe = False
558
  )
559
  im_A, im_B = test_transform((im_A, im_B))
560
  batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
 
564
  assert w == w2 and h == h2, "For batched images we assume same size"
565
  batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
566
  if h != self.h_resized or self.w_resized != w:
567
+ warn("Model resolution and batch resolution differ, may produce unexpected results")
 
 
568
  hs, ws = h, w
569
  finest_scale = 1
570
  # Run matcher
571
  if symmetric:
572
+ corresps = self.forward_symmetric(batch)
573
  else:
574
+ corresps = self.forward(batch, batched = True)
575
 
576
  if self.upsample_preds:
577
  hs, ws = self.upsample_res
578
+
579
  if self.attenuate_cert:
580
  low_res_certainty = F.interpolate(
581
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
 
 
 
582
  )
583
  cert_clamp = 0
584
  factor = 0.5
585
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
 
 
586
 
587
  if self.upsample_preds:
588
  finest_corresps = corresps[finest_scale]
 
593
  im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
594
  im_A, im_B = test_transform((im_A, im_B))
595
  im_A, im_B = im_A[None].to(device), im_B[None].to(device)
596
+ scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
 
 
 
 
597
  batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
598
  if symmetric:
599
+ corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
 
 
600
  else:
601
+ corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
602
+
603
+ im_A_to_im_B = corresps[finest_scale]["flow"]
604
+ certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
 
 
 
 
605
  if finest_scale != 1:
606
  im_A_to_im_B = F.interpolate(
607
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
608
  )
609
  certainty = F.interpolate(
610
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
611
+ )
612
+ im_A_to_im_B = im_A_to_im_B.permute(
613
+ 0, 2, 3, 1
614
  )
 
615
  # Create im_A meshgrid
616
  im_A_coords = torch.meshgrid(
617
  (
618
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
619
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
620
  )
621
  )
622
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
 
625
  im_A_coords = im_A_coords.permute(0, 2, 3, 1)
626
  if (im_A_to_im_B.abs() > 1).any() and True:
627
  wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
628
+ certainty[wrong[:,None]] = 0
629
  im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
630
  if symmetric:
631
  A_to_B, B_to_A = im_A_to_im_B.chunk(2)
632
  q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
633
  im_B_coords = im_A_coords
634
  s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
635
+ warp = torch.cat((q_warp, s_warp),dim=2)
636
  certainty = torch.cat(certainty.chunk(2), dim=3)
637
  else:
638
  warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
639
  if batched:
640
+ return (
641
+ warp,
642
+ certainty[:, 0]
643
+ )
644
  else:
645
  return (
646
  warp[0],
647
  certainty[0, 0],
648
  )
649
+