|
|
"""Contains utility code for gsplat renderer. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import NamedTuple |
|
|
|
|
|
import gsplat |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from sharp.utils import color_space as cs_utils |
|
|
from sharp.utils import io, vis |
|
|
from sharp.utils.gaussians import BackgroundColor, Gaussians3D |
|
|
|
|
|
|
|
|
class RenderingOutputs(NamedTuple): |
|
|
"""Outputs of 3D Gaussians renderer.""" |
|
|
|
|
|
color: torch.Tensor |
|
|
depth: torch.Tensor |
|
|
alpha: torch.Tensor |
|
|
|
|
|
|
|
|
def write_renderings(rendering: RenderingOutputs, output_folder: Path, filename: str): |
|
|
"""Write rendered color/depth/alpha to files.""" |
|
|
batch_size = len(rendering.color) |
|
|
if batch_size != 1: |
|
|
raise RuntimeError("We only support saving rendering of batch size = 1") |
|
|
|
|
|
def _save_image_tensor(tensor: torch.Tensor, suffix: str): |
|
|
np_array = tensor.permute(1, 2, 0).numpy() |
|
|
io.save_image(np_array, (output_folder / filename).with_suffix(suffix)) |
|
|
|
|
|
color = (rendering.color[0].cpu() * 255.0).to(dtype=torch.uint8) |
|
|
colorized_depth = vis.colorize_depth(rendering.depth[0], val_max=100.0) |
|
|
colorized_alpha = vis.colorize_alpha(rendering.alpha[0]) |
|
|
|
|
|
_save_image_tensor(color, ".color.png") |
|
|
_save_image_tensor(colorized_depth, ".depth.png") |
|
|
_save_image_tensor(colorized_alpha, ".alpha.png") |
|
|
|
|
|
|
|
|
class GSplatRenderer(nn.Module): |
|
|
"""Module to render 3D Gaussians to images using gsplat.""" |
|
|
|
|
|
color_space: cs_utils.ColorSpace |
|
|
background_color: BackgroundColor |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
color_space: cs_utils.ColorSpace = "sRGB", |
|
|
background_color: BackgroundColor = "black", |
|
|
low_pass_filter_eps: float = 0.0, |
|
|
) -> None: |
|
|
"""Initialize gsplat renderer. |
|
|
|
|
|
Args: |
|
|
color_space: The color space to use for rendering. |
|
|
background_color: The background color to use for rendering. |
|
|
low_pass_filter_eps: The epsilon value for the low pass filter. |
|
|
""" |
|
|
super().__init__() |
|
|
self.color_space = color_space |
|
|
self.background_color = background_color |
|
|
self.low_pass_filter_eps = low_pass_filter_eps |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
gaussians: Gaussians3D, |
|
|
extrinsics: torch.Tensor, |
|
|
intrinsics: torch.Tensor, |
|
|
image_width: int, |
|
|
image_height: int, |
|
|
) -> RenderingOutputs: |
|
|
"""Predict images from gaussians. |
|
|
|
|
|
Args: |
|
|
gaussians: The Gaussians to render. |
|
|
extrinsics: The extrinsics of the camera to render to in OpenCV format. |
|
|
intrinsics: The intriniscs of the camera to render to in OpenCV format. |
|
|
image_width: The desired output image width. |
|
|
image_height: The desired output image height. |
|
|
""" |
|
|
batch_size = len(gaussians.mean_vectors) |
|
|
outputs_list: list[RenderingOutputs] = [] |
|
|
|
|
|
for ib in range(batch_size): |
|
|
colors, alphas, meta = gsplat.rendering.rasterization( |
|
|
means=gaussians.mean_vectors[ib], |
|
|
quats=gaussians.quaternions[ib], |
|
|
scales=gaussians.singular_values[ib], |
|
|
opacities=gaussians.opacities[ib], |
|
|
colors=gaussians.colors[ib], |
|
|
viewmats=extrinsics[ib : ib + 1], |
|
|
Ks=intrinsics[ib : ib + 1, :3, :3], |
|
|
width=image_width, |
|
|
height=image_height, |
|
|
render_mode="RGB+D", |
|
|
rasterize_mode="classic", |
|
|
absgrad=False, |
|
|
packed=False, |
|
|
eps2d=self.low_pass_filter_eps, |
|
|
) |
|
|
|
|
|
rendered_color = colors[..., 0:3].permute([0, 3, 1, 2]) |
|
|
rendered_depth_unnormalized = colors[..., 3:4].permute([0, 3, 1, 2]) |
|
|
rendered_alpha = alphas.permute([0, 3, 1, 2]) |
|
|
|
|
|
|
|
|
rendered_color = self.compose_with_background( |
|
|
rendered_color, rendered_alpha, self.background_color |
|
|
) |
|
|
|
|
|
|
|
|
if self.color_space == "sRGB": |
|
|
pass |
|
|
elif self.color_space == "linearRGB": |
|
|
rendered_color = cs_utils.linearRGB2sRGB(rendered_color) |
|
|
else: |
|
|
ValueError("Unsupported ColorSpace type.") |
|
|
|
|
|
|
|
|
cov2d = self._conics_to_covars2d(meta["conics"]) |
|
|
|
|
|
splats_visible_mask = meta["depths"] > 1e-2 |
|
|
cov2d[~splats_visible_mask][..., 0, 0] = 1 |
|
|
cov2d[~splats_visible_mask][..., 1, 1] = 1 |
|
|
cov2d[~splats_visible_mask][..., 0, 1] = 0 |
|
|
|
|
|
|
|
|
rendered_depth = rendered_depth_unnormalized / torch.clip(rendered_alpha, min=1e-8) |
|
|
|
|
|
outputs = RenderingOutputs( |
|
|
color=rendered_color, |
|
|
depth=rendered_depth, |
|
|
alpha=rendered_alpha, |
|
|
) |
|
|
outputs_list.append(outputs) |
|
|
|
|
|
return RenderingOutputs( |
|
|
color=torch.cat([item.color for item in outputs_list], dim=0).contiguous(), |
|
|
depth=torch.cat([item.depth for item in outputs_list], dim=0).contiguous(), |
|
|
alpha=torch.cat([item.alpha for item in outputs_list], dim=0).contiguous(), |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def compose_with_background( |
|
|
rendered_rgb: torch.Tensor, |
|
|
rendered_alpha: torch.Tensor, |
|
|
background_color: BackgroundColor, |
|
|
) -> torch.Tensor: |
|
|
"""Compose rendered RGB with background color.""" |
|
|
if background_color == "black": |
|
|
return rendered_rgb |
|
|
elif background_color == "white": |
|
|
return rendered_rgb + (1.0 - rendered_alpha) |
|
|
elif background_color == "random_color": |
|
|
return ( |
|
|
rendered_rgb |
|
|
+ (1.0 - rendered_alpha) |
|
|
* torch.rand(3, dtype=rendered_rgb.dtype, device=rendered_rgb.device)[ |
|
|
None, :, None, None |
|
|
] |
|
|
) |
|
|
elif background_color == "random_pixel": |
|
|
return rendered_rgb + (1.0 - rendered_alpha) * torch.rand_like(rendered_rgb) |
|
|
else: |
|
|
raise ValueError("Unsupported BackgroundColor type.") |
|
|
|
|
|
@staticmethod |
|
|
def _conics_to_covars2d(conics: torch.Tensor, eps=1e-8) -> torch.Tensor: |
|
|
"""Convert conics to covariance matrices.""" |
|
|
a = conics[..., 0] |
|
|
b = conics[..., 1] |
|
|
c = conics[..., 2] |
|
|
|
|
|
det = 1 / (a * c - b**2 + eps) |
|
|
det = det.clamp(min=eps) |
|
|
|
|
|
covars2d = torch.zeros(*conics.shape[:-1], 2, 2, device=conics.device) |
|
|
covars2d[..., 1, 1] = a * det |
|
|
covars2d[..., 0, 0] = c * det |
|
|
covars2d[..., 0, 1] = -b * det |
|
|
covars2d[..., 1, 0] = -b * det |
|
|
covars2d = torch.nan_to_num(covars2d, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
return covars2d |
|
|
|