Spaces:
Running
on
Zero
Running
on
Zero
| # | |
| # Copyright (C) 2024, ShanghaiTech | |
| # SVIP research group, https://github.com/svip-lab | |
| # All rights reserved. | |
| # | |
| # This software is free for non-commercial, research and evaluation use | |
| # under the terms of the LICENSE.md file. | |
| # | |
| # For inquiries contact huangbb@shanghaitech.edu.cn | |
| # | |
| #copy from https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/mesh_utils.py | |
| import torch | |
| import numpy as np | |
| import os | |
| import math | |
| from tqdm import tqdm | |
| from functools import partial | |
| import open3d as o3d | |
| import trimesh | |
| from utils.depth_utils import depth_to_normal | |
| def post_process_mesh(mesh, cluster_to_keep=1000): | |
| """ | |
| Post-process a mesh to filter out floaters and disconnected parts | |
| """ | |
| import copy | |
| print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep)) | |
| mesh_0 = copy.deepcopy(mesh) | |
| with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm: | |
| triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles()) | |
| triangle_clusters = np.asarray(triangle_clusters) | |
| cluster_n_triangles = np.asarray(cluster_n_triangles) | |
| cluster_area = np.asarray(cluster_area) | |
| n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep] | |
| n_cluster = max(n_cluster, 50) # filter meshes smaller than 50 | |
| triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster | |
| mesh_0.remove_triangles_by_mask(triangles_to_remove) | |
| mesh_0.remove_unreferenced_vertices() | |
| mesh_0.remove_degenerate_triangles() | |
| print("num vertices raw {}".format(len(mesh.vertices))) | |
| print("num vertices post {}".format(len(mesh_0.vertices))) | |
| return mesh_0 | |
| def to_cam_open3d(viewpoint_stack): | |
| camera_traj = [] | |
| for i, viewpoint_cam in enumerate(viewpoint_stack): | |
| intrinsic=o3d.camera.PinholeCameraIntrinsic(width=viewpoint_cam.image_width, | |
| height=viewpoint_cam.image_height, | |
| cx = viewpoint_cam.image_width/2, | |
| cy = viewpoint_cam.image_height/2, | |
| fx = viewpoint_cam.image_width / (2 * math.tan(viewpoint_cam.FoVx / 2.)), | |
| fy = viewpoint_cam.image_height / (2 * math.tan(viewpoint_cam.FoVy / 2.))) | |
| extrinsic=np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy()) | |
| camera = o3d.camera.PinholeCameraParameters() | |
| camera.extrinsic = extrinsic | |
| camera.intrinsic = intrinsic | |
| camera_traj.append(camera) | |
| return camera_traj | |
| class GaussianExtractor(object): | |
| def __init__(self, gaussians, render, pipe, bg_color=None): | |
| """ | |
| a class that extracts attributes a scene presented by 2DGS | |
| Usage example: | |
| >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe) | |
| >>> gaussExtrator.reconstruction(view_points) | |
| >>> mesh = gaussExtractor.export_mesh_bounded(...) | |
| """ | |
| if bg_color is None: | |
| bg_color = [0, 0, 0] | |
| background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
| self.gaussians = gaussians | |
| self.render = partial(render, pipe=pipe, bg_color=background) | |
| self.clean() | |
| def clean(self): | |
| self.depthmaps = [] | |
| self.alphamaps = [] | |
| self.rgbmaps = [] | |
| self.normals = [] | |
| self.depth_normals = [] | |
| self.viewpoint_stack = [] | |
| def reconstruction(self, viewpoint_stack): | |
| """ | |
| reconstruct radiance field given cameras | |
| """ | |
| self.clean() | |
| self.viewpoint_stack = viewpoint_stack | |
| for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields"): | |
| render_pkg = self.render(viewpoint_cam, self.gaussians) | |
| rgb = render_pkg['render'] | |
| alpha = render_pkg['mask'] | |
| normal = torch.nn.functional.normalize(render_pkg['normal'], dim=0) | |
| depth = render_pkg['middepth'] | |
| depth_normal, _ = depth_to_normal(viewpoint_cam, depth) | |
| depth_normal = depth_normal.permute(2,0,1) | |
| # depth_normal = render_pkg['surf_normal'] | |
| self.rgbmaps.append(rgb.cpu()) | |
| self.depthmaps.append(depth.cpu()) | |
| self.alphamaps.append(alpha.cpu()) | |
| self.normals.append(normal.cpu()) | |
| self.depth_normals.append(depth_normal.cpu()) | |
| self.rgbmaps = torch.stack(self.rgbmaps, dim=0) | |
| self.depthmaps = torch.stack(self.depthmaps, dim=0) | |
| self.alphamaps = torch.stack(self.alphamaps, dim=0) | |
| self.depth_normals = torch.stack(self.depth_normals, dim=0) | |
| def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True): | |
| """ | |
| Perform TSDF fusion given a fixed depth range, used in the paper. | |
| voxel_size: the voxel size of the volume | |
| sdf_trunc: truncation value | |
| depth_trunc: maximum depth range, should depended on the scene's scales | |
| mask_backgrond: whether to mask backgroud, only works when the dataset have masks | |
| return o3d.mesh | |
| """ | |
| print("Running tsdf volume integration ...") | |
| print(f'voxel_size: {voxel_size}') | |
| print(f'sdf_trunc: {sdf_trunc}') | |
| print(f'depth_truc: {depth_trunc}') | |
| volume = o3d.pipelines.integration.ScalableTSDFVolume( | |
| voxel_length= voxel_size, | |
| sdf_trunc=sdf_trunc, | |
| color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 | |
| ) | |
| for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"): | |
| rgb = self.rgbmaps[i] | |
| depth = self.depthmaps[i] | |
| # if we have mask provided, use it | |
| if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None): | |
| depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0 | |
| # make open3d rgbd | |
| rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( | |
| o3d.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy() * 255, order="C", dtype=np.uint8)), | |
| o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")), | |
| depth_trunc = depth_trunc, convert_rgb_to_intensity=False, | |
| depth_scale = 1.0 | |
| ) | |
| volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic) | |
| mesh = volume.extract_triangle_mesh() | |
| return mesh | |
| def extract_mesh_unbounded(self, resolution=1024): | |
| """ | |
| Experimental features, extracting meshes from unbounded scenes, not fully test across datasets. | |
| #TODO: support color mesh exporting | |
| sdf_trunc: truncation value | |
| return o3d.mesh | |
| """ | |
| def contract(x): | |
| mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None] | |
| return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag)) | |
| def uncontract(y): | |
| mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None] | |
| return torch.where(mag < 1, y, (1 / (2-mag) * (y/mag))) | |
| def compute_sdf_perframe(i, points, depthmap, rgbmap, normalmap, viewpoint_cam): | |
| """ | |
| compute per frame sdf | |
| """ | |
| new_points = torch.cat([points, torch.ones_like(points[...,:1])], dim=-1) @ viewpoint_cam.full_proj_transform | |
| z = new_points[..., -1:] | |
| pix_coords = (new_points[..., :2] / new_points[..., -1:]) | |
| mask_proj = ((pix_coords > -1. ) & (pix_coords < 1.) & (z > 0)).all(dim=-1) | |
| sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(-1, 1) | |
| sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T | |
| sampled_normal = torch.nn.functional.grid_sample(normalmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T | |
| sdf = (sampled_depth-z) | |
| return sdf, sampled_rgb, sampled_normal, mask_proj | |
| def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False): | |
| """ | |
| Fusion all frames, perform adaptive sdf_funcation on the contract spaces. | |
| """ | |
| if inv_contraction is not None: | |
| samples = inv_contraction(samples) | |
| mask = torch.linalg.norm(samples, dim=-1) > 1 | |
| # adaptive sdf_truncation | |
| sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0]) | |
| sdf_trunc[mask] *= 1/(2-torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9)) | |
| else: | |
| sdf_trunc = 5 * voxel_size | |
| tsdfs = torch.ones_like(samples[:,0]) * 1 | |
| rgbs = torch.zeros((samples.shape[0], 3)).cuda() | |
| weights = torch.ones_like(samples[:,0]) | |
| for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="TSDF integration progress"): | |
| sdf, rgb, normal, mask_proj = compute_sdf_perframe(i, samples, | |
| depthmap = self.depthmaps[i], | |
| rgbmap = self.rgbmaps[i], | |
| normalmap = self.depth_normals[i], | |
| viewpoint_cam=self.viewpoint_stack[i], | |
| ) | |
| # volume integration | |
| sdf = sdf.flatten() | |
| mask_proj = mask_proj & (sdf > -sdf_trunc) | |
| sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj] | |
| w = weights[mask_proj] | |
| wp = w + 1 | |
| tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp | |
| rgbs[mask_proj] = (rgbs[mask_proj] * w[:,None] + rgb[mask_proj]) / wp[:,None] | |
| # update weight | |
| weights[mask_proj] = wp | |
| if return_rgb: | |
| return tsdfs, rgbs | |
| return tsdfs | |
| from utils.render_utils import transform_poses_pca, focus_point_fn | |
| torch.cuda.empty_cache() | |
| c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack]) | |
| poses = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1]) | |
| center = (focus_point_fn(poses)) | |
| radius = np.linalg.norm(c2ws[:,:3,3] - center, axis=-1).min() | |
| center = torch.from_numpy(center).float().cuda() | |
| normalize = lambda x: (x - center) / radius | |
| unnormalize = lambda x: (x * radius) + center | |
| inv_contraction = lambda x: unnormalize(uncontract(x)) | |
| N = resolution | |
| voxel_size = (radius * 2 / N) | |
| print(f"Computing sdf gird resolution {N} x {N} x {N}") | |
| print(f"Define the voxel_size as {voxel_size}") | |
| sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size) | |
| from utils.mcube_utils import marching_cubes_with_contraction | |
| R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy() | |
| R = np.quantile(R, q=0.95) | |
| R = min(R+0.01, 1.9) | |
| mesh = marching_cubes_with_contraction( | |
| sdf=sdf_function, | |
| bounding_box_min=(-R, -R, -R), | |
| bounding_box_max=(R, R, R), | |
| level=0, | |
| resolution=N, | |
| inv_contraction=inv_contraction, | |
| ) | |
| # coloring the mesh | |
| torch.cuda.empty_cache() | |
| mesh = mesh.as_open3d | |
| print("texturing mesh ... ") | |
| _, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None, voxel_size=voxel_size, return_rgb=True) | |
| mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy()) | |
| return mesh | |
| def export_image(self, path): | |
| render_path = os.path.join(path, "renders") | |
| gts_path = os.path.join(path, "gt") | |
| vis_path = os.path.join(path, "vis") | |
| os.makedirs(render_path, exist_ok=True) | |
| os.makedirs(vis_path, exist_ok=True) | |
| os.makedirs(gts_path, exist_ok=True) | |
| for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"): | |
| gt = viewpoint_cam.original_image[0:3, :, :] | |
| save_img_u8(gt.permute(1,2,0).cpu().numpy(), os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) | |
| save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) | |
| save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff")) | |
| save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png")) | |
| save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png")) |