stable-fast-3d / sf3d /models /isosurface.py
mboss's picture
Initial commit
d945eeb
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from jaxtyping import Float, Integer
from torch import Tensor
from .mesh import Mesh
class IsosurfaceHelper(nn.Module):
points_range: Tuple[float, float] = (0, 1)
@property
def grid_vertices(self) -> Float[Tensor, "N 3"]:
raise NotImplementedError
@property
def requires_instance_per_batch(self) -> bool:
return False
class MarchingTetrahedraHelper(IsosurfaceHelper):
def __init__(self, resolution: int, tets_path: str):
super().__init__()
self.resolution = resolution
self.tets_path = tets_path
self.triangle_table: Float[Tensor, "..."]
self.register_buffer(
"triangle_table",
torch.as_tensor(
[
[-1, -1, -1, -1, -1, -1],
[1, 0, 2, -1, -1, -1],
[4, 0, 3, -1, -1, -1],
[1, 4, 2, 1, 3, 4],
[3, 1, 5, -1, -1, -1],
[2, 3, 0, 2, 5, 3],
[1, 4, 0, 1, 5, 4],
[4, 2, 5, -1, -1, -1],
[4, 5, 2, -1, -1, -1],
[4, 1, 0, 4, 5, 1],
[3, 2, 0, 3, 5, 2],
[1, 3, 5, -1, -1, -1],
[4, 1, 2, 4, 3, 1],
[3, 0, 4, -1, -1, -1],
[2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1],
],
dtype=torch.long,
),
persistent=False,
)
self.num_triangles_table: Integer[Tensor, "..."]
self.register_buffer(
"num_triangles_table",
torch.as_tensor(
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
),
persistent=False,
)
self.base_tet_edges: Integer[Tensor, "..."]
self.register_buffer(
"base_tet_edges",
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
persistent=False,
)
tets = np.load(self.tets_path)
self._grid_vertices: Float[Tensor, "..."]
self.register_buffer(
"_grid_vertices",
torch.from_numpy(tets["vertices"]).float(),
persistent=False,
)
self.indices: Integer[Tensor, "..."]
self.register_buffer(
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False
)
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
center_indices, boundary_indices = self.get_center_boundary_index(
self._grid_vertices
)
self.center_indices: Integer[Tensor, "..."]
self.register_buffer("center_indices", center_indices, persistent=False)
self.boundary_indices: Integer[Tensor, "..."]
self.register_buffer("boundary_indices", boundary_indices, persistent=False)
def get_center_boundary_index(self, verts):
magn = torch.sum(verts**2, dim=-1)
center_idx = torch.argmin(magn)
boundary_neg = verts == verts.max()
boundary_pos = verts == verts.min()
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
boundary = torch.sum(boundary.float(), dim=-1)
boundary_idx = torch.nonzero(boundary)
return center_idx, boundary_idx.squeeze(dim=-1)
def normalize_grid_deformation(
self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
) -> Float[Tensor, "Nv 3"]:
return (
(self.points_range[1] - self.points_range[0])
/ self.resolution # half tet size is approximately 1 / self.resolution
* torch.tanh(grid_vertex_offsets)
) # FIXME: hard-coded activation
@property
def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
return self._grid_vertices
@property
def all_edges(self) -> Integer[Tensor, "Ne 2"]:
if self._all_edges is None:
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
edges = torch.tensor(
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
dtype=torch.long,
device=self.indices.device,
)
_all_edges = self.indices[:, edges].reshape(-1, 2)
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
_all_edges = torch.unique(_all_edges_sorted, dim=0)
self._all_edges = _all_edges
return self._all_edges
def sort_edges(self, edges_ex2):
with torch.no_grad():
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
order = order.unsqueeze(dim=1)
a = torch.gather(input=edges_ex2, index=order, dim=1)
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
return torch.stack([a, b], -1)
def _forward(self, pos_nx3, sdf_n, tet_fx4):
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
occ_sum = torch.sum(occ_fx4, -1)
valid_tets = (occ_sum > 0) & (occ_sum < 4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
mapping = (
torch.ones(
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
)
* -1
)
mapping[mask_edges] = torch.arange(
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
)
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
edges_to_interp_sdf[:, -1] *= -1
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1, 6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat(
(
torch.gather(
input=idx_map[num_triangles == 1],
dim=1,
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
).reshape(-1, 3),
torch.gather(
input=idx_map[num_triangles == 2],
dim=1,
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
).reshape(-1, 3),
),
dim=0,
)
return verts, faces
def forward(
self,
level: Float[Tensor, "N3 1"],
deformation: Optional[Float[Tensor, "N3 3"]] = None,
) -> Mesh:
if deformation is not None:
grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
deformation
)
else:
grid_vertices = self.grid_vertices
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
mesh = Mesh(
v_pos=v_pos,
t_pos_idx=t_pos_idx,
# extras
grid_vertices=grid_vertices,
tet_edges=self.all_edges,
grid_level=level,
grid_deformation=deformation,
)
return mesh