import os import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange import warnings from warnings import warn import roma from roma.utils import get_tuple_transform_ops from roma.utils.local_correlation import local_correlation from roma.utils.utils import cls_to_flow_refine from roma.utils.kde import kde device = "cuda" if torch.cuda.is_available() else "cpu" class ConvRefiner(nn.Module): def __init__( self, in_dim=6, hidden_dim=16, out_dim=2, dw=False, kernel_size=5, hidden_blocks=3, displacement_emb=None, displacement_emb_dim=None, local_corr_radius=None, corr_in_other=None, no_im_B_fm=False, amp=False, concat_logits=False, use_bias_block_1=True, use_cosine_corr=False, disable_local_corr_grad=False, is_classifier=False, sample_mode="bilinear", norm_type=nn.BatchNorm2d, bn_momentum=0.1, ): super().__init__() self.bn_momentum = bn_momentum self.block1 = self.create_block( in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias=use_bias_block_1, ) self.hidden_blocks = nn.Sequential( *[ self.create_block( hidden_dim, hidden_dim, dw=dw, kernel_size=kernel_size, norm_type=norm_type, ) for hb in range(hidden_blocks) ] ) self.hidden_blocks = self.hidden_blocks self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) if displacement_emb: self.has_displacement_emb = True self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0) else: self.has_displacement_emb = False self.local_corr_radius = local_corr_radius self.corr_in_other = corr_in_other self.no_im_B_fm = no_im_B_fm self.amp = amp self.concat_logits = concat_logits self.use_cosine_corr = use_cosine_corr self.disable_local_corr_grad = disable_local_corr_grad self.is_classifier = is_classifier self.sample_mode = sample_mode if torch.cuda.is_available(): if torch.cuda.is_bf16_supported(): self.amp_dtype = torch.bfloat16 else: self.amp_dtype = torch.float16 else: self.amp_dtype = torch.float32 def create_block( self, in_dim, out_dim, dw=False, kernel_size=5, bias=True, norm_type=nn.BatchNorm2d, ): num_groups = 1 if not dw else in_dim if dw: assert ( out_dim % in_dim == 0 ), "outdim must be divisible by indim for depthwise" conv1 = nn.Conv2d( in_dim, out_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=num_groups, bias=bias, ) norm = ( norm_type(out_dim, momentum=self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim) ) relu = nn.ReLU(inplace=True) conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) return nn.Sequential(conv1, norm, relu, conv2) def forward(self, x, y, flow, scale_factor=1, logits=None): b, c, hs, ws = x.shape with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype): with torch.no_grad(): x_hat = F.grid_sample( y, flow.permute(0, 2, 3, 1), align_corners=False, mode=self.sample_mode, ) if self.has_displacement_emb: im_A_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), ) ) im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) in_displacement = flow - im_A_coords emb_in_displacement = self.disp_emb( 40 / 32 * scale_factor * in_displacement ) if self.local_corr_radius: if self.corr_in_other: # Corr in other means take a kxk grid around the predicted coordinate in other image local_corr = local_correlation( x, y, local_radius=self.local_corr_radius, flow=flow, sample_mode=self.sample_mode, ) else: raise NotImplementedError( "Local corr in own frame should not be used." ) if self.no_im_B_fm: x_hat = torch.zeros_like(x) d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) else: d = torch.cat((x, x_hat, emb_in_displacement), dim=1) else: if self.no_im_B_fm: x_hat = torch.zeros_like(x) d = torch.cat((x, x_hat), dim=1) if self.concat_logits: d = torch.cat((d, logits), dim=1) d = self.block1(d) d = self.hidden_blocks(d) d = self.out_conv(d.float()) displacement, certainty = d[:, :-1], d[:, -1:] return displacement, certainty class CosKernel(nn.Module): # similar to softmax kernel def __init__(self, T, learn_temperature=False): super().__init__() self.learn_temperature = learn_temperature if self.learn_temperature: self.T = nn.Parameter(torch.tensor(T)) else: self.T = T def __call__(self, x, y, eps=1e-6): c = torch.einsum("bnd,bmd->bnm", x, y) / ( x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps ) if self.learn_temperature: T = self.T.abs() + 0.01 else: T = torch.tensor(self.T, device=c.device) K = ((c - 1.0) / T).exp() return K class GP(nn.Module): def __init__( self, kernel, T=1, learn_temperature=False, only_attention=False, gp_dim=64, basis="fourier", covar_size=5, only_nearest_neighbour=False, sigma_noise=0.1, no_cov=False, predict_features=False, ): super().__init__() self.K = kernel(T=T, learn_temperature=learn_temperature) self.sigma_noise = sigma_noise self.covar_size = covar_size self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) self.only_attention = only_attention self.only_nearest_neighbour = only_nearest_neighbour self.basis = basis self.no_cov = no_cov self.dim = gp_dim self.predict_features = predict_features def get_local_cov(self, cov): K = self.covar_size b, h, w, h, w = cov.shape hw = h * w cov = F.pad(cov, 4 * (K // 2,)) # pad v_q delta = torch.stack( torch.meshgrid( torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) ), dim=-1, ) positions = torch.stack( torch.meshgrid( torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) ), dim=-1, ) neighbours = positions[:, :, None, None, :] + delta[None, :, :] points = torch.arange(hw)[:, None].expand(hw, K**2) local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ :, points.flatten(), neighbours[..., 0].flatten(), neighbours[..., 1].flatten(), ].reshape(b, h, w, K**2) return local_cov def reshape(self, x): return rearrange(x, "b d h w -> b (h w) d") def project_to_basis(self, x): if self.basis == "fourier": return torch.cos(8 * math.pi * self.pos_conv(x)) elif self.basis == "linear": return self.pos_conv(x) else: raise ValueError( "No other bases other than fourier and linear currently im_Bed in public release" ) def get_pos_enc(self, y): b, c, h, w = y.shape coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), ) ) coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ None ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") coarse_embedded_coords = self.project_to_basis(coarse_coords) return coarse_embedded_coords def forward(self, x, y, **kwargs): b, c, h1, w1 = x.shape b, c, h2, w2 = y.shape f = self.get_pos_enc(y) b, d, h2, w2 = f.shape x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f) K_xx = self.K(x, x) K_yy = self.K(y, y) K_xy = self.K(x, y) K_yx = K_xy.permute(0, 2, 1) sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] with warnings.catch_warnings(): K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) mu_x = K_xy.matmul(K_yy_inv.matmul(f)) mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) if not self.no_cov: cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) cov_x = rearrange( cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1 ) local_cov_x = self.get_local_cov(cov_x) local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") gp_feats = torch.cat((mu_x, local_cov_x), dim=1) else: gp_feats = mu_x return gp_feats class Decoder(nn.Module): def __init__( self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings=None, num_refinement_steps_per_scale=1, warp_noise_std=0.0, displacement_dropout_p=0.0, gm_warp_dropout_p=0.0, flow_upsample_mode="bilinear", ): super().__init__() self.embedding_decoder = embedding_decoder self.num_refinement_steps_per_scale = num_refinement_steps_per_scale self.gps = gps self.proj = proj self.conv_refiner = conv_refiner self.detach = detach if pos_embeddings is None: self.pos_embeddings = {} else: self.pos_embeddings = pos_embeddings if scales == "all": self.scales = ["32", "16", "8", "4", "2", "1"] else: self.scales = scales self.warp_noise_std = warp_noise_std self.refine_init = 4 self.displacement_dropout_p = displacement_dropout_p self.gm_warp_dropout_p = gm_warp_dropout_p self.flow_upsample_mode = flow_upsample_mode if torch.cuda.is_available(): if torch.cuda.is_bf16_supported(): self.amp_dtype = torch.bfloat16 else: self.amp_dtype = torch.float16 else: self.amp_dtype = torch.float32 def get_placeholder_flow(self, b, h, w, device): coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), ) ) coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ None ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") return coarse_coords def get_positional_embedding(self, b, h, w, device): coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), ) ) coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ None ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") coarse_embedded_coords = self.pos_embedding(coarse_coords) return coarse_embedded_coords def forward( self, f1, f2, gt_warp=None, gt_prob=None, upsample=False, flow=None, certainty=None, scale_factor=1, ): coarse_scales = self.embedding_decoder.scales() all_scales = self.scales if not upsample else ["8", "4", "2", "1"] sizes = {scale: f1[scale].shape[-2:] for scale in f1} h, w = sizes[1] b = f1[1].shape[0] device = f1[1].device coarsest_scale = int(all_scales[0]) old_stuff = torch.zeros( b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device, ) corresps = {} if not upsample: flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) certainty = 0.0 else: flow = F.interpolate( flow, size=sizes[coarsest_scale], align_corners=False, mode="bilinear", ) certainty = F.interpolate( certainty, size=sizes[coarsest_scale], align_corners=False, mode="bilinear", ) displacement = 0.0 for new_scale in all_scales: ins = int(new_scale) corresps[ins] = {} f1_s, f2_s = f1[ins], f2[ins] if new_scale in self.proj: with torch.autocast(device, self.amp_dtype): f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) if ins in coarse_scales: old_stuff = F.interpolate( old_stuff, size=sizes[ins], mode="bilinear", align_corners=False ) gp_posterior = self.gps[new_scale](f1_s, f2_s) gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder( gp_posterior, f1_s, old_stuff, new_scale ) if self.embedding_decoder.is_classifier: flow = cls_to_flow_refine( gm_warp_or_cls, ).permute(0, 3, 1, 2) corresps[ins].update( { "gm_cls": gm_warp_or_cls, "gm_certainty": certainty, } ) if self.training else None else: corresps[ins].update( { "gm_flow": gm_warp_or_cls, "gm_certainty": certainty, } ) if self.training else None flow = gm_warp_or_cls.detach() if new_scale in self.conv_refiner: corresps[ins].update( {"flow_pre_delta": flow} ) if self.training else None delta_flow, delta_certainty = self.conv_refiner[new_scale]( f1_s, f2_s, flow, scale_factor=scale_factor, logits=certainty, ) corresps[ins].update( { "delta_flow": delta_flow, } ) if self.training else None displacement = ins * torch.stack( ( delta_flow[:, 0].float() / (self.refine_init * w), delta_flow[:, 1].float() / (self.refine_init * h), ), dim=1, ) flow = flow + displacement certainty = ( certainty + delta_certainty ) # predict both certainty and displacement corresps[ins].update( { "certainty": certainty, "flow": flow, } ) if new_scale != "1": flow = F.interpolate( flow, size=sizes[ins // 2], mode=self.flow_upsample_mode, ) certainty = F.interpolate( certainty, size=sizes[ins // 2], mode=self.flow_upsample_mode, ) if self.detach: flow = flow.detach() certainty = certainty.detach() # torch.cuda.empty_cache() return corresps class RegressionMatcher(nn.Module): def __init__( self, encoder, decoder, h=448, w=448, sample_mode="threshold", upsample_preds=False, symmetric=False, name=None, attenuate_cert=None, ): super().__init__() self.attenuate_cert = attenuate_cert self.encoder = encoder self.decoder = decoder self.name = name self.w_resized = w self.h_resized = h self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) self.sample_mode = sample_mode self.upsample_preds = upsample_preds self.upsample_res = (14 * 16 * 6, 14 * 16 * 6) self.symmetric = symmetric self.sample_thresh = 0.05 def get_output_resolution(self): if not self.upsample_preds: return self.h_resized, self.w_resized else: return self.upsample_res def extract_backbone_features(self, batch, batched=True, upsample=False): x_q = batch["im_A"] x_s = batch["im_B"] if batched: X = torch.cat((x_q, x_s), dim=0) feature_pyramid = self.encoder(X, upsample=upsample) else: feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder( x_s, upsample=upsample ) return feature_pyramid def sample( self, matches, certainty, num=10000, ): if "threshold" in self.sample_mode: upper_thresh = self.sample_thresh certainty = certainty.clone() certainty[certainty > upper_thresh] = 1 matches, certainty = ( matches.reshape(-1, 4), certainty.reshape(-1), ) expansion_factor = 4 if "balanced" in self.sample_mode else 1 good_samples = torch.multinomial( certainty, num_samples=min(expansion_factor * num, len(certainty)), replacement=False, ) good_matches, good_certainty = matches[good_samples], certainty[good_samples] if "balanced" not in self.sample_mode: return good_matches, good_certainty density = kde(good_matches, std=0.1) p = 1 / (density + 1) p[ density < 10 ] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones balanced_samples = torch.multinomial( p, num_samples=min(num, len(good_certainty)), replacement=False ) return good_matches[balanced_samples], good_certainty[balanced_samples] def forward(self, batch, batched=True, upsample=False, scale_factor=1): feature_pyramid = self.extract_backbone_features( batch, batched=batched, upsample=upsample ) if batched: f_q_pyramid = { scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() } f_s_pyramid = { scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() } else: f_q_pyramid, f_s_pyramid = feature_pyramid corresps = self.decoder( f_q_pyramid, f_s_pyramid, upsample=upsample, **(batch["corresps"] if "corresps" in batch else {}), scale_factor=scale_factor, ) return corresps def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1): feature_pyramid = self.extract_backbone_features( batch, batched=batched, upsample=upsample ) f_q_pyramid = feature_pyramid f_s_pyramid = { scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0) for scale, f_scale in feature_pyramid.items() } corresps = self.decoder( f_q_pyramid, f_s_pyramid, upsample=upsample, **(batch["corresps"] if "corresps" in batch else {}), scale_factor=scale_factor, ) return corresps def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): kpts_A, kpts_B = matches[..., :2], matches[..., 2:] kpts_A = torch.stack( (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1 ) kpts_B = torch.stack( (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1 ) return kpts_A, kpts_B def match( self, im_A_path, im_B_path, *args, batched=False, device=None, ): if device is None: device = torch.device(device if torch.cuda.is_available() else "cpu") from PIL import Image if isinstance(im_A_path, (str, os.PathLike)): im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) else: # Assume its not a path im_A, im_B = im_A_path, im_B_path symmetric = self.symmetric self.train(False) with torch.no_grad(): if not batched: b = 1 w, h = im_A.size w2, h2 = im_B.size # Get images in good format ws = self.w_resized hs = self.h_resized test_transform = get_tuple_transform_ops( resize=(hs, ws), normalize=True, clahe=False ) im_A, im_B = test_transform((im_A, im_B)) batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} else: b, c, h, w = im_A.shape b, c, h2, w2 = im_B.shape assert w == w2 and h == h2, "For batched images we assume same size" batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)} if h != self.h_resized or self.w_resized != w: warn( "Model resolution and batch resolution differ, may produce unexpected results" ) hs, ws = h, w finest_scale = 1 # Run matcher if symmetric: corresps = self.forward_symmetric(batch) else: corresps = self.forward(batch, batched=True) if self.upsample_preds: hs, ws = self.upsample_res if self.attenuate_cert: low_res_certainty = F.interpolate( corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear", ) cert_clamp = 0 factor = 0.5 low_res_certainty = ( factor * low_res_certainty * (low_res_certainty < cert_clamp) ) if self.upsample_preds: finest_corresps = corresps[finest_scale] torch.cuda.empty_cache() test_transform = get_tuple_transform_ops( resize=(hs, ws), normalize=True ) im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) im_A, im_B = test_transform((im_A, im_B)) im_A, im_B = im_A[None].to(device), im_B[None].to(device) scale_factor = math.sqrt( self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized) ) batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} if symmetric: corresps = self.forward_symmetric( batch, upsample=True, batched=True, scale_factor=scale_factor ) else: corresps = self.forward( batch, batched=True, upsample=True, scale_factor=scale_factor ) im_A_to_im_B = corresps[finest_scale]["flow"] certainty = corresps[finest_scale]["certainty"] - ( low_res_certainty if self.attenuate_cert else 0 ) if finest_scale != 1: im_A_to_im_B = F.interpolate( im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" ) certainty = F.interpolate( certainty, size=(hs, ws), align_corners=False, mode="bilinear" ) im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1) # Create im_A meshgrid im_A_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), ) ) im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) certainty = certainty.sigmoid() # logits -> probs im_A_coords = im_A_coords.permute(0, 2, 3, 1) if (im_A_to_im_B.abs() > 1).any() and True: wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 certainty[wrong[:, None]] = 0 im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) if symmetric: A_to_B, B_to_A = im_A_to_im_B.chunk(2) q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) im_B_coords = im_A_coords s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) warp = torch.cat((q_warp, s_warp), dim=2) certainty = torch.cat(certainty.chunk(2), dim=3) else: warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) if batched: return (warp, certainty[:, 0]) else: return ( warp[0], certainty[0, 0], )