ZeroShape / utils /eval_3D.py
zxhuang1698's picture
initial commit
414b431
import numpy as np
import torch
import threading
import mcubes
import trimesh
from utils.util_vis import show_att_on_image
from utils.camera import get_rotation_sphere
@torch.no_grad()
def get_dense_3D_grid(opt, var, N=None):
batch_size = len(var.idx)
N = N or opt.eval.vox_res
# -0.6, 0.6
range_min, range_max = opt.eval.range
grid = torch.linspace(range_min, range_max, N+1, device=opt.device)
points_3D = torch.stack(torch.meshgrid(grid, grid, grid, indexing='ij'), dim=-1) # [N, N, N, 3]
# actually N+1 instead of N
points_3D = points_3D.repeat(batch_size, 1, 1, 1, 1) # [B, N, N, N, 3]
return points_3D
@torch.no_grad()
def compute_level_grid(opt, impl_network, latent_depth, latent_semantic, points_3D, images, vis_attn=False):
# needed for amp
latent_depth = latent_depth.to(torch.float32) if latent_depth is not None else None
latent_semantic = latent_semantic.to(torch.float32) if latent_semantic is not None else None
# process points in sliced way
batch_size = points_3D.shape[0]
N = points_3D.shape[1]
assert N == points_3D.shape[2] == points_3D.shape[3]
assert points_3D.shape[4] == 3
points_3D = points_3D.view(batch_size, N, N*N, 3)
occ = []
attn = []
for i in range(N):
# [B, N*N, 3]
points_slice = points_3D[:, i]
# [B, N*N, 3] -> [B, N*N], [B, N*N, 1+feat_res**2]
occ_slice, attn_slice = impl_network(latent_depth, latent_semantic, points_slice)
occ.append(occ_slice)
attn.append(attn_slice.detach())
# [B, N, N*N] -> [B, N, N, N]
occ = torch.stack(occ, dim=1).view(batch_size, N, N, N)
occ = torch.sigmoid(occ)
if vis_attn:
N_global = 1
feat_res = opt.H // opt.arch.win_size
attn = torch.stack(attn, dim=1).view(batch_size, N, N, N, N_global+feat_res**2)
# average along Z, [B, N, N, N_global+feat_res**2]
attn = torch.mean(attn, dim=3)
# [B, N, N, N_global] -> [B, N, N, 1]
attn_global = attn[:, :, :, :N_global].sum(dim=-1, keepdim=True)
# [B, N, N, feat_res, feat_res]
attn_local = attn[:, :, :, N_global:].view(batch_size, N, N, feat_res, feat_res)
# [B, N, N, feat_res, feat_res]
attn_vis = attn_global.unsqueeze(-1) + attn_local
# list of frame lists
images_vis = []
for b in range(batch_size):
images_vis_sample = []
for row in range(0, N, 8):
if row % 16 == 0:
col_range = range(0, N//8*8+1, 8)
else:
col_range = range(N//8*8, -1, -8)
for col in col_range:
# [feat_res, feat_res], x is col
attn_curr = attn_vis[b, col, row]
attn_curr = torch.nn.functional.interpolate(
attn_curr.unsqueeze(0).unsqueeze(0), size=(opt.H, opt.W),
mode='bilinear', align_corners=False
).squeeze(0).squeeze(0).cpu().numpy()
attn_curr /= attn_curr.max()
# [feat_res, feat_res, 3]
image_curr = images[b].permute(1, 2, 0).cpu().numpy()
# merge the image and the attention
images_vis_sample.append(show_att_on_image(image_curr, attn_curr))
images_vis.append(images_vis_sample)
return occ, images_vis if vis_attn else None
@torch.no_grad()
def standardize_pc(pc):
assert len(pc.shape) == 3
pc_mean = pc.mean(dim=1, keepdim=True)
pc_zmean = pc - pc_mean
origin_distance = (pc_zmean**2).sum(dim=2, keepdim=True).sqrt()
scale = torch.sqrt(torch.sum(origin_distance**2, dim=1, keepdim=True) / pc.shape[1])
pc_standardized = pc_zmean / (scale * 2)
return pc_standardized
@torch.no_grad()
def normalize_pc(pc):
assert len(pc.shape) == 3
pc_mean = pc.mean(dim=1, keepdim=True)
pc_zmean = pc - pc_mean
length_x = pc_zmean[:, :, 0].max(dim=-1)[0] - pc_zmean[:, :, 0].min(dim=-1)[0]
length_y = pc_zmean[:, :, 1].max(dim=-1)[0] - pc_zmean[:, :, 1].min(dim=-1)[0]
length_max = torch.stack([length_x, length_y], dim=-1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
pc_normalized = pc_zmean / (length_max + 1.e-7)
return pc_normalized
def convert_to_explicit(opt, level_grids, isoval=0., to_pointcloud=False):
N = len(level_grids)
meshes = [None]*N
pointclouds = [None]*N if to_pointcloud else None
threads = [threading.Thread(target=convert_to_explicit_worker,
args=(opt, i, level_grids[i], isoval, meshes),
kwargs=dict(pointclouds=pointclouds),
daemon=False) for i in range(N)]
for t in threads: t.start()
for t in threads: t.join()
if to_pointcloud:
pointclouds = np.stack(pointclouds, axis=0)
return meshes, pointclouds
else: return meshes
def convert_to_explicit_worker(opt, i, level_vox_i, isoval, meshes, pointclouds=None):
# use marching cubes to convert implicit surface to mesh
vertices, faces = mcubes.marching_cubes(level_vox_i, isovalue=isoval)
assert(level_vox_i.shape[0]==level_vox_i.shape[1]==level_vox_i.shape[2])
S = level_vox_i.shape[0]
range_min, range_max = opt.eval.range
# marching cubes treat every cube as unit length
vertices = vertices/S*(range_max-range_min)+range_min
mesh = trimesh.Trimesh(vertices, faces)
meshes[i] = mesh
if pointclouds is not None:
# randomly sample on mesh to get uniform dense point cloud
if len(mesh.triangles)!=0:
points = mesh.sample(opt.eval.num_points)
else: points = np.zeros([opt.eval.num_points, 3])
pointclouds[i] = points