Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import logging | |
import mcubes | |
from icecream import ic | |
import os | |
import trimesh | |
from pysdf import SDF | |
import models.fields as fields | |
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables | |
def batched_index_select(values, indices, dim = 1): | |
value_dims = values.shape[(dim + 1):] | |
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) | |
indices = indices[(..., *((None,) * len(value_dims)))] | |
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) | |
value_expand_len = len(indices_shape) - (dim + 1) | |
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] | |
value_expand_shape = [-1] * len(values.shape) | |
expand_slice = slice(dim, (dim + value_expand_len)) | |
value_expand_shape[expand_slice] = indices.shape[expand_slice] | |
values = values.expand(*value_expand_shape) | |
dim += value_expand_len | |
return values.gather(dim, indices) | |
def create_mt_variable(device): | |
triangle_table = torch.tensor( | |
[ | |
[-1, -1, -1, -1, -1, -1], | |
[1, 0, 2, -1, -1, -1], | |
[4, 0, 3, -1, -1, -1], | |
[1, 4, 2, 1, 3, 4], | |
[3, 1, 5, -1, -1, -1], | |
[2, 3, 0, 2, 5, 3], | |
[1, 4, 0, 1, 5, 4], | |
[4, 2, 5, -1, -1, -1], | |
[4, 5, 2, -1, -1, -1], | |
[4, 1, 0, 4, 5, 1], | |
[3, 2, 0, 3, 5, 2], | |
[1, 3, 5, -1, -1, -1], | |
[4, 1, 2, 4, 3, 1], | |
[3, 0, 4, -1, -1, -1], | |
[2, 0, 1, -1, -1, -1], | |
[-1, -1, -1, -1, -1, -1] | |
], dtype=torch.long, device=device) | |
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) | |
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) | |
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) | |
return triangle_table, num_triangles_table, base_tet_edges, v_id | |
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None): | |
# load tet via resolution # | |
# scale them via bounds # | |
# extract the geometry # | |
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange # | |
device = bound_min.device | |
# if resolution in [64, 70, 80, 90, 100]: | |
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz" | |
# else: | |
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz" | |
tets = np.load(tet_fn) | |
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions | |
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device) | |
# split # | |
# verts; verts; # | |
minn_verts, _ = torch.min(verts, dim=0) | |
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the | |
# scale_verts = maxx_verts - minn_verts | |
scale_bounds = bound_max - bound_min # scale bounds # | |
### scale the vertices ### | |
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ### | |
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ### | |
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ## | |
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ### | |
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. # | |
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2. | |
sdf_values = [] | |
N = 64 | |
query_bundles = N ** 3 ### N^3 | |
query_NNs = scaled_verts.size(0) // query_bundles | |
if query_NNs * query_bundles < scaled_verts.size(0): | |
query_NNs += 1 | |
for i_query in range(query_NNs): | |
cur_bundle_st = i_query * query_bundles | |
cur_bundle_ed = (i_query + 1) * query_bundles | |
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0)) | |
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed] | |
if def_func is not None: | |
cur_query_pts = def_func(cur_query_pts) | |
cur_query_vals = query_func(cur_query_pts) | |
sdf_values.append(cur_query_vals) | |
sdf_values = torch.cat(sdf_values, dim=0) | |
# print(f"queryed sdf values: {sdf_values.size()}") # | |
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True) | |
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device) | |
# intrinsic, tet values, pts values, sdf network # | |
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device) | |
tet_table, num_tets_table = create_tetmesh_variables(device) | |
sdf_values = sdf_values.squeeze(-1) # how the rendering # | |
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}") | |
# print(f"scaled_verts: {scaled_verts.size()}, ") | |
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, | |
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): | |
# marching_tets_tetmesh ## | |
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table) | |
### use the GT sdf values for the marching tets ### | |
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table) | |
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}") | |
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces # | |
def extract_fields(bound_min, bound_max, resolution, query_func): | |
N = 64 | |
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) | |
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) | |
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) | |
u = np.zeros([resolution, resolution, resolution], dtype=np.float32) | |
with torch.no_grad(): | |
for xi, xs in enumerate(X): | |
for yi, ys in enumerate(Y): | |
for zi, zs in enumerate(Z): | |
xx, yy, zz = torch.meshgrid(xs, ys, zs) | |
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) | |
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() | |
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val | |
# should save u here # | |
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy") | |
# np.save(save_u_path, u) | |
# print(f"u saved to {save_u_path}") | |
return u | |
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): | |
print('threshold: {}'.format(threshold)) | |
## using maching cubes ### | |
u = extract_fields(bound_min, bound_max, resolution, query_func) | |
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes # | |
b_max_np = bound_max.detach().cpu().numpy() | |
b_min_np = bound_min.detach().cpu().numpy() | |
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
### using maching cubes ### | |
### using marching tets ### | |
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func) | |
# vertices = vertices.detach().cpu().numpy() | |
# triangles = triangles.detach().cpu().numpy() | |
### using marching tets ### | |
# b_max_np = bound_max.detach().cpu().numpy() | |
# b_min_np = bound_min.detach().cpu().numpy() | |
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
return vertices, triangles | |
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None): | |
# print('threshold: {}'.format(threshold)) | |
### using maching cubes ### | |
# u = extract_fields(bound_min, bound_max, resolution, query_func) | |
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes # | |
# b_max_np = bound_max.detach().cpu().numpy() | |
# b_min_np = bound_min.detach().cpu().numpy() | |
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
### using maching cubes ### | |
## | |
### using marching tets ### fiels from tets ## | |
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func) | |
# vertices = vertices.detach().cpu().numpy() | |
# triangles = triangles.detach().cpu().numpy() | |
### using marching tets ### | |
# b_max_np = bound_max.detach().cpu().numpy() | |
# b_min_np = bound_min.detach().cpu().numpy() | |
# | |
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces | |
def sample_pdf(bins, weights, n_samples, det=False): | |
# This implementation is from NeRF | |
# Get pdf | |
weights = weights + 1e-5 # prevent nans | |
pdf = weights / torch.sum(weights, -1, keepdim=True) | |
cdf = torch.cumsum(pdf, -1) | |
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) | |
# Take uniform samples | |
if det: | |
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples) | |
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) | |
else: | |
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) | |
# Invert CDF # invert cdf # | |
u = u.contiguous() | |
inds = torch.searchsorted(cdf, u, right=True) | |
below = torch.max(torch.zeros_like(inds - 1), inds - 1) | |
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) | |
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) | |
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] | |
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
denom = (cdf_g[..., 1] - cdf_g[..., 0]) | |
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) | |
t = (u - cdf_g[..., 0]) / denom | |
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) | |
return samples | |
def load_GT_vertices(GT_meshes_folder): | |
tot_meshes_fns = os.listdir(GT_meshes_folder) | |
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")] | |
tot_mesh_verts = [] | |
tot_mesh_faces = [] | |
n_tot_verts = 0 | |
for fn in tot_meshes_fns: | |
cur_mesh_fn = os.path.join(GT_meshes_folder, fn) | |
obj_mesh = trimesh.load(cur_mesh_fn, process=False) | |
# obj_mesh.remove_degenerate_faces(height=1e-06) | |
verts_obj = np.array(obj_mesh.vertices) | |
faces_obj = np.array(obj_mesh.faces) | |
tot_mesh_verts.append(verts_obj) | |
tot_mesh_faces.append(faces_obj + n_tot_verts) | |
n_tot_verts += verts_obj.shape[0] | |
# tot_mesh_faces.append(faces_obj) | |
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0) | |
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0) | |
return tot_mesh_verts, tot_mesh_faces | |
class NeuSRenderer: | |
def __init__(self, | |
nerf, | |
sdf_network, | |
deviation_network, | |
color_network, | |
n_samples, | |
n_importance, | |
n_outside, | |
up_sample_steps, | |
perturb): | |
self.nerf = nerf # multiple sdf networks and deviation networks and xxx # | |
self.sdf_network = sdf_network | |
self.deviation_network = deviation_network | |
self.color_network = color_network | |
self.n_samples = n_samples | |
self.n_importance = n_importance | |
self.n_outside = n_outside | |
self.up_sample_steps = up_sample_steps | |
self.perturb = perturb | |
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand" | |
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder) | |
maxx_pts = 25. | |
minn_pts = -15. | |
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts) | |
f = SDF(self.mesh_vertices, self.mesh_faces) | |
self.gt_sdf = f ## a unite sphere or box | |
self.minn_pts = 0 | |
self.maxx_pts = 1. | |
# self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl | |
# self.maxx_pts = 1.5 # | |
self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs # | |
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame # | |
self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs # | |
self.bending_network = ... # TODO: add the bending network # | |
self.use_bending_network = ... # TODO: set the property # | |
self.use_delta_bending = ... # TODO | |
self.prev_sdf_network = ... # TODO | |
self.use_selector = False | |
# use bending network # | |
# two bending netwrok | |
# two sdf networks | |
# get the pts and render the pts # | |
# pts and the rendering pts # | |
def deform_pts(self, pts, pts_ts=0): # deform pts # | |
if self.use_bending_network: | |
if len(pts.size()) == 3: | |
nnb, nns = pts.size(0), pts.size(1) | |
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous() | |
else: | |
pts_exp = pts | |
# pts_ts # | |
if self.use_delta_bending: | |
if isinstance(self.bending_network, list): | |
pts_offsets = [] | |
for i_obj, cur_bending_network in enumerate(self.bending_network): | |
if isinstance(cur_bending_network, fields.BendingNetwork): | |
for cur_pts_ts in range(pts_ts, -1, -1): | |
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts) | |
elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans): | |
cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
else: | |
raise ValueError('Encountered with unexpected bending network class...') | |
pts_offsets.append(cur_pts_exp - pts_exp) | |
pts_offsets = torch.stack(pts_offsets, dim=0) | |
pts_offsets = torch.sum(pts_offsets, dim=0) | |
pts_exp = pts_exp + pts_offsets | |
# for cur_pts_ts in range(pts_ts, -1, -1): | |
# if isinstance(self.bending_network, list): # pts ts # | |
# for i_obj, cur_bending_network in enumerate(self.bending_network): | |
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
# else: | |
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
else: | |
if isinstance(self.bending_network, list): # prev sdf network # | |
pts_offsets = [] | |
for i_obj, cur_bending_network in enumerate(self.bending_network): | |
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts) | |
pts_offsets.append(bended_pts_exp - pts_exp) | |
pts_offsets = torch.stack(pts_offsets, dim=0) | |
pts_offsets = torch.sum(pts_offsets, dim=0) | |
pts_exp = pts_exp + pts_offsets | |
else: | |
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts) | |
if len(pts.size()) == 3: | |
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous() | |
else: | |
pts = pts_exp | |
return pts | |
# pts: nn_batch x nn_samples x 3 | |
if len(pts.size()) == 3: | |
nnb, nns = pts.size(0), pts.size(1) | |
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous() | |
else: | |
pts_exp = pts | |
# print(f"prior to deforming: {pts.size()}") | |
dist_pts_to_bkg_pts = torch.sum( | |
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts | |
) | |
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres # | |
dist_mask_float = dist_mask.float() | |
# dist_mask_float # | |
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous() | |
cur_fr_pts_def = torch.sum( | |
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1 | |
) | |
dist_mask_float_summ = torch.sum( | |
dist_mask_float, dim=1 | |
) | |
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1) | |
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation # | |
pts_exp = pts_exp - cur_fr_pts_def | |
if len(pts.size()) == 3: | |
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous() | |
else: | |
pts = pts_exp | |
return pts # | |
def deform_pts_with_selector(self, pts, pts_ts=0): # deform pts # | |
if self.use_bending_network: | |
if len(pts.size()) == 3: | |
nnb, nns = pts.size(0), pts.size(1) | |
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous() | |
else: | |
pts_exp = pts | |
# pts_ts # | |
if self.use_delta_bending: | |
if isinstance(self.bending_network, list): | |
bended_pts = [] | |
queries_sdfs_selector = [] | |
for i_obj, cur_bending_network in enumerate(self.bending_network): | |
if cur_bending_network.use_opt_rigid_translations: | |
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts) | |
else: | |
# bended_pts_exp = pts_exp.clone() | |
for cur_pts_ts in range(pts_ts, -1, -1): | |
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts) | |
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp) | |
bended_pts.append(bended_pts_exp) | |
queries_sdfs_selector.append(cur_bended_pts_selecotr) | |
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts # | |
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2 | |
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long() | |
sdf_selector = queries_sdfs_selector[:, -1] | |
# sdf_selector = queries_sdfs_selector | |
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp) | |
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 # | |
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}") | |
pts_exp = bended_pts.squeeze(1) | |
# for cur_pts_ts in range(pts_ts, -1, -1): | |
# if isinstance(self.bending_network, list): | |
# for i_obj, cur_bending_network in enumerate(self.bending_network): | |
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
# else: | |
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
else: | |
if isinstance(self.bending_network, list): # prev sdf network # | |
# pts_offsets = [] | |
bended_pts = [] | |
queries_sdfs_selector = [] | |
for i_obj, cur_bending_network in enumerate(self.bending_network): | |
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts) | |
# pts_offsets.append(bended_pts_exp - pts_exp) | |
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp) | |
bended_pts.append(bended_pts_exp) | |
queries_sdfs_selector.append(cur_bended_pts_selecotr) | |
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts # | |
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2 | |
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long() | |
sdf_selector = queries_sdfs_selector[:, -1] | |
# sdf_selector = queries_sdfs_selector | |
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp) | |
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 # | |
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}") | |
pts_exp = bended_pts.squeeze(1) | |
# pts_offsets = torch.stack(pts_offsets, dim=0) | |
# pts_offsets = torch.sum(pts_offsets, dim=0) | |
# pts_exp = pts_exp + pts_offsets | |
else: | |
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts) | |
if len(pts.size()) == 3: | |
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous() | |
else: | |
pts = pts_exp | |
return pts | |
# pts: nn_batch x nn_samples x 3 | |
if len(pts.size()) == 3: | |
nnb, nns = pts.size(0), pts.size(1) | |
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous() | |
else: | |
pts_exp = pts | |
# print(f"prior to deforming: {pts.size()}") | |
dist_pts_to_bkg_pts = torch.sum( | |
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts | |
) | |
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres # | |
dist_mask_float = dist_mask.float() | |
# dist_mask_float # | |
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous() | |
cur_fr_pts_def = torch.sum( | |
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1 | |
) | |
dist_mask_float_summ = torch.sum( | |
dist_mask_float, dim=1 | |
) | |
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1) | |
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation # | |
pts_exp = pts_exp - cur_fr_pts_def | |
if len(pts.size()) == 3: | |
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous() | |
else: | |
pts = pts_exp | |
return pts # | |
def deform_pts_passive(self, pts, pts_ts=0): | |
if self.use_bending_network: | |
if len(pts.size()) == 3: | |
nnb, nns = pts.size(0), pts.size(1) | |
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous() | |
else: | |
pts_exp = pts | |
# pts_ts # | |
if self.use_delta_bending: | |
for cur_pts_ts in range(pts_ts, -1, -1): | |
if isinstance(self.bending_network, list): | |
for i_obj, cur_bending_network in enumerate(self.bending_network): | |
pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
else: | |
pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts) | |
else: | |
# if isinstance(self.bending_network, list): | |
# pts_offsets = [] | |
# for i_obj, cur_bending_network in enumerate(self.bending_network): | |
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts) | |
# pts_offsets.append(bended_pts_exp - pts_exp) | |
# pts_offsets = torch.stack(pts_offsets, dim=0) | |
# pts_offsets = torch.sum(pts_offsets, dim=0) | |
# pts_exp = pts_exp + pts_offsets | |
# else: | |
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts) | |
if len(pts.size()) == 3: | |
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous() | |
else: | |
pts = pts_exp | |
return pts | |
# pts: nn_batch x nn_samples x 3 | |
if len(pts.size()) == 3: | |
nnb, nns = pts.size(0), pts.size(1) | |
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous() | |
else: | |
pts_exp = pts | |
# print(f"prior to deforming: {pts.size()}") | |
dist_pts_to_bkg_pts = torch.sum( | |
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts | |
) | |
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres # | |
dist_mask_float = dist_mask.float() | |
# dist_mask_float # | |
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous() | |
cur_fr_pts_def = torch.sum( | |
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1 | |
) | |
dist_mask_float_summ = torch.sum( | |
dist_mask_float, dim=1 | |
) | |
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1) | |
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation # | |
pts_exp = pts_exp - cur_fr_pts_def | |
if len(pts.size()) == 3: | |
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous() | |
else: | |
pts = pts_exp | |
return pts # | |
def query_pts_sdf_fn_for_selector(self, pts): | |
# for negative | |
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field | |
# 2) negative in thie field and also negative in the previous field ---> | |
# 2) for positive values of this current field ---> | |
cur_sdf = self.sdf_network.sdf(pts) | |
prev_sdf = self.prev_sdf_network.sdf(pts) | |
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5 | |
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5 | |
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5 | |
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5 | |
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5 | |
res_sdf = torch.zeros_like(cur_sdf) | |
res_sdf[neg_neg] = 1. # | |
res_sdf[neg_pos] = cur_sdf[neg_pos] | |
res_sdf[pos_neg] = cur_sdf[pos_neg] | |
# inside the residual mesh -> must be neg and pos | |
res_sdf_selector = torch.zeros_like(cur_sdf).long() # | |
# res_sdf_selector[neg_pos] = 1 # is the residual mesh | |
res_sdf_selector[neg_weq_pos] = 1 | |
# res_sdf_selector[] | |
cat_cur_prev_sdf = torch.stack( | |
[cur_sdf, prev_sdf], dim=-1 | |
) | |
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1) | |
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos] | |
return res_sdf, res_sdf_selector | |
def query_func_sdf(self, pts): | |
if isinstance(self.sdf_network, list): | |
tot_sdf_values = [] | |
for i_obj, cur_sdf_network in enumerate(self.sdf_network): | |
cur_sdf_values = cur_sdf_network.sdf(pts) | |
tot_sdf_values.append(cur_sdf_values) | |
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1) | |
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values # | |
sdf = tot_sdf_values | |
else: | |
sdf = self.sdf_network.sdf(pts) | |
return sdf | |
def query_func_sdf_passive(self, pts): | |
# if isinstance(self.sdf_network, list): | |
# tot_sdf_values = [] | |
# for i_obj, cur_sdf_network in enumerate(self.sdf_network): | |
# cur_sdf_values = cur_sdf_network.sdf(pts) | |
# tot_sdf_values.append(cur_sdf_values) | |
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1) | |
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values # | |
# sdf = tot_sdf_values | |
# else: | |
sdf = self.sdf_network[-1].sdf(pts) | |
return sdf | |
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0): | |
""" | |
Render background | |
""" | |
batch_size, n_samples = z_vals.shape | |
# Section length | |
dists = z_vals[..., 1:] - z_vals[..., :-1] | |
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) | |
mid_z_vals = z_vals + dists * 0.5 | |
# Section midpoints # | |
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 # | |
# pts = pts.flip((-1,)) * 2 - 1 | |
pts = pts * 2 - 1 | |
if self.use_selector: | |
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts) | |
else: | |
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) | |
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10) | |
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 # | |
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) | |
pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) | |
dirs = dirs.reshape(-1, 3) | |
density, sampled_color = nerf(pts, dirs) | |
sampled_color = torch.sigmoid(sampled_color) | |
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists) | |
alpha = alpha.reshape(batch_size, n_samples) | |
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] | |
sampled_color = sampled_color.reshape(batch_size, n_samples, 3) | |
color = (weights[:, :, None] * sampled_color).sum(dim=1) | |
if background_rgb is not None: | |
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) | |
return { | |
'color': color, | |
'sampled_color': sampled_color, | |
'alpha': alpha, | |
'weights': weights, | |
} | |
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0): | |
""" | |
Up sampling give a fixed inv_s | |
""" | |
batch_size, n_samples = z_vals.shape | |
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 | |
# pts = pts.flip((-1,)) * 2 - 1 | |
pts = pts * 2 - 1 | |
if self.use_selector: | |
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts) | |
else: | |
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) | |
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) | |
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) | |
sdf = sdf.reshape(batch_size, n_samples) | |
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] | |
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] | |
mid_sdf = (prev_sdf + next_sdf) * 0.5 | |
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) | |
# ---------------------------------------------------------------------------------------------------------- | |
# Use min value of [ cos, prev_cos ] | |
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more | |
# robust when meeting situations like below: | |
# | |
# SDF | |
# ^ | |
# |\ -----x----... | |
# | \ / | |
# | x x | |
# |---\----/-------------> 0 level | |
# | \ / | |
# | \/ | |
# | | |
# ---------------------------------------------------------------------------------------------------------- | |
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1) | |
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) | |
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) | |
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere | |
dist = (next_z_vals - prev_z_vals) | |
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 | |
next_esti_sdf = mid_sdf + cos_val * dist * 0.5 | |
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) | |
next_cdf = torch.sigmoid(next_esti_sdf * inv_s) | |
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) | |
weights = alpha * torch.cumprod( | |
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] | |
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() | |
return z_samples | |
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0): | |
batch_size, n_samples = z_vals.shape | |
_, n_importance = new_z_vals.shape | |
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] | |
# pts = pts.flip((-1,)) * 2 - 1 | |
pts = pts * 2 - 1 | |
if self.use_selector: | |
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts) | |
else: | |
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) | |
z_vals = torch.cat([z_vals, new_z_vals], dim=-1) | |
z_vals, index = torch.sort(z_vals, dim=-1) | |
if not last: | |
if isinstance(self.sdf_network, list): | |
tot_new_sdf = [] | |
for i_obj, cur_sdf_network in enumerate(self.sdf_network): | |
cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) | |
tot_new_sdf.append(cur_new_sdf) | |
tot_new_sdf = torch.stack(tot_new_sdf, dim=-1) | |
new_sdf, _ = torch.min(tot_new_sdf, dim=-1) # | |
else: | |
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) | |
sdf = torch.cat([sdf, new_sdf], dim=-1) | |
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) | |
index = index.reshape(-1) | |
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) | |
return z_vals, sdf | |
def render_core(self, | |
rays_o, | |
rays_d, | |
z_vals, | |
sample_dist, | |
sdf_network, | |
deviation_network, | |
color_network, | |
background_alpha=None, | |
background_sampled_color=None, | |
background_rgb=None, | |
cos_anneal_ratio=0.0, | |
pts_ts=0): | |
batch_size, n_samples = z_vals.shape | |
# Section length | |
dists = z_vals[..., 1:] - z_vals[..., :-1] | |
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) | |
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 # | |
# Section midpoints | |
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 | |
dirs = rays_d[:, None, :].expand(pts.shape) | |
pts = pts.reshape(-1, 3) # pts, nn_ou | |
dirs = dirs.reshape(-1, 3) | |
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts) | |
# pts = pts.flip((-1,)) * 2 - 1 | |
pts = pts * 2 - 1 | |
if self.use_selector: | |
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts) | |
else: | |
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) | |
if isinstance(sdf_network, list): | |
tot_sdf = [] | |
tot_feature_vector = [] | |
tot_obj_sel = [] | |
tot_gradients = [] | |
for i_obj, cur_sdf_network in enumerate(sdf_network): | |
cur_sdf_nn_output = cur_sdf_network(pts) | |
cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:] | |
tot_sdf.append(cur_sdf) | |
tot_feature_vector.append(cur_feature_vector) | |
gradients = cur_sdf_network.gradient(pts).squeeze() | |
tot_gradients.append(gradients) | |
tot_sdf = torch.stack(tot_sdf, dim=-1) | |
sdf, obj_sel = torch.min(tot_sdf, dim=-1) | |
feature_vector = torch.stack(tot_feature_vector, dim=1) | |
# batched_index_select | |
# print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}") | |
feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1) | |
# feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1) | |
# print(f"after sel: {feature_vector.size()}") | |
tot_gradients = torch.stack(tot_gradients, dim=1) | |
# gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1) | |
gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1) | |
# print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}") | |
else: | |
sdf_nn_output = sdf_network(pts) | |
sdf = sdf_nn_output[:, :1] | |
feature_vector = sdf_nn_output[:, 1:] | |
gradients = sdf_network.gradient(pts).squeeze() | |
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3) | |
# deviation network # | |
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter | |
inv_s = inv_s.expand(batch_size * n_samples, 1) | |
true_cos = (dirs * gradients).sum(-1, keepdim=True) | |
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes | |
# the cos value "not dead" at the beginning training iterations, for better convergence. | |
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + | |
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive | |
# Estimate signed distances at section points | |
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 | |
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 | |
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) | |
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) | |
p = prev_cdf - next_cdf | |
c = prev_cdf | |
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) | |
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) | |
inside_sphere = (pts_norm < 1.0).float().detach() | |
relax_inside_sphere = (pts_norm < 1.2).float().detach() | |
# Render with background | |
if background_alpha is not None: | |
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) | |
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) | |
sampled_color = sampled_color * inside_sphere[:, :, None] +\ | |
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] | |
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) | |
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] | |
weights_sum = weights.sum(dim=-1, keepdim=True) | |
color = (sampled_color * weights[:, :, None]).sum(dim=1) | |
if background_rgb is not None: # Fixed background, usually black | |
color = color + background_rgb * (1.0 - weights_sum) | |
# Eikonal loss | |
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, | |
dim=-1) - 1.0) ** 2 | |
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) | |
return { | |
'color': color, | |
'sdf': sdf, | |
'dists': dists, | |
'gradients': gradients.reshape(batch_size, n_samples, 3), | |
's_val': 1.0 / inv_s, | |
'mid_z_vals': mid_z_vals, | |
'weights': weights, | |
'cdf': c.reshape(batch_size, n_samples), | |
'gradient_error': gradient_error, | |
'inside_sphere': inside_sphere | |
} | |
def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False): | |
batch_size = len(rays_o) | |
sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere | |
z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace # | |
z_vals = near + (far - near) * z_vals[None, :] | |
z_vals_outside = None | |
if self.n_outside > 0: | |
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) | |
n_samples = self.n_samples | |
perturb = self.perturb | |
if perturb_overwrite >= 0: | |
perturb = perturb_overwrite | |
if perturb > 0: | |
t_rand = (torch.rand([batch_size, 1]) - 0.5) | |
z_vals = z_vals + t_rand * 2.0 / self.n_samples | |
if self.n_outside > 0: # z values output # n_outside # | |
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) | |
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) | |
lower = torch.cat([z_vals_outside[..., :1], mids], -1) | |
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) | |
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand | |
if self.n_outside > 0: | |
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples | |
background_alpha = None | |
background_sampled_color = None | |
# Up sample | |
if self.n_importance > 0: | |
with torch.no_grad(): | |
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] | |
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts) | |
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) | |
# gt_sdf # | |
# | |
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1 | |
# pts = pts.flip((-1,)) * 2 - 1 | |
pts = pts * 2 - 1 | |
if self.use_selector: | |
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts) | |
else: | |
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts | |
pts_exp = pts.reshape(-1, 3) | |
# minn_pts, _ = torch.min(pts_exp, dim=0) | |
# maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes # | |
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}") | |
# pts_to_near = pts - near.unsqueeze(1) | |
# maxx_pts = 1.5; minn_pts = -1.5 | |
# # maxx_pts = 3; minn_pts = -3 | |
# # maxx_pts = 1; minn_pts = -1 | |
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts) | |
## render and iamges #### | |
if use_gt_sdf: | |
### use the GT sdf field #### | |
# print(f"Using gt sdf :") | |
sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy()) | |
sdf = torch.from_numpy(sdf).float().cuda() | |
sdf = sdf.reshape(batch_size, self.n_samples) | |
### use the GT sdf field #### | |
else: | |
# pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s) | |
#### use the optimized sdf field #### | |
# sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples) | |
if isinstance(self.sdf_network, list): | |
tot_sdf_values = [] | |
for i_obj, cur_sdf_network in enumerate(self.sdf_network): | |
cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples) | |
tot_sdf_values.append(cur_sdf_values) | |
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1) | |
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values # | |
sdf = tot_sdf_values | |
else: | |
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples) | |
#### use the optimized sdf field #### | |
for i in range(self.up_sample_steps): | |
new_z_vals = self.up_sample(rays_o, | |
rays_d, | |
z_vals, | |
sdf, | |
self.n_importance // self.up_sample_steps, | |
64 * 2**i, | |
pts_ts=pts_ts) | |
z_vals, sdf = self.cat_z_vals(rays_o, | |
rays_d, | |
z_vals, | |
new_z_vals, | |
sdf, | |
last=(i + 1 == self.up_sample_steps), | |
pts_ts=pts_ts) | |
n_samples = self.n_samples + self.n_importance | |
# Background model | |
if self.n_outside > 0: | |
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) | |
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) | |
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts) | |
background_sampled_color = ret_outside['sampled_color'] | |
background_alpha = ret_outside['alpha'] | |
# Render core | |
ret_fine = self.render_core(rays_o, # | |
rays_d, | |
z_vals, | |
sample_dist, | |
self.sdf_network, | |
self.deviation_network, | |
self.color_network, | |
background_rgb=background_rgb, | |
background_alpha=background_alpha, | |
background_sampled_color=background_sampled_color, | |
cos_anneal_ratio=cos_anneal_ratio, | |
pts_ts=pts_ts) | |
color_fine = ret_fine['color'] | |
weights = ret_fine['weights'] | |
weights_sum = weights.sum(dim=-1, keepdim=True) | |
gradients = ret_fine['gradients'] | |
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True) | |
return { | |
'color_fine': color_fine, | |
's_val': s_val, | |
'cdf_fine': ret_fine['cdf'], | |
'weight_sum': weights_sum, | |
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], | |
'gradients': gradients, | |
'weights': weights, | |
'gradient_error': ret_fine['gradient_error'], | |
'inside_sphere': ret_fine['inside_sphere'] | |
} | |
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): | |
return extract_geometry(bound_min, # extract geometry # | |
bound_max, | |
resolution=resolution, | |
threshold=threshold, | |
# query_func=lambda pts: -self.sdf_network.sdf(pts), | |
query_func=lambda pts: -self.query_func_sdf(pts) | |
) | |
# if self.deform_pts_with_selector: | |
# pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts) | |
def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False): | |
if wdef: | |
return extract_geometry_tets(bound_min, # extract geometry # | |
bound_max, | |
resolution=resolution, | |
threshold=threshold, | |
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts), | |
def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)) | |
else: | |
return extract_geometry_tets(bound_min, # extract geometry # | |
bound_max, | |
resolution=resolution, | |
threshold=threshold, | |
# query_func=lambda pts: -self.sdf_network.sdf(pts) | |
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts), | |
) | |
def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False): | |
if wdef: | |
return extract_geometry_tets(bound_min, # extract geometry # | |
bound_max, | |
resolution=resolution, | |
threshold=threshold, | |
query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts), | |
def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts)) | |
else: | |
return extract_geometry_tets(bound_min, # extract geometry # | |
bound_max, | |
resolution=resolution, | |
threshold=threshold, | |
# query_func=lambda pts: -self.sdf_network.sdf(pts) | |
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts), | |
) | |