alvin888's picture
Upload 1461 files
e2ebf5f verified
import os
import numpy as np
import open3d as o3d
import torch
from pytorch3d.renderer import RasterizationSettings, MeshRasterizer
from sugar.sugar_scene.gs_model import GaussianSplattingWrapper
from sugar.sugar_scene.sugar_model import SuGaR
from sugar.sugar_utils.general_utils import str2bool
from sugar.sugar_utils.spherical_harmonics import SH2RGB
from rich.console import Console
def extract_mesh_from_coarse_sugar(args):
CONSOLE = Console(width=120)
all_sugar_mesh_paths = []
# ========== Parameters ==========
use_train_test_split = True
n_skip_images_for_eval_split = 8
low_opacity_gaussian_pruning_threshold = 0.5
# Surface level extraction parameters
n_total_points = 10_000_000
use_gaussian_depth_for_surface_levels = False # False until now
surface_level_triangle_scale = 2. # 2.
# surface_level_triangle_scale = -2 * np.log(surface_level)
surface_level_primitive_types = 'diamond' # 'diamond'
surface_level_splat_mesh = True # True
surface_level_n_points_in_range = 21 # 21
surface_level_range_size = 3.0 # 3.0
surface_level_n_points_per_pass = 2_000_000 # '2_000_000'
surface_level_knn_to_track = 16 # 16
flat_surface_level_normals = False # False
use_fast_method = True # TODO: Was False before, but True seems better
# Mesh computation parameters
fg_bbox_factor = 1. # 1.
bg_bbox_factor = 4. # 4.
poisson_depth = 10 # 10
decimate_mesh = True
clean_mesh = True
# Vanilla 3DGS data
source_path = args.scene_path
gs_checkpoint_path = args.checkpoint_path
iteration_to_load = args.iteration_to_load
use_train_test_split = args.eval
# Coarse model path
sugar_checkpoint_path = args.coarse_model_path
# Surface levels to extract
if args.surface_level is None:
surface_levels = [0.1, 0.3, 0.5]
else:
surface_levels = [args.surface_level]
# Decimation targets
if args.decimation_target is None:
decimation_targets = [200_000, 1_000_000]
else:
decimation_targets = [args.decimation_target]
# Mesh output dir
if args.mesh_output_dir is None:
if len(args.scene_path.split("/")[-1]) > 0:
args.mesh_output_dir = os.path.join("./output/coarse_mesh", args.scene_path.split("/")[-1])
else:
args.mesh_output_dir = os.path.join("./output/coarse_mesh", args.scene_path.split("/")[-2])
mesh_output_dir = args.mesh_output_dir
os.makedirs(mesh_output_dir, exist_ok=True)
# Bounding box
if args.bboxmin is None:
use_custom_bbox = False
else:
if args.bboxmax is None:
raise ValueError("You need to specify both bboxmin and bboxmax.")
use_custom_bbox = True
# Parse bboxmin
if args.bboxmin[0] == '(':
args.bboxmin = args.bboxmin[1:]
if args.bboxmin[-1] == ')':
args.bboxmin = args.bboxmin[:-1]
args.bboxmin = tuple([float(x) for x in args.bboxmin.split(",")])
# Parse bboxmax
if args.bboxmax[0] == '(':
args.bboxmax = args.bboxmax[1:]
if args.bboxmax[-1] == ')':
args.bboxmax = args.bboxmax[:-1]
args.bboxmax = tuple([float(x) for x in args.bboxmax.split(",")])
fg_bbox_min = args.bboxmin
fg_bbox_max = args.bboxmax
center_bbox = args.center_bbox
use_centers_to_extract_mesh = args.use_centers_to_extract_mesh
use_marching_cubes = args.use_marching_cubes
use_vanilla_3dgs = args.use_vanilla_3dgs
CONSOLE.print("-----Parameters-----")
CONSOLE.print("Source path:", source_path)
CONSOLE.print("Gaussian Splatting Checkpoint path:", gs_checkpoint_path)
CONSOLE.print("Coarse model Checkpoint path:", sugar_checkpoint_path)
CONSOLE.print("Mesh output path:", mesh_output_dir)
CONSOLE.print("Surface levels:", surface_levels)
CONSOLE.print("Decimation targets:", decimation_targets)
CONSOLE.print("Use custom bbox:", use_custom_bbox)
CONSOLE.print("Use eval split:", use_train_test_split)
CONSOLE.print("GPU:", args.gpu)
CONSOLE.print("Use centers to extract mesh:", use_centers_to_extract_mesh)
CONSOLE.print("Use marching cubes:", use_marching_cubes)
CONSOLE.print("Use vanilla 3DGS:", use_vanilla_3dgs)
CONSOLE.print("--------------------")
# Set the GPU
torch.cuda.set_device(args.gpu)
# Load the initial 3DGS model
CONSOLE.print(f"Loading the initial 3DGS model from path {gs_checkpoint_path}...")
nerfmodel = GaussianSplattingWrapper(
source_path=source_path,
output_path=gs_checkpoint_path,
iteration_to_load=iteration_to_load,
load_gt_images=False,
eval_split=use_train_test_split,
eval_split_interval=n_skip_images_for_eval_split,
)
CONSOLE.print(f'{len(nerfmodel.training_cameras)} training images detected.')
CONSOLE.print(f'The model has been trained for {iteration_to_load} steps.')
# Load the coarse model
if use_vanilla_3dgs:
CONSOLE.print(f"\nUsing the vanilla 3DGS model for meshing...")
with torch.no_grad():
print("Initializing model from trained 3DGS...")
points = nerfmodel.gaussians.get_xyz.detach().float().cuda()
colors = SH2RGB(nerfmodel.gaussians.get_features[:, 0].detach().float().cuda())
sugar = SuGaR(
nerfmodel=nerfmodel,
points=points,
colors=colors,
initialize=True,
sh_levels=nerfmodel.gaussians.active_sh_degree+1,
keep_track_of_knn=True,
knn_to_track=16,
beta_mode='average', # 'learnable', 'average', 'weighted_average'
primitive_types='diamond', # 'diamond', 'square'
surface_mesh_to_bind=None, # Open3D mesh
)
with torch.no_grad():
sugar._scales[...] = nerfmodel.gaussians._scaling.detach()
sugar._quaternions[...] = nerfmodel.gaussians._rotation.detach()
sugar.all_densities[...] = nerfmodel.gaussians._opacity.detach()
sugar._sh_coordinates_dc[...] = nerfmodel.gaussians._features_dc.detach()
sugar._sh_coordinates_rest[...] = nerfmodel.gaussians._features_rest.detach()
else:
CONSOLE.print(f"\nLoading the coarse SuGaR model from path {sugar_checkpoint_path}...")
checkpoint = torch.load(sugar_checkpoint_path, map_location=nerfmodel.device)
colors = SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :])
sugar = SuGaR(
nerfmodel=nerfmodel,
points=checkpoint['state_dict']['_points'],
colors=colors,
initialize=True,
sh_levels=nerfmodel.gaussians.active_sh_degree+1,
keep_track_of_knn=True,
knn_to_track=16,
beta_mode='average', # 'learnable', 'average', 'weighted_average'
primitive_types='diamond', # 'diamond', 'square'
surface_mesh_to_bind=None, # Open3D mesh
)
sugar.load_state_dict(checkpoint['state_dict'])
sugar.eval()
CONSOLE.print("Coarse model loaded.")
CONSOLE.print("Coarse model parameters:")
for name, param in sugar.named_parameters():
CONSOLE.print(name, param.shape, param.requires_grad)
# Pruning low opacity gaussians
with torch.no_grad():
CONSOLE.print("Number of gaussians:", sugar.n_points)
CONSOLE.print("Opacities min/max/mean:", sugar.strengths.min(), sugar.strengths.max(), sugar.strengths.mean())
n_quantiles = 10
for i in range(n_quantiles):
CONSOLE.print(f'Quantile {i/n_quantiles}:', sugar.strengths.quantile(i/n_quantiles).item())
CONSOLE.print("\nStarting pruning low opacity gaussians...")
sugar.drop_low_opacity_points(low_opacity_gaussian_pruning_threshold)
CONSOLE.print("Number of gaussians left:", sugar.n_points)
CONSOLE.print("Opacities min/max/mean:", sugar.strengths.min(), sugar.strengths.max(), sugar.strengths.mean())
n_quantiles = 10
for i in range(n_quantiles):
CONSOLE.print(f'Quantile {i/n_quantiles}:', sugar.strengths.quantile(i/n_quantiles).item())
# Build the triangle soup that will be used for splatting
# sugar.primitive_types = 'square'
sugar.primitive_types = 'diamond'
sugar.triangle_scale = 2.
sugar.update_texture_features()
mesh = sugar.mesh
# Create a mesh renderer
faces_per_pixel = 10
max_faces_per_bin = 50_000
mesh_raster_settings = RasterizationSettings(
image_size=(sugar.image_height, sugar.image_width),
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
max_faces_per_bin=max_faces_per_bin
)
rasterizer = MeshRasterizer(
cameras=nerfmodel.training_cameras.p3d_cameras[0],
raster_settings=mesh_raster_settings,
)
if not use_marching_cubes:
if not use_centers_to_extract_mesh:
# Compute surface levels point clouds
n_pts_per_frame = int(n_total_points / len(nerfmodel.training_cameras)) + 1
sugar.knn_to_track = surface_level_knn_to_track
surface_levels_outputs = {}
for surface_level in surface_levels:
surface_levels_outputs[surface_level] = {
'points': torch.zeros(0, 3, device=sugar.device),
'colors': torch.zeros(0, 3, device=sugar.device),
'view_directions': torch.zeros(0, 3, device=sugar.device),
'pix_to_gaussians': torch.zeros(0, dtype=torch.long, device=sugar.device),
'normals': torch.zeros(0, 3, device=sugar.device),
}
with torch.no_grad():
cameras_to_use = nerfmodel.training_cameras
for cam_idx in range(len(nerfmodel.training_cameras)):
if cam_idx % 30 == 0:
CONSOLE.print(f"Processing frame {cam_idx}/{len(nerfmodel.training_cameras)}...")
for surface_level in surface_levels:
CONSOLE.print(f"Current point cloud for level {surface_level} has {len(surface_levels_outputs[surface_level]['points'])} points.")
point_depth = cameras_to_use.p3d_cameras[cam_idx].get_world_to_view_transform().transform_points(sugar.points)[..., 2:].expand(-1, 3)
# Render RGB image with Gaussian splatting
rgb = sugar.render_image_gaussian_rasterizer(
nerf_cameras=cameras_to_use,
camera_indices=cam_idx,
bg_color = None,
sh_deg=0, # nerfmodel.gaussians.active_sh_degree,
compute_color_in_rasterizer=True,
compute_covariance_in_rasterizer=True,
return_2d_radii=False,
use_same_scale_in_all_directions=False,
).clamp(min=0., max=1.).contiguous()
# Compute surface level points for the current frame
if cam_idx == 0:
sugar.reset_neighbors(knn_to_track=surface_level_knn_to_track)
with torch.no_grad():
if use_fast_method:
frame_surface_level_outputs = sugar.compute_level_surface_points_from_camera_fast(
cam_idx=cam_idx,
rasterizer=rasterizer,
surface_levels=surface_levels,
n_surface_points=2*n_pts_per_frame, # TODO: 2*n_pts_per_frame is safe to avoid empty pixels
primitive_types=surface_level_primitive_types,
triangle_scale=surface_level_triangle_scale,
splat_mesh=surface_level_splat_mesh,
n_points_in_range=surface_level_n_points_in_range,
range_size=surface_level_range_size,
n_points_per_pass=surface_level_n_points_per_pass,
density_factor=1.,
return_pixel_idx=True,
return_gaussian_idx=True,
return_normals=True,
compute_flat_normals=flat_surface_level_normals,
use_gaussian_depth=use_gaussian_depth_for_surface_levels,)
else:
frame_surface_level_outputs = sugar.compute_level_surface_points_from_camera_efficient(
cam_idx=cam_idx,
rasterizer=rasterizer,
surface_levels=surface_levels,
primitive_types=surface_level_primitive_types,
triangle_scale=surface_level_triangle_scale,
splat_mesh=surface_level_splat_mesh,
n_points_in_range=surface_level_n_points_in_range,
range_size=surface_level_range_size,
n_points_per_pass=surface_level_n_points_per_pass,
density_factor=1.,
return_depth=True,
return_gaussian_idx=True,
return_normals=True,
compute_flat_normals=flat_surface_level_normals,
use_gaussian_depth=use_gaussian_depth_for_surface_levels,)
for surface_level in surface_levels:
img_surface_points = frame_surface_level_outputs[surface_level]['intersection_points']
surface_gaussian_idx = frame_surface_level_outputs[surface_level]['gaussian_idx']
img_surface_normals = frame_surface_level_outputs[surface_level]['normals']
if use_fast_method:
pixel_idx = frame_surface_level_outputs[surface_level]['pixel_idx']
img_surface_colors = rgb.view(-1, 3)[pixel_idx]
else:
empty_pixels = frame_surface_level_outputs[surface_level]['empty_pixels']
img_surface_colors = rgb.view(-1, 3)[~empty_pixels]
img_surface_view_directions = torch.nn.functional.normalize(cameras_to_use.p3d_cameras[cam_idx].get_camera_center() - img_surface_points)
img_surface_pix_to_gaussians = surface_gaussian_idx.view(-1)
idx = torch.randperm(len(img_surface_points), device=sugar.device)[:n_pts_per_frame]
surface_levels_outputs[surface_level]['points'] = torch.cat([surface_levels_outputs[surface_level]['points'], img_surface_points[idx]], dim=0)
surface_levels_outputs[surface_level]['colors'] = torch.cat([surface_levels_outputs[surface_level]['colors'], img_surface_colors[idx]], dim=0)
surface_levels_outputs[surface_level]['view_directions'] = torch.cat([surface_levels_outputs[surface_level]['view_directions'], img_surface_view_directions[idx]], dim=0)
surface_levels_outputs[surface_level]['pix_to_gaussians'] = torch.cat([surface_levels_outputs[surface_level]['pix_to_gaussians'], img_surface_pix_to_gaussians[idx]], dim=0)
surface_levels_outputs[surface_level]['normals'] = torch.cat([surface_levels_outputs[surface_level]['normals'], img_surface_normals[idx]], dim=0)
# -----Processing surface levels-----
for surface_level in surface_levels:
CONSOLE.print("\n========== Processing surface level", surface_level, "==========")
CONSOLE.print(f"Final point cloud for level {surface_level} has {len(surface_levels_outputs[surface_level]['points'])} points.")
surface_points = surface_levels_outputs[surface_level]['points']
surface_colors = surface_levels_outputs[surface_level]['colors']
surface_normals = surface_levels_outputs[surface_level]['normals']
if use_custom_bbox:
CONSOLE.print("Using provided bounding box.")
fg_bbox_min_tensor = torch.tensor(fg_bbox_min).to(sugar.device)
fg_bbox_max_tensor = torch.tensor(fg_bbox_max).to(sugar.device)
else:
CONSOLE.print("Using default, camera based bounding box.")
fg_bbox_min_tensor = - fg_bbox_factor * sugar.get_cameras_spatial_extent() * torch.ones(1, 3, device=sugar.device)
fg_bbox_max_tensor = fg_bbox_factor * sugar.get_cameras_spatial_extent() * torch.ones(1, 3, device=sugar.device)
if center_bbox:
_cameras_spatial_extent, _camera_average_xyz = sugar.get_cameras_spatial_extent(return_average_xyz=True)
with torch.no_grad():
CONSOLE.print("Centering bounding box.")
fg_bbox_min_tensor = fg_bbox_min_tensor + _camera_average_xyz
fg_bbox_max_tensor = fg_bbox_max_tensor + _camera_average_xyz
points_idx = torch.arange(len(surface_points))
fg_mask = (surface_points[points_idx] > fg_bbox_min_tensor).all(dim=-1) * (surface_points[points_idx] < fg_bbox_max_tensor).all(dim=-1)
if center_bbox:
bg_mask = ((surface_points[points_idx] - _camera_average_xyz).abs().max(dim=-1)[0]
< bg_bbox_factor * _cameras_spatial_extent) * ~fg_mask
else:
bg_mask = (surface_points[points_idx].abs().max(dim=-1)[0] < bg_bbox_factor * sugar.get_cameras_spatial_extent()) * ~fg_mask
fg_points = surface_points[points_idx][fg_mask]
fg_colors = surface_colors[points_idx][fg_mask]
fg_normals = surface_normals[points_idx][fg_mask]
bg_points = surface_points[points_idx][bg_mask]
bg_colors = surface_colors[points_idx][bg_mask]
bg_normals = surface_normals[points_idx][bg_mask]
CONSOLE.print("Foreground points:", fg_points.shape, fg_colors.shape, fg_normals.shape)
CONSOLE.print("Background points:", bg_points.shape, bg_colors.shape, bg_normals.shape)
# ---Compute foreground mesh---
CONSOLE.print("\n-----Foreground mesh-----")
if fg_points.shape[0] > 0:
CONSOLE.print("Computing points, colors and normals...")
fg_pcd = o3d.geometry.PointCloud()
fg_pcd.points = o3d.utility.Vector3dVector(fg_points.double().cpu().numpy())
fg_pcd.colors = o3d.utility.Vector3dVector(fg_colors.double().cpu().numpy())
fg_pcd.normals = o3d.utility.Vector3dVector(fg_normals.double().cpu().numpy())
# outliers removal
cl, ind = fg_pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.)
CONSOLE.print("Cleaning Point Cloud...")
fg_pcd = fg_pcd.select_by_index(ind)
CONSOLE.print("Finished computing points, colors and normals.")
CONSOLE.print("Now computing mesh...")
o3d_fg_mesh, o3d_fg_densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
fg_pcd, depth=poisson_depth) #, width=0, scale=1.1, linear_fit=False) # depth=10 should be the default value? 11 is good to (but it starts to make a big number of triangles)
CONSOLE.print("Removing vertices with low densities...")
vertices_to_remove = o3d_fg_densities < np.quantile(o3d_fg_densities, 0.1)
o3d_fg_mesh.remove_vertices_by_mask(vertices_to_remove)
else:
CONSOLE.print("\n[WARNING] Foreground is empty.")
o3d_fg_mesh = None
# ---Compute background mesh---
CONSOLE.print("\n-----Background mesh-----")
if bg_points.shape[0] > 0:
CONSOLE.print("Computing points, colors and normals...")
bg_pcd = o3d.geometry.PointCloud()
bg_pcd.points = o3d.utility.Vector3dVector(bg_points.double().cpu().numpy())
bg_pcd.colors = o3d.utility.Vector3dVector(bg_colors.double().cpu().numpy())
bg_pcd.normals = o3d.utility.Vector3dVector(bg_normals.double().cpu().numpy())
# outliers removal
cl, ind = bg_pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.)
CONSOLE.print("Cleaning Point Cloud...")
bg_pcd = bg_pcd.select_by_index(ind)
CONSOLE.print("Finished computing points, colors and normals.")
CONSOLE.print("Now computing mesh...")
o3d_bg_mesh, o3d_bg_densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
bg_pcd, depth=poisson_depth) #, width=0, scale=1.1, linear_fit=False) # depth=10 should be the default value? 11 is good to (but it starts to make a big number of triangles)
CONSOLE.print("Removing vertices with low densities...")
vertices_to_remove = o3d_bg_densities < np.quantile(o3d_bg_densities, 0.1)
o3d_bg_mesh.remove_vertices_by_mask(vertices_to_remove)
else:
CONSOLE.print("\n[WARNING] Background is empty.")
o3d_bg_mesh = None
CONSOLE.print("Finished computing meshes.")
CONSOLE.print("Foreground mesh:", o3d_fg_mesh)
CONSOLE.print("Background mesh:", o3d_bg_mesh)
# ---Decimate and clean meshes---
CONSOLE.print("\n-----Decimating and cleaning meshes-----")
for decimation_target in decimation_targets:
CONSOLE.print("\nProcessing decimation target:", decimation_target)
if decimate_mesh:
if o3d_fg_mesh is not None:
CONSOLE.print("Decimating foreground mesh...")
decimated_o3d_fg_mesh = o3d_fg_mesh.simplify_quadric_decimation(decimation_target)
CONSOLE.print("Finished decimating foreground mesh.")
else:
decimated_o3d_fg_mesh = None
if o3d_bg_mesh is not None:
CONSOLE.print("Decimating background mesh...")
decimated_o3d_bg_mesh = o3d_bg_mesh.simplify_quadric_decimation(decimation_target)
CONSOLE.print("Finished decimating background mesh.")
else:
decimated_o3d_bg_mesh = None
if clean_mesh:
CONSOLE.print("Cleaning mesh...")
if decimated_o3d_fg_mesh is not None:
decimated_o3d_fg_mesh.remove_degenerate_triangles()
decimated_o3d_fg_mesh.remove_duplicated_triangles()
decimated_o3d_fg_mesh.remove_duplicated_vertices()
decimated_o3d_fg_mesh.remove_non_manifold_edges()
if decimated_o3d_bg_mesh is not None:
decimated_o3d_bg_mesh.remove_degenerate_triangles()
decimated_o3d_bg_mesh.remove_duplicated_triangles()
decimated_o3d_bg_mesh.remove_duplicated_vertices()
decimated_o3d_bg_mesh.remove_non_manifold_edges()
if (decimated_o3d_fg_mesh is not None) and (decimated_o3d_bg_mesh is not None):
CONSOLE.print("Merging foreground and background meshes.")
decimated_o3d_mesh = decimated_o3d_fg_mesh + decimated_o3d_bg_mesh
elif decimated_o3d_fg_mesh is not None:
CONSOLE.print("Using foreground mesh only, since background mesh is empty.")
decimated_o3d_mesh = decimated_o3d_fg_mesh
elif decimated_o3d_bg_mesh is not None:
CONSOLE.print("Using background mesh only, since foreground mesh is empty.")
decimated_o3d_mesh = decimated_o3d_bg_mesh
else:
raise ValueError("Both foreground and background meshes are empty. Please provide a valid bounding box for the scene.")
if use_vanilla_3dgs:
sugar_mesh_path = 'sugarmesh_vanilla3dgs_levelZZ_decimAA.ply'
else:
sugar_mesh_path = 'sugarmesh_' + sugar_checkpoint_path.split('/')[-2].replace('sugarcoarse_', '') + '_levelZZ_decimAA.ply'
sugar_mesh_path = sugar_mesh_path.replace(
'ZZ', str(surface_level).replace('.', '')
).replace(
'AA', str(decimation_target).replace('.', '')
)
sugar_mesh_path = os.path.join(mesh_output_dir, sugar_mesh_path)
o3d.io.write_triangle_mesh(sugar_mesh_path, decimated_o3d_mesh, write_triangle_uvs=True, write_vertex_colors=True, write_vertex_normals=True)
CONSOLE.print("Mesh saved at", sugar_mesh_path)
all_sugar_mesh_paths.append(sugar_mesh_path)
else:
CONSOLE.print("\nWARNING: Using centers of gaussians to extract mesh.")
CONSOLE.print("Results will look bad, this is not the best way to extract a mesh.")
CONSOLE.print("You should use this option only for ablation.")
with torch.no_grad():
surface_points = sugar.points
surface_colors = SH2RGB(sugar._sh_coordinates_dc[:, 0, :])
surface_normals = sugar.get_normals(estimate_from_points=True)
if use_custom_bbox:
CONSOLE.print("Using provided bounding box.")
fg_bbox_min_tensor = torch.tensor(fg_bbox_min).to(sugar.device)
fg_bbox_max_tensor = torch.tensor(fg_bbox_max).to(sugar.device)
else:
CONSOLE.print("Using default, camera based bounding box.")
fg_bbox_min_tensor = - fg_bbox_factor * sugar.get_cameras_spatial_extent() * torch.ones(1, 3, device=sugar.device)
fg_bbox_max_tensor = fg_bbox_factor * sugar.get_cameras_spatial_extent() * torch.ones(1, 3, device=sugar.device)
points_idx = torch.arange(len(surface_points))
fg_mask = (surface_points[points_idx] > fg_bbox_min_tensor).all(dim=-1) * (surface_points[points_idx] < fg_bbox_max_tensor).all(dim=-1)
bg_mask = (surface_points[points_idx].abs().max(dim=-1)[0] < bg_bbox_factor * sugar.get_cameras_spatial_extent()) * ~fg_mask
fg_points = surface_points[points_idx][fg_mask]
fg_colors = surface_colors[points_idx][fg_mask]
fg_normals = surface_normals[points_idx][fg_mask]
bg_points = surface_points[points_idx][bg_mask]
bg_colors = surface_colors[points_idx][bg_mask]
bg_normals = surface_normals[points_idx][bg_mask]
CONSOLE.print("Foreground points:", fg_points.shape, fg_colors.shape, fg_normals.shape)
CONSOLE.print("Background points:", bg_points.shape, bg_colors.shape, bg_normals.shape)
# ---Compute foreground mesh---
CONSOLE.print("\n-----Foreground mesh-----")
CONSOLE.print("Computing points, colors and normals...")
fg_pcd = o3d.geometry.PointCloud()
fg_pcd.points = o3d.utility.Vector3dVector(fg_points.double().cpu().numpy())
fg_pcd.colors = o3d.utility.Vector3dVector(fg_colors.double().cpu().numpy())
fg_pcd.normals = o3d.utility.Vector3dVector(fg_normals.double().cpu().numpy())
# outliers removal
cl, ind = fg_pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.)
CONSOLE.print("Cleaning Point Cloud...")
fg_pcd = fg_pcd.select_by_index(ind)
CONSOLE.print("Finished computing points, colors and normals.")
CONSOLE.print("Now computing mesh...")
o3d_fg_mesh, o3d_fg_densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
fg_pcd, depth=poisson_depth) #, width=0, scale=1.1, linear_fit=False) # depth=10 should be the default value? 11 is good to (but it starts to make a big number of triangles)
CONSOLE.print("Removing vertices with low densities...")
vertices_to_remove = o3d_fg_densities < np.quantile(o3d_fg_densities, 0.1)
o3d_fg_mesh.remove_vertices_by_mask(vertices_to_remove)
# ---Compute background mesh---
if bg_points.shape[0] > 0:
CONSOLE.print("\n-----Background mesh-----")
CONSOLE.print("Computing points, colors and normals...")
bg_pcd = o3d.geometry.PointCloud()
bg_pcd.points = o3d.utility.Vector3dVector(bg_points.double().cpu().numpy())
bg_pcd.colors = o3d.utility.Vector3dVector(bg_colors.double().cpu().numpy())
bg_pcd.normals = o3d.utility.Vector3dVector(bg_normals.double().cpu().numpy())
# outliers removal
cl, ind = bg_pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.)
CONSOLE.print("Cleaning Point Cloud...")
bg_pcd = bg_pcd.select_by_index(ind)
CONSOLE.print("Finished computing points, colors and normals.")
CONSOLE.print("Now computing mesh...")
o3d_bg_mesh, o3d_bg_densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
bg_pcd, depth=poisson_depth) #, width=0, scale=1.1, linear_fit=False) # depth=10 should be the default value? 11 is good to (but it starts to make a big number of triangles)
CONSOLE.print("Removing vertices with low densities...")
vertices_to_remove = o3d_bg_densities < np.quantile(o3d_bg_densities, 0.1)
o3d_bg_mesh.remove_vertices_by_mask(vertices_to_remove)
else:
o3d_bg_mesh = None
CONSOLE.print("Finished computing meshes.")
CONSOLE.print("Foreground mesh:", o3d_fg_mesh)
CONSOLE.print("Background mesh:", o3d_bg_mesh)
# ---Decimate and clean meshes---
CONSOLE.print("\n-----Decimating and cleaning meshes-----")
for decimation_target in decimation_targets:
CONSOLE.print("\nProcessing decimation target:", decimation_target)
if decimate_mesh:
CONSOLE.print("Decimating foreground mesh...")
decimated_o3d_fg_mesh = o3d_fg_mesh.simplify_quadric_decimation(decimation_target)
CONSOLE.print("Finished decimating foreground mesh.")
if o3d_bg_mesh is not None:
CONSOLE.print("Decimating background mesh...")
decimated_o3d_bg_mesh = o3d_bg_mesh.simplify_quadric_decimation(decimation_target)
CONSOLE.print("Finished decimating background mesh.")
if clean_mesh:
CONSOLE.print("Cleaning mesh...")
decimated_o3d_fg_mesh.remove_degenerate_triangles()
decimated_o3d_fg_mesh.remove_duplicated_triangles()
decimated_o3d_fg_mesh.remove_duplicated_vertices()
decimated_o3d_fg_mesh.remove_non_manifold_edges()
if decimated_o3d_bg_mesh is not None:
decimated_o3d_bg_mesh.remove_degenerate_triangles()
decimated_o3d_bg_mesh.remove_duplicated_triangles()
decimated_o3d_bg_mesh.remove_duplicated_vertices()
decimated_o3d_bg_mesh.remove_non_manifold_edges()
if decimated_o3d_bg_mesh is not None:
decimated_o3d_mesh = decimated_o3d_fg_mesh + decimated_o3d_bg_mesh
else:
decimated_o3d_mesh = decimated_o3d_fg_mesh
if use_vanilla_3dgs:
sugar_mesh_path = 'sugarmesh_vanilla3dgs_poissoncenters_decimAA.ply'
else:
sugar_mesh_path = 'sugarmesh_' + sugar_checkpoint_path.split('/')[-2].replace('sugarcoarse_', '') + '_poissoncenters_decimAA.ply'
sugar_mesh_path = sugar_mesh_path.replace(
'AA', str(decimation_target).replace('.', '')
)
sugar_mesh_path = os.path.join(mesh_output_dir, sugar_mesh_path)
o3d.io.write_triangle_mesh(sugar_mesh_path, decimated_o3d_mesh, write_triangle_uvs=True, write_vertex_colors=True, write_vertex_normals=True)
CONSOLE.print("Mesh saved at", sugar_mesh_path)
all_sugar_mesh_paths.append(sugar_mesh_path)
else:
CONSOLE.print("\nWARNING: Using marching cubes to extract mesh.")
import mcubes
sugar.reset_neighbors(knn_to_track=16)
resolution = 512
surface_level = surface_levels[0]
decimation_target = decimation_targets[0]
# Foreground mesh
CONSOLE.print("\n-----Foreground mesh-----")
X = torch.linspace(-1, 1, resolution) * sugar.get_cameras_spatial_extent()
Y = torch.linspace(-1, 1, resolution) * sugar.get_cameras_spatial_extent()
Z = torch.linspace(-1, 1, resolution) * sugar.get_cameras_spatial_extent()
xx, yy, zz = torch.meshgrid(X, Y, Z)
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(sugar.device)
xx.shape, yy.shape, zz.shape, pts.shape
n_pts_per_pass = 2_000_000
densities = torch.zeros(0, device=sugar.device)
CONSOLE.print("Computing densities...")
with torch.no_grad():
for i in range(0, len(pts), n_pts_per_pass):
print("\nPts:", i, 'to', i+n_pts_per_pass)
pts_i = pts[i:i+n_pts_per_pass]
densities_i = sugar.compute_density(pts_i)
densities = torch.cat([densities, densities_i], dim=0)
densities = densities.reshape(resolution, resolution, resolution)
CONSOLE.print("Finished computing densities.")
density_th = surface_levels[0] # 1.
CONSOLE.print(f"Computing mesh for surface level {density_th}...")
vertices, triangles = mcubes.marching_cubes(densities.cpu().numpy(), density_th)
verts = -sugar.get_cameras_spatial_extent() + (torch.tensor(vertices) / resolution) * 2 * sugar.get_cameras_spatial_extent()
faces = torch.tensor(triangles.tolist())
closest_gaussians = sugar.get_gaussians_closest_to_samples(verts.float().to(sugar.device))
verts_colors = SH2RGB(sugar._sh_coordinates_dc[closest_gaussians[:, 0]][:, 0, :])
mc_mesh = o3d.geometry.TriangleMesh()
mc_mesh.vertices = o3d.utility.Vector3dVector(verts.cpu().numpy())
mc_mesh.triangles = o3d.utility.Vector3iVector(faces.cpu().numpy())
mc_mesh.vertex_colors = o3d.utility.Vector3dVector(verts_colors.cpu().numpy())
mc_mesh.compute_vertex_normals()
CONSOLE.print("Finished computing mesh.")
# Background mesh
CONSOLE.print("\n-----Background mesh-----")
X = torch.linspace(-1, 1, resolution) * 4 * sugar.get_cameras_spatial_extent()
Y = torch.linspace(-1, 1, resolution) * 4 * sugar.get_cameras_spatial_extent()
Z = torch.linspace(-1, 1, resolution) * 4 * sugar.get_cameras_spatial_extent()
xx, yy, zz = torch.meshgrid(X, Y, Z)
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(sugar.device)
xx.shape, yy.shape, zz.shape, pts.shape
n_pts_per_pass = 2_000_000
densities = torch.zeros(0, device=sugar.device)
CONSOLE.print("Computing densities...")
with torch.no_grad():
for i in range(0, len(pts), n_pts_per_pass):
print("\nPts:", i, 'to', i+n_pts_per_pass)
pts_i = pts[i:i+n_pts_per_pass]
densities_i = sugar.compute_density(pts_i)
densities = torch.cat([densities, densities_i], dim=0)
CONSOLE.print("Finished computing densities.")
# Removing pts in foreground
densities[(pts > -sugar.get_cameras_spatial_extent()).all(dim=-1) * (pts < sugar.get_cameras_spatial_extent()).all(dim=-1)] = 0.
densities = densities.reshape(resolution, resolution, resolution)
density_th = surface_levels[0] # 1.
CONSOLE.print(f"Computing mesh for surface level {density_th}...")
bg_vertices, bg_triangles = mcubes.marching_cubes(densities.cpu().numpy(), density_th)
bg_verts = - 4 * sugar.get_cameras_spatial_extent() + (torch.tensor(bg_vertices) / resolution) * 2 * 4 * sugar.get_cameras_spatial_extent()
bg_faces = torch.tensor(bg_triangles.tolist())
closest_gaussians = sugar.get_gaussians_closest_to_samples(bg_verts.float().to(sugar.device))
bg_verts_colors = SH2RGB(sugar._sh_coordinates_dc[closest_gaussians[:, 0]][:, 0, :])
bg_mc_mesh = o3d.geometry.TriangleMesh()
bg_mc_mesh.vertices = o3d.utility.Vector3dVector(bg_verts.cpu().numpy())
bg_mc_mesh.triangles = o3d.utility.Vector3iVector(bg_faces.cpu().numpy())
bg_mc_mesh.vertex_colors = o3d.utility.Vector3dVector(bg_verts_colors.cpu().numpy())
bg_mc_mesh.compute_vertex_normals()
CONSOLE.print("Finished computing mesh.")
# Decimate and clean meshes
decimate_mesh = True
decimation_target = decimation_targets[0]
if decimate_mesh:
print(f"Decimating mesh to target {decimation_target}...")
decimated_o3d_fg_mesh = mc_mesh.simplify_quadric_decimation(decimation_target)
print("Finished decimating mesh.")
print("Decimating mesh...")
decimated_o3d_bg_mesh = bg_mc_mesh.simplify_quadric_decimation(decimation_target)
print("Finished decimating mesh.")
else:
decimated_o3d_fg_mesh = mc_mesh
decimated_o3d_bg_mesh = bg_mc_mesh
clean_mesh = True
if clean_mesh:
decimated_o3d_fg_mesh.remove_degenerate_triangles()
decimated_o3d_fg_mesh.remove_duplicated_triangles()
decimated_o3d_fg_mesh.remove_duplicated_vertices()
decimated_o3d_fg_mesh.remove_non_manifold_edges()
decimated_o3d_bg_mesh.remove_degenerate_triangles()
decimated_o3d_bg_mesh.remove_duplicated_triangles()
decimated_o3d_bg_mesh.remove_duplicated_vertices()
decimated_o3d_bg_mesh.remove_non_manifold_edges()
decimated_o3d_mesh = decimated_o3d_fg_mesh + decimated_o3d_bg_mesh
if use_vanilla_3dgs:
sugar_mesh_path = 'sugarmesh_vanilla3dgsmarchingcubes_levelZZ_decimAA.ply'
else:
sugar_mesh_path = 'sugarmesh_' + sugar_checkpoint_path.split('/')[-2].replace('sugarcoarse_', '') + 'marchingcubes_levelZZ_decimAA.ply'
sugar_mesh_path = sugar_mesh_path.replace(
'ZZ', str(surface_level).replace('.', '')
).replace(
'AA', str(decimation_target).replace('.', '')
)
sugar_mesh_path = os.path.join(mesh_output_dir, sugar_mesh_path)
o3d.io.write_triangle_mesh(sugar_mesh_path, decimated_o3d_mesh, write_triangle_uvs=True, write_vertex_colors=True, write_vertex_normals=True)
CONSOLE.print("Mesh saved at", sugar_mesh_path)
all_sugar_mesh_paths.append(sugar_mesh_path)
return all_sugar_mesh_paths