datnguyentien204's picture
Upload 338 files
8e0b903 verified
raw
history blame
1.98 kB
from typing import Callable, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torchmcubes import marching_cubes
class IsosurfaceHelper(nn.Module):
points_range: Tuple[float, float] = (0, 1)
@property
def grid_vertices(self) -> torch.FloatTensor:
raise NotImplementedError
class MarchingCubeHelper(IsosurfaceHelper):
def __init__(self, resolution: int) -> None:
super().__init__()
self.resolution = resolution
self.mc_func: Callable = marching_cubes
self._grid_vertices: Optional[torch.FloatTensor] = None
@property
def grid_vertices(self) -> torch.FloatTensor:
if self._grid_vertices is None:
# keep the vertices on CPU so that we can support very large resolution
x, y, z = (
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
)
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
verts = torch.cat(
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
).reshape(-1, 3)
self._grid_vertices = verts
return self._grid_vertices
def forward(
self,
level: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
level = -level.view(self.resolution, self.resolution, self.resolution)
try:
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
except AttributeError:
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
v_pos = v_pos[..., [2, 1, 0]]
v_pos = v_pos / (self.resolution - 1.0)
return v_pos.to(level.device), t_pos_idx.to(level.device)