alvin888's picture
Upload 1461 files
e2ebf5f verified
raw
history blame contribute delete
11.6 kB
import sys
sys.path.append('./gaussian_splatting')
import os
import torch
import plotly.graph_objs as go
from sugar.gaussian_splatting.scene.gaussian_model import GaussianModel
from sugar.gaussian_splatting.gaussian_renderer import render as gs_render
from sugar.gaussian_splatting.scene.dataset_readers import fetchPly
from sugar.sugar_utils.spherical_harmonics import SH2RGB
from sugar.sugar_scene.cameras import CamerasWrapper, load_gs_cameras
class ModelParams():
"""Parameters of the Gaussian Splatting model.
Largely inspired by the original implementation of the 3D Gaussian Splatting paper:
https://github.com/graphdeco-inria/gaussian-splatting
"""
def __init__(self):
self.sh_degree = 3
self.source_path = ""
self.model_path = ""
self.images = "images"
self.resolution = -1
self.white_background = False
self.data_device = "cuda"
self.eval = False
class PipelineParams():
"""Parameters of the Gaussian Splatting pipeline.
Largely inspired by the original implementation of the 3D Gaussian Splatting paper:
https://github.com/graphdeco-inria/gaussian-splatting
"""
def __init__(self):
self.convert_SHs_python = False
self.compute_cov3D_python = False
self.debug = False
class OptimizationParams():
"""Parameters of the Gaussian Splatting optimization.
Largely inspired by the original implementation of the 3D Gaussian Splatting paper:
https://github.com/graphdeco-inria/gaussian-splatting
"""
def __init__(self):
self.iterations = 30_000
self.position_lr_init = 0.00016
self.position_lr_final = 0.0000016
self.position_lr_delay_mult = 0.01
self.position_lr_max_steps = 30_000
self.feature_lr = 0.0025
self.opacity_lr = 0.05
self.scaling_lr = 0.005
self.rotation_lr = 0.001
self.percent_dense = 0.01
self.lambda_dssim = 0.2
self.densification_interval = 100
self.opacity_reset_interval = 3000
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002
class GaussianSplattingWrapper:
"""Class to wrap original Gaussian Splatting models and facilitates both usage and integration with PyTorch3D.
"""
def __init__(self,
source_path: str,
output_path: str,
iteration_to_load:int=30_000,
model_params: ModelParams=None,
pipeline_params: PipelineParams=None,
opt_params: OptimizationParams=None,
load_gt_images=True,
eval_split=False,
eval_split_interval=8,
) -> None:
"""Initialize the Gaussian Splatting model wrapper.
Args:
source_path (str): Path to the directory containing the source images.
output_path (str): Path to the directory containing the output of the Gaussian Splatting optimization.
iteration_to_load (int, optional): Checkpoint to load. Should be 7000 or 30_000. Defaults to 30_000.
model_params (ModelParams, optional): Model parameters. Defaults to None.
pipeline_params (PipelineParams, optional): Pipeline parameters. Defaults to None.
opt_params (OptimizationParams, optional): Optimization parameters. Defaults to None.
load_gt_images (bool, optional): If True, will load all GT images in the source folder.
Useful for evaluating the model, but loading can take a few minutes. Defaults to True.
eval_split (bool, optional): If True, will split images and cameras into a training set and an evaluation set.
Defaults to False.
eval_split_interval (int, optional): Every eval_split_interval images, an image is added to the evaluation set.
Defaults to 8 (following standard practice).
"""
self.source_path = source_path
self.output_path = output_path
self.loaded_iteration = iteration_to_load
if model_params is None:
model_params = ModelParams()
if pipeline_params is None:
pipeline_params = PipelineParams()
if opt_params is None:
opt_params = OptimizationParams()
self.model_params = model_params
self.pipeline_params = pipeline_params
self.opt_params = opt_params
self._C0 = 0.28209479177387814
cam_list = load_gs_cameras(
source_path=source_path,
gs_output_path=output_path,
load_gt_images=load_gt_images,
)
if eval_split:
self.cam_list = []
self.test_cam_list = []
for i, cam in enumerate(cam_list):
if i % eval_split_interval == 0:
self.test_cam_list.append(cam)
else:
self.cam_list.append(cam)
# test_ns_cameras = convert_camera_from_gs_to_nerfstudio(self.test_cam_list)
# self.test_cameras = NeRFCameras.from_ns_cameras(test_ns_cameras)
self.test_cameras = CamerasWrapper(self.test_cam_list)
else:
self.cam_list = cam_list
self.test_cam_list = None
self.test_cameras = None
# ns_cameras = convert_camera_from_gs_to_nerfstudio(self.cam_list)
# self.training_cameras = NeRFCameras.from_ns_cameras(ns_cameras)
self.training_cameras = CamerasWrapper(self.cam_list)
self.gaussians = GaussianModel(self.model_params.sh_degree)
self.gaussians.load_ply(
os.path.join(
output_path,
"point_cloud",
"iteration_" + str(iteration_to_load),
"point_cloud.ply"
)
)
@property
def device(self):
with torch.no_grad():
return self.gaussians.get_xyz.device
@property
def image_height(self):
return self.cam_list[0].image_height
@property
def image_width(self):
return self.cam_list[0].image_width
def render_image(
self,
nerf_cameras:CamerasWrapper=None,
camera_indices:int=0,
return_whole_package=False):
"""Render an image with Gaussian Splatting rasterizer.
Args:
nerf_cameras (CamerasWrapper, optional): Set of cameras.
If None, uses the training cameras, but can be any set of cameras. Defaults to None.
camera_indices (int, optional): Index of the camera to render in the set of cameras.
Defaults to 0.
return_whole_package (bool, optional): If True, returns the whole output package
as computed in the original rasterizer from 3D Gaussian Splatting paper. Defaults to False.
Returns:
Tensor or Dict: A tensor of the rendered RGB image, or the whole output package.
"""
if nerf_cameras is None:
gs_cameras = self.cam_list
else:
gs_cameras = nerf_cameras.gs_cameras
camera = gs_cameras[camera_indices]
render_pkg = gs_render(camera, self.gaussians,
self.pipeline_params,
bg_color=torch.zeros(3, device='cuda'))
if return_whole_package:
return render_pkg
else:
image = render_pkg["render"]
return image.permute(1, 2, 0)
def get_gt_image(self, camera_indices:int, to_cuda=False):
"""Returns the ground truth image corresponding to the training camera at the given index.
Args:
camera_indices (int): Index of the camera in the set of cameras.
to_cuda (bool, optional): If True, moves the image to GPU. Defaults to False.
Returns:
Tensor: The ground truth image.
"""
gt_image = self.cam_list[camera_indices].original_image
if to_cuda:
gt_image = gt_image.cuda()
return gt_image.permute(1, 2, 0)
def get_test_gt_image(self, camera_indices:int, to_cuda=False):
"""Returns the ground truth image corresponding to the test camera at the given index.
Args:
camera_indices (int): Index of the camera in the set of cameras.
to_cuda (bool, optional): If True, moves the image to GPU. Defaults to False.
Returns:
Tensor: The ground truth image.
"""
gt_image = self.test_cam_list[camera_indices].original_image
if to_cuda:
gt_image = gt_image.cuda()
return gt_image.permute(1, 2, 0)
def downscale_output_resolution(self, downscale_factor):
"""Downscale the output resolution of the Gaussian Splatting model.
Args:
downscale_factor (float): Factor by which to downscale the resolution.
"""
self.training_cameras.rescale_output_resolution(1.0 / downscale_factor)
def generate_point_cloud(self):
"""Generate a point cloud from the Gaussian Splatting model.
Returns:
(Tensor, Tensor): The points and the colors of the point cloud.
Each has shape (N, 3), where N is the number of Gaussians.
"""
with torch.no_grad():
points = self.gaussians.get_xyz
# colors = self.gaussians.get_features[:, 0] * self._C0 + 0.5
colors = SH2RGB(self.gaussians.get_features[:, 0])
return points, colors
def plot_point_cloud(
self,
points=None,
colors=None,
n_points_to_plot: int = 50000,
width=1000,
height=500,
):
"""Plot the generated 3D point cloud with plotly.
Args:
n_points_to_plot (int, optional): _description_. Defaults to 50000.
points (_type_, optional): _description_. Defaults to None.
colors (_type_, optional): _description_. Defaults to None.
width (int, optional): Defaults to 1000.
height (int, optional): Defaults to 1000.
Raises:
ValueError: _description_
Returns:
go.Figure: The plotly figure.
"""
with torch.no_grad():
if points is None:
points, colors = self.generate_point_cloud()
points_idx = torch.randperm(points.shape[0])[:n_points_to_plot]
points_to_plot = points[points_idx].cpu()
colors_to_plot = colors[points_idx].cpu()
z = points_to_plot[:, 2]
x = points_to_plot[:, 0]
y = points_to_plot[:, 1]
trace = go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
size=3,
color=colors_to_plot, # set color to an array/list of desired values
# colorscale = 'Magma'
),
)
layout = go.Layout(
scene=dict(bgcolor="white", aspectmode="data"),
template="none",
width=width,
height=height,
)
fig = go.Figure(data=[trace], layout=layout)
# fig.update_layout(template='none', scene_aspectmode='data')
# fig.show()
return fig