Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import threading | |
import mcubes | |
import trimesh | |
from utils.util_vis import show_att_on_image | |
from utils.camera import get_rotation_sphere | |
def get_dense_3D_grid(opt, var, N=None): | |
batch_size = len(var.idx) | |
N = N or opt.eval.vox_res | |
# -0.6, 0.6 | |
range_min, range_max = opt.eval.range | |
grid = torch.linspace(range_min, range_max, N+1, device=opt.device) | |
points_3D = torch.stack(torch.meshgrid(grid, grid, grid, indexing='ij'), dim=-1) # [N, N, N, 3] | |
# actually N+1 instead of N | |
points_3D = points_3D.repeat(batch_size, 1, 1, 1, 1) # [B, N, N, N, 3] | |
return points_3D | |
def compute_level_grid(opt, impl_network, latent_depth, latent_semantic, points_3D, images, vis_attn=False): | |
# needed for amp | |
latent_depth = latent_depth.to(torch.float32) if latent_depth is not None else None | |
latent_semantic = latent_semantic.to(torch.float32) if latent_semantic is not None else None | |
# process points in sliced way | |
batch_size = points_3D.shape[0] | |
N = points_3D.shape[1] | |
assert N == points_3D.shape[2] == points_3D.shape[3] | |
assert points_3D.shape[4] == 3 | |
points_3D = points_3D.view(batch_size, N, N*N, 3) | |
occ = [] | |
attn = [] | |
for i in range(N): | |
# [B, N*N, 3] | |
points_slice = points_3D[:, i] | |
# [B, N*N, 3] -> [B, N*N], [B, N*N, 1+feat_res**2] | |
occ_slice, attn_slice = impl_network(latent_depth, latent_semantic, points_slice) | |
occ.append(occ_slice) | |
attn.append(attn_slice.detach()) | |
# [B, N, N*N] -> [B, N, N, N] | |
occ = torch.stack(occ, dim=1).view(batch_size, N, N, N) | |
occ = torch.sigmoid(occ) | |
if vis_attn: | |
N_global = 1 | |
feat_res = opt.H // opt.arch.win_size | |
attn = torch.stack(attn, dim=1).view(batch_size, N, N, N, N_global+feat_res**2) | |
# average along Z, [B, N, N, N_global+feat_res**2] | |
attn = torch.mean(attn, dim=3) | |
# [B, N, N, N_global] -> [B, N, N, 1] | |
attn_global = attn[:, :, :, :N_global].sum(dim=-1, keepdim=True) | |
# [B, N, N, feat_res, feat_res] | |
attn_local = attn[:, :, :, N_global:].view(batch_size, N, N, feat_res, feat_res) | |
# [B, N, N, feat_res, feat_res] | |
attn_vis = attn_global.unsqueeze(-1) + attn_local | |
# list of frame lists | |
images_vis = [] | |
for b in range(batch_size): | |
images_vis_sample = [] | |
for row in range(0, N, 8): | |
if row % 16 == 0: | |
col_range = range(0, N//8*8+1, 8) | |
else: | |
col_range = range(N//8*8, -1, -8) | |
for col in col_range: | |
# [feat_res, feat_res], x is col | |
attn_curr = attn_vis[b, col, row] | |
attn_curr = torch.nn.functional.interpolate( | |
attn_curr.unsqueeze(0).unsqueeze(0), size=(opt.H, opt.W), | |
mode='bilinear', align_corners=False | |
).squeeze(0).squeeze(0).cpu().numpy() | |
attn_curr /= attn_curr.max() | |
# [feat_res, feat_res, 3] | |
image_curr = images[b].permute(1, 2, 0).cpu().numpy() | |
# merge the image and the attention | |
images_vis_sample.append(show_att_on_image(image_curr, attn_curr)) | |
images_vis.append(images_vis_sample) | |
return occ, images_vis if vis_attn else None | |
def standardize_pc(pc): | |
assert len(pc.shape) == 3 | |
pc_mean = pc.mean(dim=1, keepdim=True) | |
pc_zmean = pc - pc_mean | |
origin_distance = (pc_zmean**2).sum(dim=2, keepdim=True).sqrt() | |
scale = torch.sqrt(torch.sum(origin_distance**2, dim=1, keepdim=True) / pc.shape[1]) | |
pc_standardized = pc_zmean / (scale * 2) | |
return pc_standardized | |
def normalize_pc(pc): | |
assert len(pc.shape) == 3 | |
pc_mean = pc.mean(dim=1, keepdim=True) | |
pc_zmean = pc - pc_mean | |
length_x = pc_zmean[:, :, 0].max(dim=-1)[0] - pc_zmean[:, :, 0].min(dim=-1)[0] | |
length_y = pc_zmean[:, :, 1].max(dim=-1)[0] - pc_zmean[:, :, 1].min(dim=-1)[0] | |
length_max = torch.stack([length_x, length_y], dim=-1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1) | |
pc_normalized = pc_zmean / (length_max + 1.e-7) | |
return pc_normalized | |
def convert_to_explicit(opt, level_grids, isoval=0., to_pointcloud=False): | |
N = len(level_grids) | |
meshes = [None]*N | |
pointclouds = [None]*N if to_pointcloud else None | |
threads = [threading.Thread(target=convert_to_explicit_worker, | |
args=(opt, i, level_grids[i], isoval, meshes), | |
kwargs=dict(pointclouds=pointclouds), | |
daemon=False) for i in range(N)] | |
for t in threads: t.start() | |
for t in threads: t.join() | |
if to_pointcloud: | |
pointclouds = np.stack(pointclouds, axis=0) | |
return meshes, pointclouds | |
else: return meshes | |
def convert_to_explicit_worker(opt, i, level_vox_i, isoval, meshes, pointclouds=None): | |
# use marching cubes to convert implicit surface to mesh | |
vertices, faces = mcubes.marching_cubes(level_vox_i, isovalue=isoval) | |
assert(level_vox_i.shape[0]==level_vox_i.shape[1]==level_vox_i.shape[2]) | |
S = level_vox_i.shape[0] | |
range_min, range_max = opt.eval.range | |
# marching cubes treat every cube as unit length | |
vertices = vertices/S*(range_max-range_min)+range_min | |
mesh = trimesh.Trimesh(vertices, faces) | |
meshes[i] = mesh | |
if pointclouds is not None: | |
# randomly sample on mesh to get uniform dense point cloud | |
if len(mesh.triangles)!=0: | |
points = mesh.sample(opt.eval.num_points) | |
else: points = np.zeros([opt.eval.num_points, 3]) | |
pointclouds[i] = points | |