ReubenSun's picture
1
2ac1c2d
from typing import Union, Tuple, List
import numpy as np
import torch
from skimage import measure
class MeshExtractResult:
def __init__(self, verts, faces, vertex_attrs=None, res=64):
self.verts = verts
self.faces = faces.long()
self.vertex_attrs = vertex_attrs
self.face_normal = self.comput_face_normals()
self.vert_normal = self.comput_v_normals()
self.res = res
self.success = verts.shape[0] != 0 and faces.shape[0] != 0
# training only
self.tsdf_v = None
self.tsdf_s = None
self.reg_loss = None
def comput_face_normals(self):
i0 = self.faces[..., 0].long()
i1 = self.faces[..., 1].long()
i2 = self.faces[..., 2].long()
v0 = self.verts[i0, :]
v1 = self.verts[i1, :]
v2 = self.verts[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
face_normals = torch.nn.functional.normalize(face_normals, dim=1)
return face_normals[:, None, :].repeat(1, 3, 1)
def comput_v_normals(self):
i0 = self.faces[..., 0].long()
i1 = self.faces[..., 1].long()
i2 = self.faces[..., 2].long()
v0 = self.verts[i0, :]
v1 = self.verts[i1, :]
v2 = self.verts[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
v_normals = torch.zeros_like(self.verts)
v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
v_normals = torch.nn.functional.normalize(v_normals, dim=1)
return v_normals
def center_vertices(vertices):
"""Translate the vertices so that bounding box is centered at zero."""
vert_min = vertices.min(dim=0)[0]
vert_max = vertices.max(dim=0)[0]
vert_center = 0.5 * (vert_min + vert_max)
return vertices - vert_center
class SurfaceExtractor:
def _compute_box_stat(
self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int
):
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
grid_size = [
int(octree_resolution) + 1,
int(octree_resolution) + 1,
int(octree_resolution) + 1,
]
return grid_size, bbox_min, bbox_size
def run(self, *args, **kwargs):
return NotImplementedError
def __call__(self, grid_logits, **kwargs):
outputs = []
for i in range(grid_logits.shape[0]):
try:
verts, faces = self.run(grid_logits[i], **kwargs)
outputs.append(
MeshExtractResult(
verts=verts.float(),
faces=faces,
res=kwargs["octree_resolution"],
)
)
except Exception:
import traceback
traceback.print_exc()
outputs.append(None)
return outputs
class MCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
verts, faces, normals, _ = measure.marching_cubes(
grid_logit.float().cpu().numpy(), mc_level, method="lewiner"
)
grid_size, bbox_min, bbox_size = self._compute_box_stat(
bounds, octree_resolution
)
verts = verts / grid_size * bbox_size + bbox_min
verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32)
faces = torch.tensor(
np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long
)
faces = faces[:, [2, 1, 0]]
return verts, faces
class DMCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, octree_resolution, **kwargs):
device = grid_logit.device
if not hasattr(self, "dmc"):
try:
from diso import DiffDMC
except:
raise ImportError(
"Please install diso via `pip install diso`, or set mc_algo to 'mc'"
)
self.dmc = DiffDMC(dtype=torch.float32).to(device)
sdf = -grid_logit / octree_resolution
sdf = sdf.to(torch.float32).contiguous()
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
grid_size, bbox_min, bbox_size = self._compute_box_stat(
kwargs["bounds"], octree_resolution
)
verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"]
return verts, faces