Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import trimesh | |
| def dot(x, y): | |
| return torch.sum(x * y, -1, keepdim=True) | |
| class Mesh: | |
| def __init__( | |
| self, v_pos, t_pos_idx, material=None | |
| ): | |
| self.v_pos = v_pos | |
| self.t_pos_idx = t_pos_idx | |
| self.material = material | |
| self._v_nrm = None | |
| self._v_tng = None | |
| self._v_tex = None | |
| self._t_tex_idx = None | |
| self._v_rgb = None | |
| self._edges = None | |
| self.extras = {} | |
| def add_extra(self, k, v) -> None: | |
| self.extras[k] = v | |
| def remove_outlier(self, n_face_threshold=5): | |
| """Remove outlier components with fewer faces than threshold.""" | |
| # Convert to trimesh | |
| trimesh_mesh = self.as_trimesh() | |
| # Split into connected components | |
| components = trimesh_mesh.split(only_watertight=False) | |
| # Filter components with few faces | |
| valid_components = [c for c in components if len(c.faces) > n_face_threshold] | |
| if len(valid_components) == 0: | |
| # If no valid components, return the original mesh | |
| return self | |
| # Combine valid components | |
| combined = trimesh.util.concatenate(valid_components) | |
| # Convert back to our Mesh format | |
| new_mesh = Mesh( | |
| torch.tensor(combined.vertices, dtype=self.v_pos.dtype, device=self.v_pos.device), | |
| torch.tensor(combined.faces, dtype=self.t_pos_idx.dtype, device=self.t_pos_idx.device) | |
| ) | |
| return new_mesh | |
| def requires_grad(self): | |
| return self.v_pos.requires_grad | |
| def v_nrm(self): | |
| if self._v_nrm is None: | |
| self._v_nrm = self._compute_vertex_normal() | |
| return self._v_nrm | |
| def v_tng(self): | |
| if self._v_tng is None: | |
| self._v_tng = self._compute_vertex_tangent() | |
| return self._v_tng | |
| def v_tex(self): | |
| if self._v_tex is None: | |
| self._v_tex, self._t_tex_idx = self._unwrap_uv() | |
| return self._v_tex | |
| def t_tex_idx(self): | |
| if self._t_tex_idx is None: | |
| self._v_tex, self._t_tex_idx = self._unwrap_uv() | |
| return self._t_tex_idx | |
| def v_rgb(self): | |
| return self._v_rgb | |
| def edges(self): | |
| if self._edges is None: | |
| self._edges = self._compute_edges() | |
| return self._edges | |
| def _compute_vertex_normal(self): | |
| i0 = self.t_pos_idx[:, 0] | |
| i1 = self.t_pos_idx[:, 1] | |
| i2 = self.t_pos_idx[:, 2] | |
| v0 = self.v_pos[i0, :] | |
| v1 = self.v_pos[i1, :] | |
| v2 = self.v_pos[i2, :] | |
| face_normals = torch.cross(v1 - v0, v2 - v0) | |
| # Splat face normals to vertices | |
| v_nrm = torch.zeros_like(self.v_pos) | |
| v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) | |
| v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) | |
| v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) | |
| # Normalize, replace zero (degenerated) normals with some default value | |
| v_nrm = torch.where( | |
| dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) | |
| ) | |
| v_nrm = F.normalize(v_nrm, dim=1) | |
| if torch.is_anomaly_enabled(): | |
| assert torch.all(torch.isfinite(v_nrm)) | |
| return v_nrm | |
| def _compute_vertex_tangent(self): | |
| vn_idx = [None] * 3 | |
| pos = [None] * 3 | |
| tex = [None] * 3 | |
| for i in range(0, 3): | |
| pos[i] = self.v_pos[self.t_pos_idx[:, i]] | |
| tex[i] = self.v_tex[self.t_tex_idx[:, i]] | |
| # t_nrm_idx is always the same as t_pos_idx | |
| vn_idx[i] = self.t_pos_idx[:, i] | |
| tangents = torch.zeros_like(self.v_nrm) | |
| tansum = torch.zeros_like(self.v_nrm) | |
| # Compute tangent space for each triangle | |
| uve1 = tex[1] - tex[0] | |
| uve2 = tex[2] - tex[0] | |
| pe1 = pos[1] - pos[0] | |
| pe2 = pos[2] - pos[0] | |
| nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] | |
| denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] | |
| # Avoid division by zero for degenerated texture coordinates | |
| tang = nom / torch.where( | |
| denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) | |
| ) | |
| # Update all 3 vertices | |
| for i in range(0, 3): | |
| idx = vn_idx[i][:, None].repeat(1, 3) | |
| tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang | |
| tansum.scatter_add_( | |
| 0, idx, torch.ones_like(tang) | |
| ) # tansum[n_i] = tansum[n_i] + 1 | |
| tangents = tangents / tansum | |
| # Normalize and make sure tangent is perpendicular to normal | |
| tangents = F.normalize(tangents, dim=1) | |
| tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) | |
| if torch.is_anomaly_enabled(): | |
| assert torch.all(torch.isfinite(tangents)) | |
| return tangents | |
| def _unwrap_uv( | |
| self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} | |
| ): | |
| import xatlas | |
| atlas = xatlas.Atlas() | |
| atlas.add_mesh( | |
| self.v_pos.detach().cpu().numpy(), | |
| self.t_pos_idx.cpu().numpy(), | |
| ) | |
| co = xatlas.ChartOptions() | |
| po = xatlas.PackOptions() | |
| for k, v in xatlas_chart_options.items(): | |
| setattr(co, k, v) | |
| for k, v in xatlas_pack_options.items(): | |
| setattr(po, k, v) | |
| atlas.generate(co, po) | |
| vmapping, indices, uvs = atlas.get_mesh(0) | |
| vmapping = ( | |
| torch.from_numpy( | |
| vmapping.astype(np.uint64, casting="same_kind").view(np.int64) | |
| ) | |
| .to(self.v_pos.device) | |
| .long() | |
| ) | |
| uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() | |
| indices = ( | |
| torch.from_numpy( | |
| indices.astype(np.uint64, casting="same_kind").view(np.int64) | |
| ) | |
| .to(self.v_pos.device) | |
| .long() | |
| ) | |
| return uvs, indices | |
| def unwrap_uv( | |
| self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} | |
| ): | |
| self._v_tex, self._t_tex_idx = self._unwrap_uv( | |
| xatlas_chart_options, xatlas_pack_options | |
| ) | |
| def set_vertex_color(self, v_rgb): | |
| assert v_rgb.shape[0] == self.v_pos.shape[0] | |
| self._v_rgb = v_rgb | |
| def _compute_edges(self): | |
| # Compute edges | |
| edges = torch.cat( | |
| [ | |
| self.t_pos_idx[:, [0, 1]], | |
| self.t_pos_idx[:, [1, 2]], | |
| self.t_pos_idx[:, [2, 0]], | |
| ], | |
| dim=0, | |
| ) | |
| edges = edges.sort()[0] | |
| edges = torch.unique(edges, dim=0) | |
| return edges | |
| def normal_consistency(self): | |
| edge_nrm = self.v_nrm[self.edges] | |
| nc = ( | |
| 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) | |
| ).mean() | |
| return nc | |
| def _laplacian_uniform(self): | |
| # from stable-dreamfusion | |
| # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224 | |
| verts, faces = self.v_pos, self.t_pos_idx | |
| V = verts.shape[0] | |
| F = faces.shape[0] | |
| # Neighbor indices | |
| ii = faces[:, [1, 2, 0]].flatten() | |
| jj = faces[:, [2, 0, 1]].flatten() | |
| adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( | |
| dim=1 | |
| ) | |
| adj_values = torch.ones(adj.shape[1]).to(verts) | |
| # Diagonal indices | |
| diag_idx = adj[0] | |
| # Build the sparse matrix | |
| idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) | |
| values = torch.cat((-adj_values, adj_values)) | |
| # The coalesce operation sums the duplicate indices, resulting in the | |
| # correct diagonal | |
| return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() | |
| def laplacian(self): | |
| with torch.no_grad(): | |
| L = self._laplacian_uniform() | |
| loss = L.mm(self.v_pos) | |
| loss = loss.norm(dim=1) | |
| loss = loss.mean() | |
| return loss | |
| def to(self, device): | |
| v_pos = self.v_pos.to(device) | |
| t_pos_idx = self.t_pos_idx.to(device) | |
| return Mesh(v_pos, t_pos_idx) | |
| def as_trimesh(self): | |
| vertices = self.v_pos.detach().cpu().numpy() | |
| faces = self.t_pos_idx.detach().cpu().numpy() | |
| mesh = trimesh.Trimesh( | |
| vertices=vertices, | |
| faces=faces, | |
| process=False | |
| ) | |
| # Add texture if available | |
| if hasattr(self, 'albedo_map') and self.albedo_map is not None: | |
| # Create texture visuals | |
| uv = self.v_tex.detach().cpu().numpy() | |
| # Create texture visuals | |
| visual = trimesh.visual.texture.TextureVisuals( | |
| uv=uv, | |
| material=trimesh.visual.material.SimpleMaterial() | |
| ) | |
| mesh.visual = visual | |
| return mesh | |
| def scale_tensor(x, input_range, target_range): | |
| """Scale tensor from input_range to target_range.""" | |
| x_unit = (x - input_range[0]) / (input_range[1] - input_range[0]) | |
| x_scaled = x_unit * (target_range[1] - target_range[0]) + target_range[0] | |
| return x_scaled | |