Real3D / tsr /models /isosurface.py
dmitriitochilkin's picture
add dependencies
ff49a48
raw
history blame contribute delete
No virus
1.68 kB
from typing import Callable, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torchmcubes import marching_cubes
class IsosurfaceHelper(nn.Module):
points_range: Tuple[float, float] = (0, 1)
@property
def grid_vertices(self) -> torch.FloatTensor:
raise NotImplementedError
class MarchingCubeHelper(IsosurfaceHelper):
def __init__(self, resolution: int) -> None:
super().__init__()
self.resolution = resolution
self.mc_func: Callable = marching_cubes
self._grid_vertices: Optional[torch.FloatTensor] = None
@property
def grid_vertices(self) -> torch.FloatTensor:
if self._grid_vertices is None:
# keep the vertices on CPU so that we can support very large resolution
x, y, z = (
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
)
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
verts = torch.cat(
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
).reshape(-1, 3)
self._grid_vertices = verts
return self._grid_vertices
def forward(
self,
level: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
level = -level.view(self.resolution, self.resolution, self.resolution)
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
v_pos = v_pos[..., [2, 1, 0]]
v_pos = v_pos / (self.resolution - 1.0)
return v_pos, t_pos_idx