gim-online / third_party /lanet /loss_function.py
Vincentqyw
fix: roma
8b973ee
raw
history blame
7.73 kB
import torch
def build_descriptor_loss(
source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False
):
"""
Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
Parameters
----------
source_des: torch.Tensor (B,256,H/8,W/8)
Source image descriptors.
target_des: torch.Tensor (B,256,H/8,W/8)
Target image descriptors.
source_points: torch.Tensor (B,H/8,W/8,2)
Source image keypoints
tar_points: torch.Tensor (B,H/8,W/8,2)
Target image keypoints
tar_points_un: torch.Tensor (B,2,H/8,W/8)
Target image keypoints unnormalized
eval_only: bool
Computes only recall without the loss.
Returns
-------
loss: torch.Tensor
Descriptor loss.
recall: torch.Tensor
Descriptor match recall.
"""
device = source_des.device
loss = 0
batch_size = source_des.size(0)
recall = 0.0
relax_field_size = [relax_field]
margins = [1.0]
weights = [1.0]
isource_dense = top_kk is None
for b_id in range(batch_size):
if isource_dense:
ref_desc = source_des[b_id].squeeze().view(256, -1)
tar_desc = target_des[b_id].squeeze().view(256, -1)
tar_points_raw = tar_points_un[b_id].view(2, -1)
else:
top_k = top_kk[b_id].squeeze()
n_feat = top_k.sum().item()
if n_feat < 20:
continue
ref_desc = source_des[b_id].squeeze()[:, top_k]
tar_desc = target_des[b_id].squeeze()[:, top_k]
tar_points_raw = tar_points_un[b_id][:, top_k]
# Compute dense descriptor distance matrix and find nearest neighbor
ref_desc = ref_desc.div(torch.norm(ref_desc, p=2, dim=0))
tar_desc = tar_desc.div(torch.norm(tar_desc, p=2, dim=0))
dmat = torch.mm(ref_desc.t(), tar_desc)
dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1))
_, idx = torch.sort(dmat, dim=1)
# Compute triplet loss and recall
for pyramid in range(len(relax_field_size)):
candidates = idx.t()
match_k_x = tar_points_raw[0, candidates]
match_k_y = tar_points_raw[1, candidates]
tru_x = tar_points_raw[0]
tru_y = tar_points_raw[1]
if pyramid == 0:
correct2 = (abs(match_k_x[0] - tru_x) == 0) & (
abs(match_k_y[0] - tru_y) == 0
)
correct2_cnt = correct2.float().sum()
recall += float(1.0 / batch_size) * (
float(correct2_cnt) / float(ref_desc.size(1))
)
if eval_only:
continue
correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (
abs(match_k_y - tru_y) <= relax_field_size[pyramid]
)
incorrect_index = (
torch.arange(start=correct_k.shape[0] - 1, end=-1, step=-1)
.unsqueeze(1)
.repeat(1, correct_k.shape[1])
.to(device)
)
incorrect_first = torch.argmax(
incorrect_index * (1 - correct_k.long()), dim=0
)
incorrect_first_index = candidates.gather(
0, incorrect_first.unsqueeze(0)
).squeeze()
anchor_var = ref_desc
posource_var = tar_desc
neg_var = tar_desc[:, incorrect_first_index]
loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(
anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]
).mul(weights[pyramid])
return loss, recall
class KeypointLoss(object):
"""
Loss function class encapsulating the location loss, the descriptor loss, and the score loss.
"""
def __init__(self, config):
self.score_weight = config.score_weight
self.loc_weight = config.loc_weight
self.desc_weight = config.desc_weight
self.corres_weight = config.corres_weight
self.corres_threshold = config.corres_threshold
def __call__(self, data):
B, _, hc, wc = data["source_score"].shape
loc_mat_abs = torch.abs(
data["target_coord_warped"].view(B, 2, -1).unsqueeze(3)
- data["target_coord"].view(B, 2, -1).unsqueeze(2)
)
l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1)
l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2)
# construct pseudo ground truth matching matrix
loc_min_mat = torch.repeat_interleave(
l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1
)
pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.0)
neg_mask = l2_dist_loc_mat.ge(4.0)
pos_corres = -torch.log(data["confidence_matrix"][pos_mask])
neg_corres = -torch.log(1.0 - data["confidence_matrix"][neg_mask])
corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean()
# corresponding distance threshold is 4
dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data[
"border_mask"
].view(B, hc * wc)
# location loss
loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean()
# desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
desc_loss, _ = build_descriptor_loss(
data["source_desc"],
data["target_desc_warped"],
data["target_coord_warped"].detach(),
top_kk=data["border_mask"],
relax_field=8,
)
# score loss
target_score_associated = (
data["target_score"]
.view(B, hc * wc)
.gather(1, l2_dist_loc_min_index)
.view(B, hc, wc)
.unsqueeze(1)
)
dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data[
"border_mask"
].unsqueeze(1)
l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1)
loc_err = l2_dist_loc_min[dist_norm_valid_mask]
# repeatable_constrain in score loss
repeatable_constrain = (
(
target_score_associated[dist_norm_valid_mask]
+ data["source_score"][dist_norm_valid_mask]
)
* (loc_err - loc_err.mean())
).mean()
# consistent_constrain in score_loss
consistent_constrain = (
torch.nn.functional.mse_loss(
data["target_score_warped"][data["border_mask"].unsqueeze(1)],
data["source_score"][data["border_mask"].unsqueeze(1)],
).mean()
* 2
)
aware_consistent_loss = (
torch.nn.functional.mse_loss(
data["target_aware_warped"][
data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1)
],
data["source_aware"][
data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1)
],
).mean()
* 2
)
score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss
loss = (
self.loc_weight * loc_loss
+ self.desc_weight * desc_loss
+ self.score_weight * score_loss
+ self.corres_weight * corres_loss
)
return (
loss,
self.loc_weight * loc_loss,
self.desc_weight * desc_loss,
self.score_weight * score_loss,
self.corres_weight * corres_loss,
)