Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,799 Bytes
2ac1c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from typing import Union, Tuple, List
import numpy as np
import torch
from skimage import measure
class MeshExtractResult:
def __init__(self, verts, faces, vertex_attrs=None, res=64):
self.verts = verts
self.faces = faces.long()
self.vertex_attrs = vertex_attrs
self.face_normal = self.comput_face_normals()
self.vert_normal = self.comput_v_normals()
self.res = res
self.success = verts.shape[0] != 0 and faces.shape[0] != 0
# training only
self.tsdf_v = None
self.tsdf_s = None
self.reg_loss = None
def comput_face_normals(self):
i0 = self.faces[..., 0].long()
i1 = self.faces[..., 1].long()
i2 = self.faces[..., 2].long()
v0 = self.verts[i0, :]
v1 = self.verts[i1, :]
v2 = self.verts[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
face_normals = torch.nn.functional.normalize(face_normals, dim=1)
return face_normals[:, None, :].repeat(1, 3, 1)
def comput_v_normals(self):
i0 = self.faces[..., 0].long()
i1 = self.faces[..., 1].long()
i2 = self.faces[..., 2].long()
v0 = self.verts[i0, :]
v1 = self.verts[i1, :]
v2 = self.verts[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
v_normals = torch.zeros_like(self.verts)
v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
v_normals = torch.nn.functional.normalize(v_normals, dim=1)
return v_normals
def center_vertices(vertices):
"""Translate the vertices so that bounding box is centered at zero."""
vert_min = vertices.min(dim=0)[0]
vert_max = vertices.max(dim=0)[0]
vert_center = 0.5 * (vert_min + vert_max)
return vertices - vert_center
class SurfaceExtractor:
def _compute_box_stat(
self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int
):
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
grid_size = [
int(octree_resolution) + 1,
int(octree_resolution) + 1,
int(octree_resolution) + 1,
]
return grid_size, bbox_min, bbox_size
def run(self, *args, **kwargs):
return NotImplementedError
def __call__(self, grid_logits, **kwargs):
outputs = []
for i in range(grid_logits.shape[0]):
try:
verts, faces = self.run(grid_logits[i], **kwargs)
outputs.append(
MeshExtractResult(
verts=verts.float(),
faces=faces,
res=kwargs["octree_resolution"],
)
)
except Exception:
import traceback
traceback.print_exc()
outputs.append(None)
return outputs
class MCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
verts, faces, normals, _ = measure.marching_cubes(
grid_logit.float().cpu().numpy(), mc_level, method="lewiner"
)
grid_size, bbox_min, bbox_size = self._compute_box_stat(
bounds, octree_resolution
)
verts = verts / grid_size * bbox_size + bbox_min
verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32)
faces = torch.tensor(
np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long
)
faces = faces[:, [2, 1, 0]]
return verts, faces
class DMCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, octree_resolution, **kwargs):
device = grid_logit.device
if not hasattr(self, "dmc"):
try:
from diso import DiffDMC
except:
raise ImportError(
"Please install diso via `pip install diso`, or set mc_algo to 'mc'"
)
self.dmc = DiffDMC(dtype=torch.float32).to(device)
sdf = -grid_logit / octree_resolution
sdf = sdf.to(torch.float32).contiguous()
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
grid_size, bbox_min, bbox_size = self._compute_box_stat(
kwargs["bounds"], octree_resolution
)
verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"]
return verts, faces
|