# A reimplemented version in public environments by Xiao Fu and Mu Hu import pickle import os import h5py import numpy as np import cv2 import torch import torch.nn as nn import glob def init_image_coor(height, width): x_row = np.arange(0, width) x = np.tile(x_row, (height, 1)) x = x[np.newaxis, :, :] x = x.astype(np.float32) x = torch.from_numpy(x.copy()).cuda() u_u0 = x - width/2.0 y_col = np.arange(0, height) # y_col = np.arange(0, height) y = np.tile(y_col, (width, 1)).T y = y[np.newaxis, :, :] y = y.astype(np.float32) y = torch.from_numpy(y.copy()).cuda() v_v0 = y - height/2.0 return u_u0, v_v0 def depth_to_xyz(depth, focal_length): b, c, h, w = depth.shape u_u0, v_v0 = init_image_coor(h, w) x = u_u0 * depth / focal_length[0] y = v_v0 * depth / focal_length[1] z = depth pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] return pw def get_surface_normal(xyz, patch_size=5): # xyz: [1, h, w, 3] x, y, z = torch.unbind(xyz, dim=3) x = torch.unsqueeze(x, 0) y = torch.unsqueeze(y, 0) z = torch.unsqueeze(z, 0) xx = x * x yy = y * y zz = z * z xy = x * y xz = x * z yz = y * z patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda() xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2)) yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2)) zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2)) xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2)) xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2)) yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2)) ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], dim=4) ATA = torch.squeeze(ATA) ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3)) eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1]) ATA = ATA + eps_identity x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2)) y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2)) z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2)) AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4) AT1 = torch.squeeze(AT1) AT1 = torch.unsqueeze(AT1, 3) patch_num = 4 patch_x = int(AT1.size(1) / patch_num) patch_y = int(AT1.size(0) / patch_num) n_img = torch.randn(AT1.shape).cuda() overlap = patch_size // 2 + 1 for x in range(int(patch_num)): for y in range(int(patch_num)): left_flg = 0 if x == 0 else 1 right_flg = 0 if x == patch_num -1 else 1 top_flg = 0 if y == 0 else 1 btm_flg = 0 if y == patch_num - 1 else 1 at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap, x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap] # n_img_tmp, _ = torch.solve(at1, ata) n_img_tmp = torch.linalg.solve(ata, at1) n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :] n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True)) n_img_norm = n_img / n_img_L2 # re-orient normals consistently orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0 n_img_norm[orient_mask] *= -1 return n_img_norm def get_surface_normalv2(xyz, patch_size=5): """ xyz: xyz coordinates patch: [p1, p2, p3, p4, p5, p6, p7, p8, p9] surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] return: normal [h, w, 3, b] """ b, h, w, c = xyz.shape half_patch = patch_size // 2 xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 xyz_horizon = xyz_left - xyz_right # p4p6 xyz_vertical = xyz_top - xyz_bottom # p2p8 xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) # re-orient normals consistently orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 n_img_1[orient_mask] *= -1 orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 n_img_2[orient_mask] *= -1 n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) # average 2 norms n_img_aver = n_img1_norm + n_img2_norm n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) # re-orient normals consistently orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 n_img_aver_norm[orient_mask] *= -1 n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() # plt.imshow(np.abs(a), cmap='rainbow') # plt.show() return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) def surface_normal_from_depth(depth, focal_length, valid_mask=None): # para depth: depth map, [b, c, h, w] b, c, h, w = depth.shape focal_length = focal_length[:, None, None, None] depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) #depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) xyz = depth_to_xyz(depth_filter, focal_length) sn_batch = [] for i in range(b): xyz_i = xyz[i, :][None, :, :, :] #normal = get_surface_normalv2(xyz_i) normal = get_surface_normal(xyz_i) sn_batch.append(normal) sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] if valid_mask != None: mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) sn_batch[mask_invalid] = 0.0 return sn_batch