import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def get_surface_normalv2(xyz, patch_size=5): """ xyz: xyz coordinates patch: [p1, p2, p3, p4, p5, p6, p7, p8, p9] surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] return: normal [h, w, 3, b] """ b, h, w, c = xyz.shape half_patch = patch_size // 2 xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 xyz_horizon = xyz_left - xyz_right # p4p6 xyz_vertical = xyz_top - xyz_bottom # p2p8 xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) # re-orient normals consistently orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 n_img_1[orient_mask] *= -1 orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 n_img_2[orient_mask] *= -1 n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) # average 2 norms n_img_aver = n_img1_norm + n_img2_norm n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) # re-orient normals consistently orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 n_img_aver_norm[orient_mask] *= -1 n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() # plt.imshow(np.abs(a), cmap='rainbow') # plt.show() return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) def init_image_coor(height, width): x_row = np.arange(0, width) x = np.tile(x_row, (height, 1)) x = x[np.newaxis, :, :] x = x.astype(np.float32) x = torch.from_numpy(x.copy()).cuda() u_u0 = x - width/2.0 y_col = np.arange(0, height) # y_col = np.arange(0, height) y = np.tile(y_col, (width, 1)).T y = y[np.newaxis, :, :] y = y.astype(np.float32) y = torch.from_numpy(y.copy()).cuda() v_v0 = y - height/2.0 return u_u0, v_v0 def depth_to_xyz(depth, focal_length): b, c, h, w = depth.shape u_u0, v_v0 = init_image_coor(h, w) x = u_u0 * depth / focal_length y = v_v0 * depth / focal_length z = depth pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] return pw def surface_normal_from_depth(depth, focal_length, valid_mask=None): # para depth: depth map, [b, c, h, w] b, c, h, w = depth.shape focal_length = focal_length[:, None, None, None] depth_filter = torch.nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) depth_filter = torch.nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) xyz = depth_to_xyz(depth_filter, focal_length) sn_batch = [] for i in range(b): xyz_i = xyz[i, :][None, :, :, :] normal = get_surface_normalv2(xyz_i) sn_batch.append(normal) sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) sn_batch[mask_invalid] = 0.0 return ########### # EDGE-GUIDED SAMPLING # input: # inputs[i,:], targets[i, :], masks[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w # return: # inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B ########### def ind2sub(idx, cols): r = idx / cols c = idx - r * cols return r, c def sub2ind(r, c, cols): idx = r * cols + c return idx def edgeGuidedSampling(inputs, targets, edges_img, thetas_img, masks, h, w): # find edges edges_max = edges_img.max() edges_min = edges_img.min() edges_mask = edges_img.ge(edges_max*0.1) edges_loc = edges_mask.nonzero() thetas_edge = torch.masked_select(thetas_img, edges_mask) minlen = thetas_edge.size()[0] # find anchor points (i.e, edge points) sample_num = minlen index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long).cuda() theta_anchors = torch.gather(thetas_edge, 0, index_anchors) row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w) ## compute the coordinates of 4-points, distances are from [2, 30] distance_matrix = torch.randint(3, 20, (4,sample_num)).cuda() pos_or_neg = torch.ones(4,sample_num).cuda() pos_or_neg[:2,:] = -pos_or_neg[:2,:] distance_matrix = distance_matrix.float() * pos_or_neg col = col_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.cos(theta_anchors).unsqueeze(0)).long() row = row_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.sin(theta_anchors).unsqueeze(0)).long() # constrain 0=w-1] = w-1 row[row<0] = 0 row[row>h-1] = h-1 # a-b, b-c, c-d a = sub2ind(row[0,:], col[0,:], w) b = sub2ind(row[1,:], col[1,:], w) c = sub2ind(row[2,:], col[2,:], w) d = sub2ind(row[3,:], col[3,:], w) A = torch.cat((a,b,c), 0) B = torch.cat((b,c,d), 0) inputs_A = inputs[:, A] inputs_B = inputs[:, B] targets_A = targets[:, A] targets_B = targets[:, B] masks_A = torch.gather(masks, 0, A.long()) masks_B = torch.gather(masks, 0, B.long()) return inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num, row, col ########### # RANDOM SAMPLING # input: # inputs[i,:], targets[i, :], masks[i, :], self.mask_value, self.point_pairs # return: # inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B ########### def randomSamplingNormal(inputs, targets, masks, sample_num): # find A-B point pairs from predictions num_effect_pixels = torch.sum(masks) shuffle_effect_pixels = torch.randperm(num_effect_pixels).cuda() valid_inputs = inputs[:, masks] valid_targes = targets[:, masks] inputs_A = valid_inputs[:, shuffle_effect_pixels[0:sample_num*2:2]] inputs_B = valid_inputs[:, shuffle_effect_pixels[1:sample_num*2:2]] # find corresponding pairs from GT targets_A = valid_targes[:, shuffle_effect_pixels[0:sample_num*2:2]] targets_B = valid_targes[:, shuffle_effect_pixels[1:sample_num*2:2]] if inputs_A.shape[1] != inputs_B.shape[1]: num_min = min(targets_A.shape[1], targets_B.shape[1]) inputs_A = inputs_A[:, :num_min] inputs_B = inputs_B[:, :num_min] targets_A = targets_A[:, :num_min] targets_B = targets_B[:, :num_min] return inputs_A, inputs_B, targets_A, targets_B class EdgeguidedNormalRegressionLoss(nn.Module): def __init__(self, point_pairs=10000, cos_theta1=0.3, cos_theta2=0.95, cos_theta3=0.5, cos_theta4=0.86, mask_value=-1e-8, max_threshold=10.1): super(EdgeguidedNormalRegressionLoss, self).__init__() self.point_pairs = point_pairs # number of point pairs self.mask_value = mask_value self.max_threshold = max_threshold self.cos_theta1 = cos_theta1 # 75 degree self.cos_theta2 = cos_theta2 # 10 degree self.cos_theta3 = cos_theta3 # 60 degree self.cos_theta4 = cos_theta4 # 30 degree self.kernel = torch.tensor(np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32), requires_grad=False)[None, None, :, :].cuda() def scale_shift_pred_depth(self, pred, gt): b, c, h, w = pred.shape mask = (gt > self.mask_value) & (gt < self.max_threshold) # [b, c, h, w] EPS = 1e-6 * torch.eye(2, dtype=pred.dtype, device=pred.device) scale_shift_batch = [] ones_img = torch.ones((1, h, w), dtype=pred.dtype, device=pred.device) for i in range(b): mask_i = mask[i, ...] pred_valid_i = pred[i, ...][mask_i] ones_i = ones_img[mask_i] pred_valid_ones_i = torch.stack((pred_valid_i, ones_i), dim=0) # [c+1, n] A_i = torch.matmul(pred_valid_ones_i, pred_valid_ones_i.permute(1, 0)) # [2, 2] A_inverse = torch.inverse(A_i + EPS) gt_i = gt[i, ...][mask_i] B_i = torch.matmul(pred_valid_ones_i, gt_i)[:, None] # [2, 1] scale_shift_i = torch.matmul(A_inverse, B_i) # [2, 1] scale_shift_batch.append(scale_shift_i) scale_shift_batch = torch.stack(scale_shift_batch, dim=0) # [b, 2, 1] ones = torch.ones_like(pred) pred_ones = torch.cat((pred, ones), dim=1) # [b, 2, h, w] pred_scale_shift = torch.matmul(pred_ones.permute(0, 2, 3, 1).reshape(b, h * w, 2), scale_shift_batch) # [b, h*w, 1] pred_scale_shift = pred_scale_shift.permute(0, 2, 1).reshape((b, c, h, w)) return pred_scale_shift def getEdge(self, images): n,c,h,w = images.size() a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1) b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1) if c == 3: gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a) gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b) else: gradient_x = F.conv2d(images, a) gradient_y = F.conv2d(images, b) edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2)) edges = F.pad(edges, (1,1,1,1), "constant", 0) thetas = torch.atan2(gradient_y, gradient_x) thetas = F.pad(thetas, (1,1,1,1), "constant", 0) return edges, thetas def getNormalEdge(self, normals): n,c,h,w = normals.size() a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(3, 1, 1, 1) b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(3, 1, 1, 1) gradient_x = torch.abs(F.conv2d(normals, a, groups=c)) gradient_y = torch.abs(F.conv2d(normals, b, groups=c)) gradient_x = gradient_x.mean(dim=1, keepdim=True) gradient_y = gradient_y.mean(dim=1, keepdim=True) edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2)) edges = F.pad(edges, (1,1,1,1), "constant", 0) thetas = torch.atan2(gradient_y, gradient_x) thetas = F.pad(thetas, (1,1,1,1), "constant", 0) return edges, thetas def forward(self, pred_depths, gt_depths, images, focal_length): """ inputs and targets: surface normal image images: rgb images """ masks = gt_depths > self.mask_value #pred_depths_ss = self.scale_shift_pred_depth(pred_depths, gt_depths) inputs = surface_normal_from_depth(pred_depths, focal_length, valid_mask=masks) targets = surface_normal_from_depth(gt_depths, focal_length, valid_mask=masks) # find edges from RGB edges_img, thetas_img = self.getEdge(images) # find edges from normals edges_normal, thetas_normal = self.getNormalEdge(targets) mask_img_border = torch.ones_like(edges_normal) # normals on the borders mask_img_border[:, :, 5:-5, 5:-5] = 0 edges_normal[mask_img_border.bool()] = 0 # find edges from depth edges_depth, _ = self.getEdge(gt_depths) edges_depth_mask = edges_depth.ge(edges_depth.max() * 0.1) edges_mask_dilate = torch.clamp(torch.nn.functional.conv2d(edges_depth_mask.float(), self.kernel, padding=(1, 1)), 0, 1).bool() edges_normal[edges_mask_dilate] = 0 edges_img[edges_mask_dilate] = 0 #============================= n,c,h,w = targets.size() inputs = inputs.contiguous().view(n, c, -1).double() targets = targets.contiguous().view(n, c, -1).double() masks = masks.contiguous().view(n, -1) edges_img = edges_img.contiguous().view(n, -1).double() thetas_img = thetas_img.contiguous().view(n, -1).double() edges_normal = edges_normal.view(n, -1).double() thetas_normal = thetas_normal.view(n, -1).double() # initialization loss = torch.DoubleTensor([0.0]).cuda() for i in range(n): # Edge-Guided sampling inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num, row_img, col_img = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w) normal_inputs_A, normal_inputs_B, normal_targets_A, normal_targets_B, normal_masks_A, normal_masks_B, normal_sample_num, row_normal, col_normal = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_normal[i], thetas_normal[i], masks[i, :], h, w) # Combine EGS + EGNS inputs_A = torch.cat((inputs_A, normal_inputs_A), 1) inputs_B = torch.cat((inputs_B, normal_inputs_B), 1) targets_A = torch.cat((targets_A, normal_targets_A), 1) targets_B = torch.cat((targets_B, normal_targets_B), 1) masks_A = torch.cat((masks_A, normal_masks_A), 0) masks_B = torch.cat((masks_B, normal_masks_B), 0) # consider forward-backward consistency checking, i.e, only compute losses of point pairs with valid GT consistency_mask = masks_A & masks_B #GT ordinal relationship target_cos = torch.abs(torch.sum(targets_A * targets_B, dim=0)) input_cos = torch.abs(torch.sum(inputs_A * inputs_B, dim=0)) # ranking regression #loss += torch.mean(torch.abs(target_cos[consistency_mask] - input_cos[consistency_mask])) # Ranking for samples mask_cos75 = target_cos < self.cos_theta1 mask_cos10 = target_cos > self.cos_theta2 # Regression for samples loss += torch.sum(torch.abs(target_cos[mask_cos75 & consistency_mask] - input_cos[mask_cos75 & consistency_mask])) / (torch.sum(mask_cos75 & consistency_mask)+1e-8) loss += torch.sum(torch.abs(target_cos[mask_cos10 & consistency_mask] - input_cos[mask_cos10 & consistency_mask])) / (torch.sum(mask_cos10 & consistency_mask)+1e-8) # Random Sampling regression random_sample_num = torch.sum(mask_cos10 & consistency_mask) + torch.sum(torch.sum(mask_cos75 & consistency_mask)) random_inputs_A, random_inputs_B, random_targets_A, random_targets_B = randomSamplingNormal(inputs[i,:], targets[i, :], masks[i, :], random_sample_num) #GT ordinal relationship random_target_cos = torch.abs(torch.sum(random_targets_A * random_targets_B, dim=0)) random_input_cos = torch.abs(torch.sum(random_inputs_A * random_inputs_B, dim=0)) loss += torch.sum(torch.abs(random_target_cos - random_input_cos)) / (random_target_cos.shape[0] + 1e-8) if loss[0] != 0: return loss[0].float() / n else: return pred_depths.sum() * 0.0