|
from typing import Callable, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torchmcubes import marching_cubes |
|
|
|
|
|
class IsosurfaceHelper(nn.Module): |
|
points_range: Tuple[float, float] = (0, 1) |
|
|
|
@property |
|
def grid_vertices(self) -> torch.FloatTensor: |
|
raise NotImplementedError |
|
|
|
|
|
class MarchingCubeHelper(IsosurfaceHelper): |
|
def __init__(self, resolution: int) -> None: |
|
super().__init__() |
|
self.resolution = resolution |
|
self.mc_func: Callable = marching_cubes |
|
self._grid_vertices: Optional[torch.FloatTensor] = None |
|
|
|
@property |
|
def grid_vertices(self) -> torch.FloatTensor: |
|
if self._grid_vertices is None: |
|
|
|
x, y, z = ( |
|
torch.linspace(*self.points_range, self.resolution), |
|
torch.linspace(*self.points_range, self.resolution), |
|
torch.linspace(*self.points_range, self.resolution), |
|
) |
|
x, y, z = torch.meshgrid(x, y, z, indexing="ij") |
|
verts = torch.cat( |
|
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 |
|
).reshape(-1, 3) |
|
self._grid_vertices = verts |
|
return self._grid_vertices |
|
|
|
def forward( |
|
self, |
|
level: torch.FloatTensor, |
|
) -> Tuple[torch.FloatTensor, torch.LongTensor]: |
|
level = -level.view(self.resolution, self.resolution, self.resolution) |
|
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0) |
|
v_pos = v_pos[..., [2, 1, 0]] |
|
v_pos = v_pos / (self.resolution - 1.0) |
|
return v_pos, t_pos_idx |
|
|