import sys from pathlib import Path from typing import Union import h5py import numpy as np import open3d as o3d import torch from rich.progress import track from salad.utils.paths import SPAGHETTI_DIR from salad.utils import nputil, thutil, sysutil, meshutil # TODO rewrite SPAGHETTI's relative path dependecies. # Too lazy to refactorize SPAGHETTI's relative paths.. def add_spaghetti_path(spaghetti_path=SPAGHETTI_DIR): spaghetti_path = str(spaghetti_path) if spaghetti_path not in sys.path: sys.path.append(spaghetti_path) def delete_spaghetti_path( spaghetti_path=SPAGHETTI_DIR, ): spaghetti_path = str(spaghetti_path) if spaghetti_path in sys.path: sys.path.remove(spaghetti_path) def load_spaghetti(device, tag="chairs_large"): assert tag in [ "chairs_large", "airplanes", "tables", ], f"tag should be 'chairs_large', 'airplanes' or 'tables'." add_spaghetti_path() from salad.spaghetti.options import Options from salad.spaghetti.ui import occ_inference opt = Options() opt.dataset_size = 1 opt.device = device opt.tag = tag infer_module = occ_inference.Inference(opt) spaghetti = infer_module.model.to(device) spaghetti.eval() for p in spaghetti.parameters(): p.requires_grad_(False) delete_spaghetti_path() return spaghetti def load_mesher( device, min_res=64, ): from salad.spaghetti.utils.mcubes_meshing import MarchingCubesMeshing mesher = MarchingCubesMeshing(device=device, min_res=min_res) delete_spaghetti_path() return mesher def get_mesh_and_pc(spaghetti, mesher, zc): vert, face = get_mesh_from_spaghetti(spaghetti, mesher, zc) pc = poisson_sampling(vert, face) return vert, face, pc def get_mesh_from_spaghetti(spaghetti, mesher, zc, res=256): mesh = mesher.occ_meshing( decoder=get_occ_func(spaghetti, zc), res=res, get_time=False, verbose=False ) vert, face = list(map(lambda x: thutil.th2np(x), mesh)) return vert, face def poisson_sampling(vert: np.array, face: np.array): vert_o3d = o3d.utility.Vector3dVector(vert) face_o3d = o3d.utility.Vector3iVector(face) mesh_o3d = o3d.geometry.TriangleMesh(vert_o3d, face_o3d) pc_o3d = mesh_o3d.sample_points_poisson_disk(2048) pc = np.asarray(pc_o3d.points).astype(np.float32) return pc def get_occ_func(spaghetti, zc): device = spaghetti.device zc = nputil.np2th(zc).to(device) def forward(x): nonlocal zc x = x.unsqueeze(0) out = spaghetti.occupancy_network(x, zc)[0, :] out = 2 * out.sigmoid_() - 1 return out if zc.dim() == 2: zc = zc.unsqueeze(0) return forward def generate_zc_from_sj_gaus( spaghetti, sj: Union[torch.Tensor, np.ndarray], gaus: Union[torch.Tensor, np.ndarray], ): """ Input: sj: [B,16,512] or [16,512] gaus: [B,16,16] or [16,16] Output: zc: [B,16,512] """ device = spaghetti.device sj = nputil.np2th(sj) gaus = nputil.np2th(gaus) assert sj.dim() == gaus.dim() if sj.dim() == 2: sj = sj.unsqueeze(0) batch_sj = sj.to(device) batch_gmms = batch_gaus_to_gmms(gaus, device) zcs, _ = spaghetti.merge_zh(batch_sj, batch_gmms) return zcs def generate_zc_from_za(spaghetti, za: Union[torch.Tensor, np.ndarray]): device = spaghetti.device za = nputil.np2th(za).to(device) sjs, gmms = spaghetti.decomposition_control(za) zcs, _ = spaghetti.merge_zh(sjs, gmms) return zcs def generate_gaus_from_za(spaghetti, za): # device = spaghetti.device # za = nputil.np2th(za).to(device) sjs, gmms = spaghetti.decomposition_control(za) if isinstance(gmms[0], list): gaus = gmms[0] else: gaus = list(gmms) gaus = [flatten_gmms_item(x) for x in gaus] gaus = torch.cat(gaus, -1) # gaus = batch_gmms_to_gaus(gmms) return gaus def generate_zc_from_single_phase_latent( spaghetti, sj_gaus: Union[torch.Tensor, np.ndarray] ): device = spaghetti.device sj_gaus = nputil.np2th(sj_gaus).to(device) sj, gaus = sj_gaus.split(split_size=[512, 16], dim=-1) zcs = generate_zc_from_sj_gaus(spaghetti, sj, gaus) return zcs def flatten_gmms_item(x): """ Input: [B,1,G,*shapes] Output: [B,G,-1] """ return x.reshape(x.shape[0], x.shape[2], -1) @torch.no_grad() def batch_gmms_to_gaus(gmms): """ Input: [T(B,1,G,3), T(B,1,G,3,3), T(B,1,G), T(B,1,G,3)] Output: T(B,G,16) """ if isinstance(gmms[0], list): gaus = gmms[0].copy() else: gaus = list(gmms).copy() gaus = [flatten_gmms_item(x) for x in gaus] return torch.cat(gaus, -1) @torch.no_grad() def batch_gaus_to_gmms(gaus, device="cpu"): """ Input: T(B,G,16) Output: [mu: T(B,1,G,3), eivec: T(B,1,G,3,3), pi: T(B,1,G), eival: T(B,1,G,3)] """ gaus = nputil.np2th(gaus).to(device) if len(gaus.shape) < 3: gaus = gaus.unsqueeze(0) # expand dim for batch B, G, _ = gaus.shape mu = gaus[:, :, :3].reshape(B, 1, G, 3) eivec = gaus[:, :, 3:12].reshape(B, 1, G, 3, 3) pi = gaus[:, :, 12].reshape(B, 1, G) eival = gaus[:, :, 13:16].reshape(B, 1, G, 3) return [mu, eivec, pi, eival] def reflect_and_concat_gmms(gmms: torch.Tensor): """ Input: gmms: (B, 8, 16). A batch of GMMs Output: new_gmms: (B, 16, 16) """ gmms = nputil.np2th(gmms) gmms = gmms.clone() if gmms.dim() == 2: gmms = gmms.unsqueeze(0) affine = torch.eye(3).to(gmms) affine[0, 0] = -1.0 mu, p, phi, eigen = torch.split(gmms, [3, 9, 1, 3], dim=2) if affine.ndim == 2: affine = affine.unsqueeze(0).expand(mu.size(0), *affine.shape) bs, n_part, _ = mu.shape p = p.reshape(bs, n_part, 3, 3) mu_r = torch.einsum("bad, bnd -> bna", affine, mu) p_r = torch.einsum("bad, bncd -> bnca", affine, p) p_r = p_r.reshape(bs, n_part, -1) gmms_t = torch.cat([mu_r, p_r, phi, eigen], dim=2) assert ( gmms.shape == gmms_t.shape ), "Input and reflected gmms shapes must be the same" return torch.cat([gmms, gmms_t], dim=1) def clip_eigenvalues(gaus: Union[torch.Tensor, np.ndarray], eps=1e-4): """ Input: gaus: [B,G,16] or [G,16] Output: gaus_clipped: [B,G,16] or [G,16] torch.Tensor """ gaus = nputil.np2th(gaus) clipped_gaus = gaus.clone() clipped_gaus[..., 13:16] = torch.clamp_min(clipped_gaus[..., 13:16], eps) return clipped_gaus def project_eigenvectors(gaus: Union[torch.Tensor, np.ndarray]): """ Input: gaus: [B,G,16] or [G,16] Output: gaus_projected: [B,G,16] or [1,G,16] """ gaus = nputil.np2th(gaus).clone() if gaus.ndim == 2: gaus = gaus.unsqueeze(0) B, G = gaus.shape[:2] eigvec = gaus[:, :, 3:12] eigvec_projected = get_orthonormal_bases_svd(eigvec) gaus[:, :, 3:12] = eigvec_projected return gaus def get_orthonormal_bases_svd(vs: torch.Tensor): """ Implements the solution for the Orthogonal Procrustes problem, which projects a matrix to the closest rotation matrix / reflection matrix using SVD. Args: vs: Tensor of shape (B, M, 9) Returns: p: Tensor of shape (B, M, 9). """ # Compute SVDs of matrices in batch b, m, _ = vs.shape vs_ = vs.reshape(b * m, 3, 3) U, _, Vh = torch.linalg.svd(vs_) # Determine the diagonal matrix to make determinants 1 sigma = torch.eye(3)[None, ...].repeat(b * m, 1, 1).to(vs_.device) det = torch.linalg.det(torch.bmm(U, Vh)) # Compute determinants of UVT #### # Do not set the sign of determinants to 1. # Inputs contain reflection matrices. # sigma[:, 2, 2] = det #### # Construct orthogonal matrices p = torch.bmm(torch.bmm(U, sigma), Vh) return p.reshape(b, m, 9) def save_meshes_and_pointclouds( spaghetti, mesher, zcs, save_top_dir, mesh_save_dir=None, pc_save_dir=None, num_shapes=2000, ): save_top_dir = Path(save_top_dir) print(f"Save dir is: {save_top_dir}") if mesh_save_dir is None: mesh_save_dir = save_top_dir / "meshes" mesh_save_dir.mkdir(exist_ok=True) if pc_save_dir is None: pc_save_dir = save_top_dir / "pointclouds" pc_save_dir.mkdir(exist_ok=True) mesh_save_dir = Path(mesh_save_dir) pc_save_dir = Path(pc_save_dir) all_pointclouds = np.zeros((num_shapes, 2048, 3)) for i in track(range(num_shapes), description="extracting pc and mesh"): zc = zcs[i] vert_np, face_np, pc_np = get_mesh_and_pc(spaghetti, mesher, zc) sysutil.clean_gpu() all_pointclouds[i] = pc_np meshutil.write_obj_triangle(mesh_save_dir / f"{i}.obj", vert_np, face_np) np.save(pc_save_dir / f"{i}.npy", pc_np) if i == 1000: with h5py.File(save_top_dir / "o3d_all_pointclouds.hdf5", "w") as f: f["data"] = all_pointclouds[:1000] with h5py.File(save_top_dir / "o3d_all_pointclouds.hdf5", "w") as f: f["data"] = all_pointclouds