File size: 7,810 Bytes
a7299bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import pickle
import os
import h5py
import numpy as np
import cv2
import torch
import torch.nn as nn
import glob
def init_image_coor(height, width):
x_row = np.arange(0, width)
x = np.tile(x_row, (height, 1))
x = x[np.newaxis, :, :]
x = x.astype(np.float32)
x = torch.from_numpy(x.copy()).cuda()
u_u0 = x - width/2.0
y_col = np.arange(0, height) # y_col = np.arange(0, height)
y = np.tile(y_col, (width, 1)).T
y = y[np.newaxis, :, :]
y = y.astype(np.float32)
y = torch.from_numpy(y.copy()).cuda()
v_v0 = y - height/2.0
return u_u0, v_v0
def depth_to_xyz(depth, focal_length):
b, c, h, w = depth.shape
u_u0, v_v0 = init_image_coor(h, w)
x = u_u0 * depth / focal_length[0]
y = v_v0 * depth / focal_length[1]
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
return pw
def get_surface_normal(xyz, patch_size=5):
# xyz: [1, h, w, 3]
x, y, z = torch.unbind(xyz, dim=3)
x = torch.unsqueeze(x, 0)
y = torch.unsqueeze(y, 0)
z = torch.unsqueeze(z, 0)
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
dim=4)
ATA = torch.squeeze(ATA)
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
ATA = ATA + eps_identity
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
AT1 = torch.squeeze(AT1)
AT1 = torch.unsqueeze(AT1, 3)
patch_num = 4
patch_x = int(AT1.size(1) / patch_num)
patch_y = int(AT1.size(0) / patch_num)
n_img = torch.randn(AT1.shape).cuda()
overlap = patch_size // 2 + 1
for x in range(int(patch_num)):
for y in range(int(patch_num)):
left_flg = 0 if x == 0 else 1
right_flg = 0 if x == patch_num -1 else 1
top_flg = 0 if y == 0 else 1
btm_flg = 0 if y == patch_num - 1 else 1
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
# n_img_tmp, _ = torch.solve(at1, ata)
n_img_tmp = torch.linalg.solve(ata, at1)
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
n_img_norm = n_img / n_img_L2
# re-orient normals consistently
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
n_img_norm[orient_mask] *= -1
return n_img_norm
def get_surface_normalv2(xyz, patch_size=5):
"""
xyz: xyz coordinates
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [h, w, 3, b]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
xyz_horizon = xyz_left - xyz_right # p4p6
xyz_vertical = xyz_top - xyz_bottom # p2p8
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
# plt.imshow(np.abs(a), cmap='rainbow')
# plt.show()
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
# para depth: depth map, [b, c, h, w]
b, c, h, w = depth.shape
focal_length = focal_length[:, None, None, None]
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
#depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
xyz = depth_to_xyz(depth_filter, focal_length)
sn_batch = []
for i in range(b):
xyz_i = xyz[i, :][None, :, :, :]
#normal = get_surface_normalv2(xyz_i)
normal = get_surface_normal(xyz_i)
sn_batch.append(normal)
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
if valid_mask != None:
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
sn_batch[mask_invalid] = 0.0
return sn_batch
|