import torch def build_descriptor_loss( source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False ): """ Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. Parameters ---------- source_des: torch.Tensor (B,256,H/8,W/8) Source image descriptors. target_des: torch.Tensor (B,256,H/8,W/8) Target image descriptors. source_points: torch.Tensor (B,H/8,W/8,2) Source image keypoints tar_points: torch.Tensor (B,H/8,W/8,2) Target image keypoints tar_points_un: torch.Tensor (B,2,H/8,W/8) Target image keypoints unnormalized eval_only: bool Computes only recall without the loss. Returns ------- loss: torch.Tensor Descriptor loss. recall: torch.Tensor Descriptor match recall. """ device = source_des.device loss = 0 batch_size = source_des.size(0) recall = 0.0 relax_field_size = [relax_field] margins = [1.0] weights = [1.0] isource_dense = top_kk is None for b_id in range(batch_size): if isource_dense: ref_desc = source_des[b_id].squeeze().view(256, -1) tar_desc = target_des[b_id].squeeze().view(256, -1) tar_points_raw = tar_points_un[b_id].view(2, -1) else: top_k = top_kk[b_id].squeeze() n_feat = top_k.sum().item() if n_feat < 20: continue ref_desc = source_des[b_id].squeeze()[:, top_k] tar_desc = target_des[b_id].squeeze()[:, top_k] tar_points_raw = tar_points_un[b_id][:, top_k] # Compute dense descriptor distance matrix and find nearest neighbor ref_desc = ref_desc.div(torch.norm(ref_desc, p=2, dim=0)) tar_desc = tar_desc.div(torch.norm(tar_desc, p=2, dim=0)) dmat = torch.mm(ref_desc.t(), tar_desc) dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1)) _, idx = torch.sort(dmat, dim=1) # Compute triplet loss and recall for pyramid in range(len(relax_field_size)): candidates = idx.t() match_k_x = tar_points_raw[0, candidates] match_k_y = tar_points_raw[1, candidates] tru_x = tar_points_raw[0] tru_y = tar_points_raw[1] if pyramid == 0: correct2 = (abs(match_k_x[0] - tru_x) == 0) & ( abs(match_k_y[0] - tru_y) == 0 ) correct2_cnt = correct2.float().sum() recall += float(1.0 / batch_size) * ( float(correct2_cnt) / float(ref_desc.size(1)) ) if eval_only: continue correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & ( abs(match_k_y - tru_y) <= relax_field_size[pyramid] ) incorrect_index = ( torch.arange(start=correct_k.shape[0] - 1, end=-1, step=-1) .unsqueeze(1) .repeat(1, correct_k.shape[1]) .to(device) ) incorrect_first = torch.argmax( incorrect_index * (1 - correct_k.long()), dim=0 ) incorrect_first_index = candidates.gather( 0, incorrect_first.unsqueeze(0) ).squeeze() anchor_var = ref_desc posource_var = tar_desc neg_var = tar_desc[:, incorrect_first_index] loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss( anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid] ).mul(weights[pyramid]) return loss, recall class KeypointLoss(object): """ Loss function class encapsulating the location loss, the descriptor loss, and the score loss. """ def __init__(self, config): self.score_weight = config.score_weight self.loc_weight = config.loc_weight self.desc_weight = config.desc_weight self.corres_weight = config.corres_weight self.corres_threshold = config.corres_threshold def __call__(self, data): B, _, hc, wc = data["source_score"].shape loc_mat_abs = torch.abs( data["target_coord_warped"].view(B, 2, -1).unsqueeze(3) - data["target_coord"].view(B, 2, -1).unsqueeze(2) ) l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1) l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2) # construct pseudo ground truth matching matrix loc_min_mat = torch.repeat_interleave( l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1 ) pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.0) neg_mask = l2_dist_loc_mat.ge(4.0) pos_corres = -torch.log(data["confidence_matrix"][pos_mask]) neg_corres = -torch.log(1.0 - data["confidence_matrix"][neg_mask]) corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean() # corresponding distance threshold is 4 dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data[ "border_mask" ].view(B, hc * wc) # location loss loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean() # desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. desc_loss, _ = build_descriptor_loss( data["source_desc"], data["target_desc_warped"], data["target_coord_warped"].detach(), top_kk=data["border_mask"], relax_field=8, ) # score loss target_score_associated = ( data["target_score"] .view(B, hc * wc) .gather(1, l2_dist_loc_min_index) .view(B, hc, wc) .unsqueeze(1) ) dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data[ "border_mask" ].unsqueeze(1) l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1) loc_err = l2_dist_loc_min[dist_norm_valid_mask] # repeatable_constrain in score loss repeatable_constrain = ( ( target_score_associated[dist_norm_valid_mask] + data["source_score"][dist_norm_valid_mask] ) * (loc_err - loc_err.mean()) ).mean() # consistent_constrain in score_loss consistent_constrain = ( torch.nn.functional.mse_loss( data["target_score_warped"][data["border_mask"].unsqueeze(1)], data["source_score"][data["border_mask"].unsqueeze(1)], ).mean() * 2 ) aware_consistent_loss = ( torch.nn.functional.mse_loss( data["target_aware_warped"][ data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1) ], data["source_aware"][ data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1) ], ).mean() * 2 ) score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss loss = ( self.loc_weight * loc_loss + self.desc_weight * desc_loss + self.score_weight * score_loss + self.corres_weight * corres_loss ) return ( loss, self.loc_weight * loc_loss, self.desc_weight * desc_loss, self.score_weight * score_loss, self.corres_weight * corres_loss, )