import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from lib.utils import ( grid_positions, upscale_positions, downscale_positions, savefig, imshow_image ) from lib.exceptions import NoGradientError, EmptyTensorError matplotlib.use('Agg') def loss_function( model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False ): output = model({ 'image1': batch['image1'].to(device), 'image2': batch['image2'].to(device) }) loss = torch.tensor(np.array([0], dtype=np.float32), device=device) has_grad = False n_valid_samples = 0 for idx_in_batch in range(batch['image1'].size(0)): # Annotations depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1] intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3] pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4] bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2] depth2 = batch['depth2'][idx_in_batch].to(device) intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device) pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device) bbox2 = batch['bbox2'][idx_in_batch].to(device) # Network output dense_features1 = output['dense_features1'][idx_in_batch] c, h1, w1 = dense_features1.size() scores1 = output['scores1'][idx_in_batch].view(-1) dense_features2 = output['dense_features2'][idx_in_batch] _, h2, w2 = dense_features2.size() scores2 = output['scores2'][idx_in_batch] all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0) descriptors1 = all_descriptors1 all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0) # Warp the positions from image 1 to image 2 fmap_pos1 = grid_positions(h1, w1, device) pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps) try: pos1, pos2, ids = warp( pos1, depth1, intrinsics1, pose1, bbox1, depth2, intrinsics2, pose2, bbox2 ) except EmptyTensorError: continue fmap_pos1 = fmap_pos1[:, ids] descriptors1 = descriptors1[:, ids] scores1 = scores1[ids] # Skip the pair if not enough GT correspondences are available if ids.size(0) < 128: continue # Descriptors at the corresponding positions fmap_pos2 = torch.round( downscale_positions(pos2, scaling_steps=scaling_steps) ).long() descriptors2 = F.normalize( dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]], dim=0 ) positive_distance = 2 - 2 * ( descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2) ).squeeze() all_fmap_pos2 = grid_positions(h2, w2, device) position_distance = torch.max( torch.abs( fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1) ), dim=0 )[0] is_out_of_safe_radius = position_distance > safe_radius distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2) negative_distance2 = torch.min( distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., dim=1 )[0] all_fmap_pos1 = grid_positions(h1, w1, device) position_distance = torch.max( torch.abs( fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1) ), dim=0 )[0] is_out_of_safe_radius = position_distance > safe_radius distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1) negative_distance1 = torch.min( distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., dim=1 )[0] diff = positive_distance - torch.min( negative_distance1, negative_distance2 ) scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]] loss = loss + ( torch.sum(scores1 * scores2 * F.relu(margin + diff)) / torch.sum(scores1 * scores2) ) has_grad = True n_valid_samples += 1 if plot and batch['batch_idx'] % batch['log_interval'] == 0: pos1_aux = pos1.cpu().numpy() pos2_aux = pos2.cpu().numpy() k = pos1_aux.shape[1] col = np.random.rand(k, 3) n_sp = 4 plt.figure() plt.subplot(1, n_sp, 1) im1 = imshow_image( batch['image1'][idx_in_batch].cpu().numpy(), preprocessing=batch['preprocessing'] ) plt.imshow(im1) plt.scatter( pos1_aux[1, :], pos1_aux[0, :], s=0.25**2, c=col, marker=',', alpha=0.5 ) plt.axis('off') plt.subplot(1, n_sp, 2) plt.imshow( output['scores1'][idx_in_batch].data.cpu().numpy(), cmap='Reds' ) plt.axis('off') plt.subplot(1, n_sp, 3) im2 = imshow_image( batch['image2'][idx_in_batch].cpu().numpy(), preprocessing=batch['preprocessing'] ) plt.imshow(im2) plt.scatter( pos2_aux[1, :], pos2_aux[0, :], s=0.25**2, c=col, marker=',', alpha=0.5 ) plt.axis('off') plt.subplot(1, n_sp, 4) plt.imshow( output['scores2'][idx_in_batch].data.cpu().numpy(), cmap='Reds' ) plt.axis('off') savefig('train_vis/%s.%02d.%02d.%d.png' % ( 'train' if batch['train'] else 'valid', batch['epoch_idx'], batch['batch_idx'] // batch['log_interval'], idx_in_batch ), dpi=300) plt.close() if not has_grad: raise NoGradientError loss = loss / n_valid_samples return loss def interpolate_depth(pos, depth): device = pos.device ids = torch.arange(0, pos.size(1), device=device) h, w = depth.size() i = pos[0, :] j = pos[1, :] # Valid corners i_top_left = torch.floor(i).long() j_top_left = torch.floor(j).long() valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) i_top_right = torch.floor(i).long() j_top_right = torch.ceil(j).long() valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) i_bottom_left = torch.ceil(i).long() j_bottom_left = torch.floor(j).long() valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) i_bottom_right = torch.ceil(i).long() j_bottom_right = torch.ceil(j).long() valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) valid_corners = torch.min( torch.min(valid_top_left, valid_top_right), torch.min(valid_bottom_left, valid_bottom_right) ) i_top_left = i_top_left[valid_corners] j_top_left = j_top_left[valid_corners] i_top_right = i_top_right[valid_corners] j_top_right = j_top_right[valid_corners] i_bottom_left = i_bottom_left[valid_corners] j_bottom_left = j_bottom_left[valid_corners] i_bottom_right = i_bottom_right[valid_corners] j_bottom_right = j_bottom_right[valid_corners] ids = ids[valid_corners] if ids.size(0) == 0: raise EmptyTensorError # Valid depth valid_depth = torch.min( torch.min( depth[i_top_left, j_top_left] > 0, depth[i_top_right, j_top_right] > 0 ), torch.min( depth[i_bottom_left, j_bottom_left] > 0, depth[i_bottom_right, j_bottom_right] > 0 ) ) i_top_left = i_top_left[valid_depth] j_top_left = j_top_left[valid_depth] i_top_right = i_top_right[valid_depth] j_top_right = j_top_right[valid_depth] i_bottom_left = i_bottom_left[valid_depth] j_bottom_left = j_bottom_left[valid_depth] i_bottom_right = i_bottom_right[valid_depth] j_bottom_right = j_bottom_right[valid_depth] ids = ids[valid_depth] if ids.size(0) == 0: raise EmptyTensorError # Interpolation i = i[ids] j = j[ids] dist_i_top_left = i - i_top_left.float() dist_j_top_left = j - j_top_left.float() w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) w_top_right = (1 - dist_i_top_left) * dist_j_top_left w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) w_bottom_right = dist_i_top_left * dist_j_top_left interpolated_depth = ( w_top_left * depth[i_top_left, j_top_left] + w_top_right * depth[i_top_right, j_top_right] + w_bottom_left * depth[i_bottom_left, j_bottom_left] + w_bottom_right * depth[i_bottom_right, j_bottom_right] ) pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) return [interpolated_depth, pos, ids] def uv_to_pos(uv): return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0) def warp( pos1, depth1, intrinsics1, pose1, bbox1, depth2, intrinsics2, pose2, bbox2 ): device = pos1.device Z1, pos1, ids = interpolate_depth(pos1, depth1) # COLMAP convention u1 = pos1[1, :] + bbox1[1] + .5 v1 = pos1[0, :] + bbox1[0] + .5 X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0]) Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1]) XYZ1_hom = torch.cat([ X1.view(1, -1), Y1.view(1, -1), Z1.view(1, -1), torch.ones(1, Z1.size(0), device=device) ], dim=0) XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom) XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1) uv2_hom = torch.matmul(intrinsics2, XYZ2) uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1) u2 = uv2[0, :] - bbox2[1] - .5 v2 = uv2[1, :] - bbox2[0] - .5 uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0) annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2) ids = ids[new_ids] pos1 = pos1[:, new_ids] estimated_depth = XYZ2[2, new_ids] inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05 ids = ids[inlier_mask] if ids.size(0) == 0: raise EmptyTensorError pos2 = pos2[:, inlier_mask] pos1 = pos1[:, inlier_mask] return pos1, pos2, ids