from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import craftsman from .utils import ( Mesh, IsosurfaceHelper, MarchingCubeCPUHelper, MarchingTetrahedraHelper, ) from craftsman.utils.base import BaseModule from craftsman.utils.ops import chunk_batch, scale_tensor from craftsman.utils.typing import * class BaseGeometry(BaseModule): @dataclass class Config(BaseModule.Config): pass cfg: Config @staticmethod def create_from( other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs ) -> "BaseGeometry": raise TypeError( f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" ) def export(self, *args, **kwargs): return {} class BaseImplicitGeometry(BaseGeometry): @dataclass class Config(BaseGeometry.Config): radius: float = 1.0 isosurface: bool = True isosurface_method: str = "mt" isosurface_resolution: int = 128 isosurface_threshold: Union[float, str] = 0.0 isosurface_chunk: int = 0 isosurface_coarse_to_fine: bool = True isosurface_deformable_grid: bool = False isosurface_remove_outliers: bool = True isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 cfg: Config def configure(self) -> None: self.bbox: Float[Tensor, "2 3"] self.register_buffer( "bbox", torch.as_tensor( [ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], [self.cfg.radius, self.cfg.radius, self.cfg.radius], ], dtype=torch.float32, ), ) self.isosurface_helper: Optional[IsosurfaceHelper] = None self.unbounded: bool = False def _initilize_isosurface_helper(self): if self.cfg.isosurface and self.isosurface_helper is None: if self.cfg.isosurface_method == "mc-cpu": self.isosurface_helper = MarchingCubeCPUHelper( self.cfg.isosurface_resolution ).to(self.device) elif self.cfg.isosurface_method == "mt": self.isosurface_helper = MarchingTetrahedraHelper( self.cfg.isosurface_resolution, f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", ).to(self.device) else: raise AttributeError( "Unknown isosurface method {self.cfg.isosurface_method}" ) def forward( self, points: Float[Tensor, "*N Di"], output_normal: bool = False ) -> Dict[str, Float[Tensor, "..."]]: raise NotImplementedError def forward_field( self, points: Float[Tensor, "*N Di"] ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: # return the value of the implicit field, could be density / signed distance # also return a deformation field if the grid vertices can be optimized raise NotImplementedError def forward_level( self, field: Float[Tensor, "*N 1"], threshold: float ) -> Float[Tensor, "*N 1"]: # return the value of the implicit field, where the zero level set represents the surface raise NotImplementedError def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: def batch_func(x): # scale to bbox as the input vertices are in [0, 1] field, deformation = self.forward_field( scale_tensor( x.to(bbox.device), self.isosurface_helper.points_range, bbox ), ) field = field.to( x.device ) # move to the same device as the input (could be CPU) if deformation is not None: deformation = deformation.to(x.device) return field, deformation assert self.isosurface_helper is not None field, deformation = chunk_batch( batch_func, self.cfg.isosurface_chunk, self.isosurface_helper.grid_vertices, ) threshold: float if isinstance(self.cfg.isosurface_threshold, float): threshold = self.cfg.isosurface_threshold elif self.cfg.isosurface_threshold == "auto": eps = 1.0e-5 threshold = field[field > eps].mean().item() craftsman.info( f"Automatically determined isosurface threshold: {threshold}" ) else: raise TypeError( f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" ) level = self.forward_level(field, threshold) mesh: Mesh = self.isosurface_helper(level, deformation=deformation) mesh.v_pos = scale_tensor( mesh.v_pos, self.isosurface_helper.points_range, bbox ) # scale to bbox as the grid vertices are in [0, 1] mesh.add_extra("bbox", bbox) if self.cfg.isosurface_remove_outliers: # remove outliers components with small number of faces # only enabled when the mesh is not differentiable mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) return mesh def isosurface(self) -> Mesh: if not self.cfg.isosurface: raise NotImplementedError( "Isosurface is not enabled in the current configuration" ) self._initilize_isosurface_helper() if self.cfg.isosurface_coarse_to_fine: craftsman.debug("First run isosurface to get a tight bounding box ...") with torch.no_grad(): mesh_coarse = self._isosurface(self.bbox) vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) craftsman.debug("Run isosurface again with the tight bounding box ...") mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) else: mesh = self._isosurface(self.bbox) return mesh class BaseExplicitGeometry(BaseGeometry): @dataclass class Config(BaseGeometry.Config): radius: float = 1.0 cfg: Config def configure(self) -> None: self.bbox: Float[Tensor, "2 3"] self.register_buffer( "bbox", torch.as_tensor( [ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], [self.cfg.radius, self.cfg.radius, self.cfg.radius], ], dtype=torch.float32, ), )