quasi-physical-sims / models /renderer_def_multi_objs_compositional.py
meow
a
710e818
raw
history blame
83.4 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import mcubes
from icecream import ic
import os
import trimesh
from pysdf import SDF
import models.fields as fields
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def create_mt_variable(device):
triangle_table = torch.tensor(
[
[-1, -1, -1, -1, -1, -1],
[1, 0, 2, -1, -1, -1],
[4, 0, 3, -1, -1, -1],
[1, 4, 2, 1, 3, 4],
[3, 1, 5, -1, -1, -1],
[2, 3, 0, 2, 5, 3],
[1, 4, 0, 1, 5, 4],
[4, 2, 5, -1, -1, -1],
[4, 5, 2, -1, -1, -1],
[4, 1, 0, 4, 5, 1],
[3, 2, 0, 3, 5, 2],
[1, 3, 5, -1, -1, -1],
[4, 1, 2, 4, 3, 1],
[3, 0, 4, -1, -1, -1],
[2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
], dtype=torch.long, device=device)
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
return triangle_table, num_triangles_table, base_tet_edges, v_id
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
# load tet via resolution #
# scale them via bounds #
# extract the geometry #
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
device = bound_min.device
# if resolution in [64, 70, 80, 90, 100]:
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
# else:
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
if not os.path.exists(tet_fn):
tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
tets = np.load(tet_fn)
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
# split #
# verts; verts; #
minn_verts, _ = torch.min(verts, dim=0)
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
# scale_verts = maxx_verts - minn_verts
scale_bounds = bound_max - bound_min # scale bounds #
### scale the vertices ###
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
sdf_values = []
N = 64
query_bundles = N ** 3 ### N^3
query_NNs = scaled_verts.size(0) // query_bundles
if query_NNs * query_bundles < scaled_verts.size(0):
query_NNs += 1
for i_query in range(query_NNs):
cur_bundle_st = i_query * query_bundles
cur_bundle_ed = (i_query + 1) * query_bundles
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
if def_func is not None:
cur_query_pts = def_func(cur_query_pts)
cur_query_vals = query_func(cur_query_pts)
sdf_values.append(cur_query_vals)
sdf_values = torch.cat(sdf_values, dim=0)
# print(f"queryed sdf values: {sdf_values.size()}") #
# GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
if not os.path.exists(gt_sdf_fn):
gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
# intrinsic, tet values, pts values, sdf network #
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
tet_table, num_tets_table = create_tetmesh_variables(device)
sdf_values = sdf_values.squeeze(-1) # how the rendering #
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
# print(f"scaled_verts: {scaled_verts.size()}, ")
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
# marching_tets_tetmesh ##
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
### use the GT sdf values for the marching tets ###
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
def extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=None):
# load tet via resolution #
# scale them via bounds #
# extract the geometry #
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
device = bound_min.device
# if resolution in [64, 70, 80, 90, 100]:
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
# else:
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
if not os.path.exists(tet_fn):
tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
tets = np.load(tet_fn)
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
# split #
# verts; verts; #
minn_verts, _ = torch.min(verts, dim=0)
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
# scale_verts = maxx_verts - minn_verts
scale_bounds = bound_max - bound_min # scale bounds #
### scale the vertices ###
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
sdf_values = []
N = 64
query_bundles = N ** 3 ### N^3
query_NNs = scaled_verts.size(0) // query_bundles
if query_NNs * query_bundles < scaled_verts.size(0):
query_NNs += 1
for i_query in range(query_NNs):
cur_bundle_st = i_query * query_bundles
cur_bundle_ed = (i_query + 1) * query_bundles
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
if def_func is not None:
cur_query_pts, _ = def_func(cur_query_pts)
cur_query_vals = query_func(cur_query_pts)
sdf_values.append(cur_query_vals)
sdf_values = torch.cat(sdf_values, dim=0)
# print(f"queryed sdf values: {sdf_values.size()}") #
# GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
if not os.path.exists(gt_sdf_fn):
gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
# intrinsic, tet values, pts values, sdf network #
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
tet_table, num_tets_table = create_tetmesh_variables(device)
sdf_values = sdf_values.squeeze(-1) # how the rendering #
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
# print(f"scaled_verts: {scaled_verts.size()}, ")
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
# marching_tets_tetmesh ##
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
### use the GT sdf values for the marching tets ###
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
def extract_fields(bound_min, bound_max, resolution, query_func):
N = 64
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
with torch.no_grad():
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = torch.meshgrid(xs, ys, zs)
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
# should save u here #
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
# np.save(save_u_path, u) #
# print(f"u saved to {save_u_path}")
return u
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
print('threshold: {}'.format(threshold))
## using maching cubes ###
u = extract_fields(bound_min, bound_max, resolution, query_func)
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
b_max_np = bound_max.detach().cpu().numpy()
b_min_np = bound_min.detach().cpu().numpy()
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
### using maching cubes ###
### using marching tets ###
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
# vertices = vertices.detach().cpu().numpy()
# triangles = triangles.detach().cpu().numpy()
### using marching tets ###
# b_max_np = bound_max.detach().cpu().numpy()
# b_min_np = bound_min.detach().cpu().numpy()
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
return vertices, triangles
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None, selector=False):
# print('threshold: {}'.format(threshold))
### using maching cubes ###
# u = extract_fields(bound_min, bound_max, resolution, query_func)
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
# b_max_np = bound_max.detach().cpu().numpy()
# b_min_np = bound_min.detach().cpu().numpy()
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
### using maching cubes ###
##
### using marching tets ### fiels from tets ##
if selector:
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=def_func)
else:
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
# vertices = vertices.detach().cpu().numpy()
# triangles = triangles.detach().cpu().numpy()
### using marching tets ###
# b_max_np = bound_max.detach().cpu().numpy()
# b_min_np = bound_min.detach().cpu().numpy()
#
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
### sample pdfs ###
def sample_pdf(bins, weights, n_samples, det=False):
# This implementation is from NeRF
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
# Invert CDF # invert cdf #
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def load_GT_vertices(GT_meshes_folder):
tot_meshes_fns = os.listdir(GT_meshes_folder)
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
tot_mesh_verts = []
tot_mesh_faces = []
n_tot_verts = 0
for fn in tot_meshes_fns:
cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
obj_mesh = trimesh.load(cur_mesh_fn, process=False)
# obj_mesh.remove_degenerate_faces(height=1e-06)
verts_obj = np.array(obj_mesh.vertices)
faces_obj = np.array(obj_mesh.faces)
tot_mesh_verts.append(verts_obj)
tot_mesh_faces.append(faces_obj + n_tot_verts)
n_tot_verts += verts_obj.shape[0]
# tot_mesh_faces.append(faces_obj)
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
return tot_mesh_verts, tot_mesh_faces
class NeuSRenderer:
def __init__(self,
nerf,
sdf_network,
deviation_network,
color_network,
n_samples,
n_importance,
n_outside,
up_sample_steps,
perturb):
self.nerf = nerf #
self.sdf_network = sdf_network
self.deviation_network = deviation_network
self.color_network = color_network
self.n_samples = n_samples
self.n_importance = n_importance
self.n_outside = n_outside
self.up_sample_steps = up_sample_steps
self.perturb = perturb
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
if not os.path.exists(GT_meshes_folder):
GT_meshes_folder = "/data/xueyi/diffsim/DiffHand/assets/hand"
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
maxx_pts = 25.
minn_pts = -15.
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
f = SDF(self.mesh_vertices, self.mesh_faces)
self.gt_sdf = f ## a unite sphere or box
self.minn_pts = 0
self.maxx_pts = 1.
# self.minn_pts = -1.5 #
# self.maxx_pts = 1.5 #
self.bkg_pts = ... # TODO
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
self.bending_network = ... # TODO: add the bending network #
self.use_bending_network = ... # TODO: set the property #
self.use_delta_bending = ... # TODO
self.prev_sdf_network = ... # TODO
self.use_selector = False
self.timestep_to_passive_mesh = ... # TODO
self.timestep_to_active_mesh = ... # TODO
def deform_pts(self, pts, pts_ts=0, update_tot_def=True): # deform pts #
if self.use_bending_network:
if len(pts.size()) == 3:
nnb, nns = pts.size(0), pts.size(1)
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
else:
pts_exp = pts
# pts_ts #
if self.use_delta_bending:
if isinstance(self.bending_network, list):
pts_offsets = []
for i_obj, cur_bending_network in enumerate(self.bending_network):
if isinstance(cur_bending_network, fields.BendingNetwork):
for cur_pts_ts in range(pts_ts, -1, -1):
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
else:
raise ValueError('Encountered with unexpected bending network class...')
pts_offsets.append(cur_pts_exp - pts_exp)
pts_offsets = torch.stack(pts_offsets, dim=0)
pts_offsets = torch.sum(pts_offsets, dim=0)
pts_exp = pts_exp + pts_offsets
# for cur_pts_ts in range(pts_ts, -1, -1):
# if isinstance(self.bending_network, list): # pts ts #
# for i_obj, cur_bending_network in enumerate(self.bending_network):
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
# else:
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
else:
if isinstance(self.bending_network, list): # prev sdf network #
pts_offsets = []
for i_obj, cur_bending_network in enumerate(self.bending_network):
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
pts_offsets.append(bended_pts_exp - pts_exp)
pts_offsets = torch.stack(pts_offsets, dim=0)
pts_offsets = torch.sum(pts_offsets, dim=0)
pts_exp = pts_exp + pts_offsets
else:
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
if len(pts.size()) == 3:
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
else:
pts = pts_exp
return pts
# pts: nn_batch x nn_samples x 3
if len(pts.size()) == 3:
nnb, nns = pts.size(0), pts.size(1)
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
else:
pts_exp = pts
# print(f"prior to deforming: {pts.size()}")
dist_pts_to_bkg_pts = torch.sum(
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
)
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
dist_mask_float = dist_mask.float()
# dist_mask_float #
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
cur_fr_pts_def = torch.sum(
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
)
dist_mask_float_summ = torch.sum(
dist_mask_float, dim=1
)
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
pts_exp = pts_exp - cur_fr_pts_def
if len(pts.size()) == 3:
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
else:
pts = pts_exp
return pts #
def deform_pts_with_selector(self, pts, pts_ts=0, update_tot_def=True): # deform pts #
if self.use_bending_network:
if len(pts.size()) == 3:
nnb, nns = pts.size(0), pts.size(1)
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
else:
pts_exp = pts
# pts_ts #
if self.use_delta_bending:
if isinstance(self.bending_network, list):
bended_pts = []
queries_sdfs_selector = []
for i_obj, cur_bending_network in enumerate(self.bending_network):
if cur_bending_network.use_opt_rigid_translations:
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
else:
# bended_pts_exp = pts_exp.clone()
if i_obj == 1 and pts_ts == 0:
bended_pts_exp = pts_exp
elif i_obj == 1:
for cur_pts_ts in range(pts_ts, 0, -1): ### before 0 ###
if isinstance(cur_bending_network, fields.BendingNetwork):
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
elif isinstance(cur_bending_network, fields.BendingNetworkForceForward) or isinstance(cur_bending_network, fields.BendingNetworkRigidTransForward):
# input_pts, input_pts_ts, timestep_to_passive_mesh, act_sdf_net=None, details=None, special_loss_return=False
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh)
elif isinstance(cur_bending_network, fields.BendingNetworkForceFieldForward):
# input_pts, input_pts_ts, timestep_to_passive_mesh, passive_sdf_net, details=None, special_loss_return=False
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForward):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV2):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV3):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV4):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV5):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV6):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV7):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV8):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0], update_tot_def=update_tot_def)
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV9):
# active_bending_net, active_sdf_net,
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0], update_tot_def=update_tot_def)
else:
raise ValueError(f"Unrecognized bending network type: {type(cur_bending_network)}")
else:
for cur_pts_ts in range(pts_ts, -1, -1):
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
# _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector_ndelta(bended_pts_exp, i_net=i_obj)
bended_pts.append(bended_pts_exp)
queries_sdfs_selector.append(cur_bended_pts_selecotr)
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts # # bended_pts #
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
#### get the final sdf_selector from queries_sdfs_selector ####
# sdf_selector = queries_sdfs_selector[:, -1]
# neg_neg = ((queries_sdfs_selector[:, 0] == 0).float() + (queries_sdfs_selector[:, -1] == 1).float()) > 1.5 #### both inside of the object
sdf_selector = 1 - queries_sdfs_selector[:, 0]
# neg_neg
# sdf_selector = queries_sdfs_selector
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
# pts_exp = bended_pts.squeeze(1)
pts_exp = bended_pts
# for cur_pts_ts in range(pts_ts, -1, -1):
# if isinstance(self.bending_network, list):
# for i_obj, cur_bending_network in enumerate(self.bending_network):
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
# else:
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
else:
if isinstance(self.bending_network, list): # prev sdf network #
# pts_offsets = []
bended_pts = []
queries_sdfs_selector = []
for i_obj, cur_bending_network in enumerate(self.bending_network):
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
# pts_offsets.append(bended_pts_exp - pts_exp)
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
bended_pts.append(bended_pts_exp)
queries_sdfs_selector.append(cur_bended_pts_selecotr)
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
sdf_selector = queries_sdfs_selector[:, -1]
# sdf_selector = queries_sdfs_selector
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
pts_exp = bended_pts.squeeze(1)
# pts_offsets = torch.stack(pts_offsets, dim=0)
# pts_offsets = torch.sum(pts_offsets, dim=0)
# pts_exp = pts_exp + pts_offsets
else:
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
if len(pts.size()) == 3:
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
else:
pts = pts_exp
return pts, sdf_selector
else:
# pts: nn_batch x nn_samples x 3
if len(pts.size()) == 3:
nnb, nns = pts.size(0), pts.size(1)
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
else:
pts_exp = pts
# print(f"prior to deforming: {pts.size()}")
dist_pts_to_bkg_pts = torch.sum(
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
)
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
dist_mask_float = dist_mask.float()
# dist_mask_float #
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
cur_fr_pts_def = torch.sum(
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
)
dist_mask_float_summ = torch.sum(
dist_mask_float, dim=1
)
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
pts_exp = pts_exp - cur_fr_pts_def
if len(pts.size()) == 3:
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
else:
pts = pts_exp
return pts #
def deform_pts_passive(self, pts, pts_ts=0):
if self.use_bending_network:
if len(pts.size()) == 3:
nnb, nns = pts.size(0), pts.size(1)
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
else:
pts_exp = pts
# pts_ts #
if self.use_delta_bending:
if pts_ts > 0:
for cur_pts_ts in range(pts_ts, 0, -1):
# if isinstance(self.bending_network, list):
# for i_obj, cur_bending_network in enumerate(self.bending_network):
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
# else:
if isinstance(self.bending_network[-1], fields.BendingNetwork):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts)
elif isinstance(self.bending_network[-1], fields.BendingNetworkForceForward) or isinstance(self.bending_network[-1], fields.BendingNetworkRigidTransForward):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh)
elif isinstance(self.bending_network[-1], fields.BendingNetworkForceFieldForward):
# input_pts, input_pts_ts, timestep_to_passive_mesh, passive_sdf_net, details=None, special_loss_return=False
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForward):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV2):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV3):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV4):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV5):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV6):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV7):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV8):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV9):
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
else:
raise ValueError(f"Unrecognized bending network type: {type(self.bending_network[-1])}")
# pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts)
else:
# if isinstance(self.bending_network, list):
# pts_offsets = []
# for i_obj, cur_bending_network in enumerate(self.bending_network):
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
# pts_offsets.append(bended_pts_exp - pts_exp)
# pts_offsets = torch.stack(pts_offsets, dim=0)
# pts_offsets = torch.sum(pts_offsets, dim=0)
# pts_exp = pts_exp + pts_offsets
# else:
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
if len(pts.size()) == 3:
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
else:
pts = pts_exp
return pts
# pts: nn_batch x nn_samples x 3
if len(pts.size()) == 3:
nnb, nns = pts.size(0), pts.size(1)
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
else:
pts_exp = pts
# print(f"prior to deforming: {pts.size()}")
dist_pts_to_bkg_pts = torch.sum(
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
)
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
dist_mask_float = dist_mask.float()
# dist_mask_float #
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
cur_fr_pts_def = torch.sum(
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
)
dist_mask_float_summ = torch.sum(
dist_mask_float, dim=1
)
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
pts_exp = pts_exp - cur_fr_pts_def
if len(pts.size()) == 3:
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
else:
pts = pts_exp
return pts #
# delta mesh as passive mesh #
def query_pts_sdf_fn_for_selector(self, pts):
# for negative
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
# 2) negative in thie field and also negative in the previous field --->
# 2) for positive values of this current field --->
cur_sdf = self.sdf_network.sdf(pts)
prev_sdf = self.prev_sdf_network.sdf(pts)
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
res_sdf = torch.zeros_like(cur_sdf)
res_sdf[neg_neg] = 1. #
res_sdf[neg_pos] = cur_sdf[neg_pos]
res_sdf[pos_neg] = cur_sdf[pos_neg]
# inside the residual mesh -> must be neg and pos
res_sdf_selector = torch.zeros_like(cur_sdf).long() #
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
res_sdf_selector[neg_weq_pos] = 1
# res_sdf_selector[]
cat_cur_prev_sdf = torch.stack(
[cur_sdf, prev_sdf], dim=-1
)
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
return res_sdf, res_sdf_selector
def query_pts_sdf_fn_for_selector_ndelta(self, pts, i_net):
# for negative
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
# 2) negative in thie field and also negative in the previous field --->
# 2) for positive values of this current field --->
passive_sdf = self.sdf_network[i_net].sdf(pts).squeeze(-1)
passive_sdf_selector = torch.zeros_like(passive_sdf).long()
passive_sdf_selector[passive_sdf <= 0.] = 1.
return passive_sdf, passive_sdf_selector
cur_sdf = self.sdf_network.sdf(pts)
prev_sdf = self.prev_sdf_network.sdf(pts)
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
res_sdf = torch.zeros_like(cur_sdf)
res_sdf[neg_neg] = 1. #
res_sdf[neg_pos] = cur_sdf[neg_pos]
res_sdf[pos_neg] = cur_sdf[pos_neg]
# inside the residual mesh -> must be neg and pos
res_sdf_selector = torch.zeros_like(cur_sdf).long() #
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
res_sdf_selector[neg_weq_pos] = 1
# res_sdf_selector[]
cat_cur_prev_sdf = torch.stack(
[cur_sdf, prev_sdf], dim=-1
)
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
return res_sdf, res_sdf_selector
def query_func_sdf(self, pts):
if isinstance(self.sdf_network, list):
tot_sdf_values = []
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
cur_sdf_values = cur_sdf_network.sdf(pts)
tot_sdf_values.append(cur_sdf_values)
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
sdf = tot_sdf_values
else:
sdf = self.sdf_network.sdf(pts)
return sdf
def query_func_sdf_passive(self, pts):
# if isinstance(self.sdf_network, list):
# tot_sdf_values = []
# for i_obj, cur_sdf_network in enumerate(self.sdf_network):
# cur_sdf_values = cur_sdf_network.sdf(pts)
# tot_sdf_values.append(cur_sdf_values)
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
# sdf = tot_sdf_values
# else:
sdf = self.sdf_network[-1].sdf(pts)
return sdf
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
"""
Render background
"""
batch_size, n_samples = z_vals.shape
# Section length
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
mid_z_vals = z_vals + dists * 0.5
# Section midpoints #
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
# pts = pts.flip((-1,)) * 2 - 1
pts = pts * 2 - 1
if self.use_selector:
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
else:
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) ### deformed_pts ###
dirs = dirs.reshape(-1, 3)
if self.use_selector:
tot_density, tot_sampled_color = [], []
for i_nerf, cur_nerf in enumerate(nerf):
cur_density, cur_sampled_color = cur_nerf(pts, dirs)
tot_density.append(cur_density)
tot_sampled_color.append(cur_sampled_color)
tot_density = torch.stack(tot_density, dim=1)
tot_sampled_color = torch.stack(tot_sampled_color, dim=1) ### sampled colors
# print(f"tot_density: {tot_density.size()}, tot_sampled_color: {tot_sampled_color.size()}, sdf_selector: {sdf_selector.size()}")
density = batched_index_select(values=tot_density, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
sampled_color = batched_index_select(values=tot_sampled_color, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
else:
density, sampled_color = nerf(pts, dirs)
sampled_color = torch.sigmoid(sampled_color)
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
alpha = alpha.reshape(batch_size, n_samples)
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
color = (weights[:, :, None] * sampled_color).sum(dim=1)
if background_rgb is not None:
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
return {
'color': color,
'sampled_color': sampled_color,
'alpha': alpha,
'weights': weights,
}
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
"""
Up sampling give a fixed inv_s
"""
batch_size, n_samples = z_vals.shape
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
# pts = pts.flip((-1,)) * 2 - 1
pts = pts * 2 - 1
if self.use_selector:
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
else:
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
sdf = sdf.reshape(batch_size, n_samples)
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
mid_sdf = (prev_sdf + next_sdf) * 0.5
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
# ----------------------------------------------------------------------------------------------------------
# Use min value of [ cos, prev_cos ]
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
# robust when meeting situations like below:
#
# SDF
# ^
# |\ -----x----...
# | \ /
# | x x
# |---\----/-------------> 0 level
# | \ /
# | \/
# |
# ----------------------------------------------------------------------------------------------------------
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
dist = (next_z_vals - prev_z_vals)
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
weights = alpha * torch.cumprod(
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
return z_samples
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
batch_size, n_samples = z_vals.shape
_, n_importance = new_z_vals.shape
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
# pts = pts.flip((-1,)) * 2 - 1
pts = pts * 2 - 1
if self.use_selector:
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
else:
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
z_vals, index = torch.sort(z_vals, dim=-1)
if not last:
if isinstance(self.sdf_network, list):
tot_new_sdf = []
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
tot_new_sdf.append(cur_new_sdf)
tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
else:
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
sdf = torch.cat([sdf, new_sdf], dim=-1)
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
index = index.reshape(-1)
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
return z_vals, sdf
def render_core(self,
rays_o,
rays_d,
z_vals,
sample_dist,
sdf_network,
deviation_network,
color_network,
background_alpha=None,
background_sampled_color=None,
background_rgb=None,
cos_anneal_ratio=0.0,
pts_ts=0):
batch_size, n_samples = z_vals.shape
# Section length
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
# Section midpoints
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
dirs = rays_d[:, None, :].expand(pts.shape)
pts = pts.reshape(-1, 3) # pts, nn_ou
dirs = dirs.reshape(-1, 3)
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
# pts = pts.flip((-1,)) * 2 - 1
pts = pts * 2 - 1
if self.use_selector:
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
else:
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
if isinstance(sdf_network, list):
tot_sdf = []
tot_feature_vector = []
tot_obj_sel = []
tot_gradients = []
for i_obj, cur_sdf_network in enumerate(sdf_network):
cur_sdf_nn_output = cur_sdf_network(pts)
cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
tot_sdf.append(cur_sdf)
tot_feature_vector.append(cur_feature_vector)
gradients = cur_sdf_network.gradient(pts).squeeze()
tot_gradients.append(gradients)
tot_sdf = torch.stack(tot_sdf, dim=-1)
#
if self.use_selector:
sdf = batched_index_select(tot_sdf, sdf_selector.unsqueeze(1).unsqueeze(1), dim=2).squeeze(-1)
obj_sel = sdf_selector.unsqueeze(1)
else:
sdf, obj_sel = torch.min(tot_sdf, dim=-1)
feature_vector = torch.stack(tot_feature_vector, dim=1)
# batched_index_select
# print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
# feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
# print(f"after sel: {feature_vector.size()}")
tot_gradients = torch.stack(tot_gradients, dim=1)
# gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
# print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
else:
sdf_nn_output = sdf_network(pts)
sdf = sdf_nn_output[:, :1]
feature_vector = sdf_nn_output[:, 1:]
gradients = sdf_network.gradient(pts).squeeze()
if self.use_selector:
tot_sampled_color = []
for i_color_net, cur_color_network in enumerate(color_network):
cur_sampled_color = cur_color_network(pts, gradients, dirs, feature_vector) # .reshape(batch_size, n_samples, 3)
tot_sampled_color.append(cur_sampled_color)
# print(f"tot_density: {tot_density.size()}, tot_sampled_color: {tot_sampled_color.size()}, sdf_selector: {sdf_selector.size()}")
tot_sampled_color = torch.stack(tot_sampled_color, dim=1)
sampled_color = batched_index_select(values=tot_sampled_color, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1).reshape(batch_size, n_samples, 3)
else:
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
if self.use_selector and isinstance(deviation_network, list):
tot_inv_s = []
for i_dev_net, cur_deviation_network in enumerate(deviation_network):
cur_inv_s = cur_deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6)
tot_inv_s.append(cur_inv_s)
tot_inv_s = torch.stack(tot_inv_s, dim=1)
inv_s = batched_index_select(values=tot_inv_s, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
# inv_s =
else:
# deviation network #
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
inv_s = inv_s.expand(batch_size * n_samples, 1)
true_cos = (dirs * gradients).sum(-1, keepdim=True)
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
# the cos value "not dead" at the beginning training iterations, for better convergence.
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
# Estimate signed distances at section points
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
p = prev_cdf - next_cdf
c = prev_cdf
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
inside_sphere = (pts_norm < 1.0).float().detach()
relax_inside_sphere = (pts_norm < 1.2).float().detach()
# Render with background
if background_alpha is not None:
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
sampled_color = sampled_color * inside_sphere[:, :, None] +\
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
weights_sum = weights.sum(dim=-1, keepdim=True)
color = (sampled_color * weights[:, :, None]).sum(dim=1)
if background_rgb is not None: # Fixed background, usually black
color = color + background_rgb * (1.0 - weights_sum)
# Eikonal loss
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
dim=-1) - 1.0) ** 2
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
return {
'color': color,
'sdf': sdf,
'dists': dists,
'gradients': gradients.reshape(batch_size, n_samples, 3),
's_val': 1.0 / inv_s,
'mid_z_vals': mid_z_vals,
'weights': weights,
'cdf': c.reshape(batch_size, n_samples),
'gradient_error': gradient_error,
'inside_sphere': inside_sphere
}
def per_sdf_query(self, pts):
tot_sdfs = []
for i_sdf_net, cur_sdf_network in enumerate(self.sdf_network):
cur_sdf_value = cur_sdf_network.sdf(pts).squeeze(-1)
tot_sdfs.append(cur_sdf_value)
tot_sdfs = torch.stack(tot_sdfs, dim=1)
return tot_sdfs
def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
batch_size = len(rays_o)
sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
z_vals = near + (far - near) * z_vals[None, :]
z_vals_outside = None
if self.n_outside > 0:
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
n_samples = self.n_samples
perturb = self.perturb
if perturb_overwrite >= 0:
perturb = perturb_overwrite
if perturb > 0:
t_rand = (torch.rand([batch_size, 1]) - 0.5)
z_vals = z_vals + t_rand * 2.0 / self.n_samples
if self.n_outside > 0: # z values output # n_outside #
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
if self.n_outside > 0:
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
background_alpha = None
background_sampled_color = None
# Up sample
if self.n_importance > 0:
with torch.no_grad():
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
# gt_sdf #
#
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
# pts = pts.flip((-1,)) * 2 - 1
pts = pts * 2 - 1
if self.use_selector:
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
else:
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
pts_exp = pts.reshape(-1, 3)
# minn_pts, _ = torch.min(pts_exp, dim=0)
# maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
# pts_to_near = pts - near.unsqueeze(1)
# maxx_pts = 1.5; minn_pts = -1.5
# # maxx_pts = 3; minn_pts = -3
# # maxx_pts = 1; minn_pts = -1
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
## render and iamges ####
# if use_gt_sdf:
# ### use the GT sdf field ####
# # print(f"Using gt sdf :")
# sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
# sdf = torch.from_numpy(sdf).float().cuda()
# sdf = sdf.reshape(batch_size, self.n_samples)
# ### use the GT sdf field ####
# else:
# # pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
# #### use the optimized sdf field ####
# # sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
if isinstance(self.sdf_network, list):
if self.use_selector:
tot_sdf_values = []
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
cur_sdf_values = cur_sdf_network.sdf(pts_exp).squeeze(-1)
tot_sdf_values.append(cur_sdf_values)
tot_sdf_values = torch.stack(tot_sdf_values, dim=1)
tot_sdf_values = batched_index_select(tot_sdf_values, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1)
sdf = tot_sdf_values.reshape(batch_size, self.n_samples)
else:
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
tot_sdf_values = []
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
tot_sdf_values.append(cur_sdf_values)
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
sdf = tot_sdf_values
else:
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
#### use the optimized sdf field ####
for i in range(self.up_sample_steps):
new_z_vals = self.up_sample(rays_o,
rays_d,
z_vals,
sdf,
self.n_importance // self.up_sample_steps,
64 * 2**i,
pts_ts=pts_ts)
z_vals, sdf = self.cat_z_vals(rays_o,
rays_d,
z_vals,
new_z_vals,
sdf,
last=(i + 1 == self.up_sample_steps),
pts_ts=pts_ts)
n_samples = self.n_samples + self.n_importance
# Background model
if self.n_outside > 0:
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
background_sampled_color = ret_outside['sampled_color']
background_alpha = ret_outside['alpha']
tot_sdfs = self.per_sdf_query(pts_exp)
# Render core
ret_fine = self.render_core(rays_o, #
rays_d,
z_vals,
sample_dist,
self.sdf_network,
self.deviation_network,
self.color_network,
background_rgb=background_rgb,
background_alpha=background_alpha,
background_sampled_color=background_sampled_color,
cos_anneal_ratio=cos_anneal_ratio,
pts_ts=pts_ts)
color_fine = ret_fine['color']
weights = ret_fine['weights']
weights_sum = weights.sum(dim=-1, keepdim=True)
gradients = ret_fine['gradients']
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
return {
'color_fine': color_fine,
's_val': s_val,
'cdf_fine': ret_fine['cdf'],
'weight_sum': weights_sum,
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
'gradients': gradients,
'weights': weights,
'gradient_error': ret_fine['gradient_error'],
'inside_sphere': ret_fine['inside_sphere'],
'tot_sdfs': tot_sdfs,
}
def render_def(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False, update_tot_def=True):
batch_size = len(rays_o)
# sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
z_vals = torch.linspace(0.0, 1.0, self.n_samples)
z_vals = near + (far - near) * z_vals[None, :]
z_vals_outside = None
if self.n_outside > 0:
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
n_samples = self.n_samples
perturb = self.perturb
if perturb_overwrite >= 0:
perturb = perturb_overwrite
if perturb > 0:
t_rand = (torch.rand([batch_size, 1]) - 0.5)
z_vals = z_vals + t_rand * 2.0 / self.n_samples
if self.n_outside > 0: # z values output # n_outside #
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
if self.n_outside > 0:
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
background_alpha = None
background_sampled_color = None
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
# gt_sdf #
#
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
# pts = pts.flip((-1,)) * 2 - 1
pts = pts * 2 - 1
if self.use_selector:
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts, update_tot_def=update_tot_def)
else:
pts = self.deform_pts(pts=pts, pts_ts=pts_ts, update_tot_def=update_tot_def) # give nthe pts
return {
'defed_pts': pts
}
#
def extract_fields_from_tets_selector_self(self, bound_min, bound_max, resolution, i_ts, passive=False):
# load tet via resolution #
# scale them via bounds #
# extract the geometry #
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
device = bound_min.device
# if resolution in [64, 70, 80, 90, 100]:
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
# else:
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
if not os.path.exists(tet_fn):
tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
tets = np.load(tet_fn)
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
# split #
# verts; verts; #
minn_verts, _ = torch.min(verts, dim=0)
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
# scale_verts = maxx_verts - minn_verts
scale_bounds = bound_max - bound_min # scale bounds #
### scale the vertices ###
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
sdf_values = []
N = 64
query_bundles = N ** 3 ### N^3
query_NNs = scaled_verts.size(0) // query_bundles
if query_NNs * query_bundles < scaled_verts.size(0):
query_NNs += 1
for i_query in range(query_NNs):
cur_bundle_st = i_query * query_bundles
cur_bundle_ed = (i_query + 1) * query_bundles
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
# if def_func is not None:
cur_query_pts, sdf_selector = self.deform_pts_with_selector(cur_query_pts, pts_ts=i_ts)
# cur_query_pts, _
# cur_query_vals = query_func(cur_query_pts)
if passive:
cur_query_vals = self.sdf_network[1].sdf(cur_query_pts) # .squeeze(-1)
else:
tot_sdf_values = []
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
cur_sdf_values = cur_sdf_network.sdf(cur_query_pts).squeeze(-1)
tot_sdf_values.append(cur_sdf_values)
tot_sdf_values = torch.stack(tot_sdf_values, dim=1)
tot_sdf_values = batched_index_select(tot_sdf_values, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1)
cur_query_vals = tot_sdf_values.unsqueeze(1)
# sdf = tot_sdf_values.reshape(batch_size, self.n_samples)
# for i_obj,
sdf_values.append(cur_query_vals)
sdf_values = torch.cat(sdf_values, dim=0)
# print(f"queryed sdf values: {sdf_values.size()}") #
gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
if not os.path.exists(gt_sdf_fn):
gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
# intrinsic, tet values, pts values, sdf network #
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
tet_table, num_tets_table = create_tetmesh_variables(device)
sdf_values = sdf_values.squeeze(-1) # how the rendering #
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
# print(f"scaled_verts: {scaled_verts.size()}, ")
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
# marching_tets_tetmesh ##
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
### use the GT sdf values for the marching tets ###
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
return extract_geometry(bound_min, # extract geometry #
bound_max,
resolution=resolution,
threshold=threshold,
# query_func=lambda pts: -self.sdf_network.sdf(pts),
query_func=lambda pts: -self.query_func_sdf(pts)
)
# if self.deform_pts_with_selector:
# pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
if wdef:
return extract_geometry_tets(bound_min, # extract geometry #
bound_max,
resolution=resolution,
threshold=threshold,
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts),
selector=True)
else:
return extract_geometry_tets(bound_min, # extract geometry #
bound_max,
resolution=resolution,
threshold=threshold,
# query_func=lambda pts: -self.sdf_network.sdf(pts)
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
selector=True
)
def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
if wdef:
return extract_geometry_tets(bound_min, # extract geometry #
bound_max,
resolution=resolution,
threshold=threshold,
query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts),
selector=False
)
# return extract_geometry_tets(bound_min, # extract geometry #
# bound_max,
# resolution=resolution,
# threshold=threshold,
# query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
# def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts),
# selector=True)
else:
return extract_geometry_tets(bound_min, # extract geometry #
bound_max,
resolution=resolution,
threshold=threshold,
# query_func=lambda pts: -self.sdf_network.sdf(pts)
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
selector=False
)