|
|
"""Contains basic data structures and functionality for 3D Gaussians. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any, Literal, NamedTuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from plyfile import PlyData, PlyElement |
|
|
|
|
|
from sharp.utils import color_space as cs_utils |
|
|
from sharp.utils import linalg |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
BackgroundColor = Literal["black", "white", "random_color", "random_pixel"] |
|
|
|
|
|
|
|
|
class Gaussians3D(NamedTuple): |
|
|
"""Represents a collection of 3D Gaussians.""" |
|
|
|
|
|
mean_vectors: torch.Tensor |
|
|
singular_values: torch.Tensor |
|
|
quaternions: torch.Tensor |
|
|
colors: torch.Tensor |
|
|
opacities: torch.Tensor |
|
|
|
|
|
def to(self, device: torch.device) -> Gaussians3D: |
|
|
"""Move Gaussians to device.""" |
|
|
return Gaussians3D( |
|
|
mean_vectors=self.mean_vectors.to(device), |
|
|
singular_values=self.singular_values.to(device), |
|
|
quaternions=self.quaternions.to(device), |
|
|
colors=self.colors.to(device), |
|
|
opacities=self.opacities.to(device), |
|
|
) |
|
|
|
|
|
|
|
|
class SceneMetaData(NamedTuple): |
|
|
"""Meta data about Gaussian scene.""" |
|
|
|
|
|
focal_length_px: float |
|
|
resolution_px: tuple[int, int] |
|
|
color_space: cs_utils.ColorSpace |
|
|
|
|
|
|
|
|
def get_unprojection_matrix( |
|
|
extrinsics: torch.Tensor, |
|
|
intrinsics: torch.Tensor, |
|
|
image_shape: tuple[int, int], |
|
|
) -> torch.Tensor: |
|
|
"""Compute unprojection matrix to transform Gaussians to Euclidean space. |
|
|
|
|
|
Args: |
|
|
extrinsics: The 4x4 extrinsics matrix of the camera view. |
|
|
intrinsics: The 4x4 intrinsics matrix of the camera view. |
|
|
image_shape: The (width, height) of the input image. |
|
|
|
|
|
Returns: |
|
|
A 4x4 matrix to transform Gaussians from NDC space to Euclidean space. |
|
|
""" |
|
|
device = intrinsics.device |
|
|
image_width, image_height = image_shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ndc_matrix = torch.tensor( |
|
|
[ |
|
|
[2.0 / image_width, 0.0, -1.0, 0.0], |
|
|
[0.0, 2.0 / image_height, -1.0, 0.0], |
|
|
[0.0, 0.0, 1.0, 0.0], |
|
|
[0.0, 0.0, 0.0, 1.0], |
|
|
], |
|
|
device=device, |
|
|
) |
|
|
return torch.linalg.inv(ndc_matrix @ intrinsics @ extrinsics) |
|
|
|
|
|
|
|
|
def unproject_gaussians( |
|
|
gaussians_ndc: Gaussians3D, |
|
|
extrinsics: torch.Tensor, |
|
|
intrinsics: torch.Tensor, |
|
|
image_shape: tuple[int, int], |
|
|
) -> Gaussians3D: |
|
|
"""Unproject Gaussians from NDC space to world coordinates.""" |
|
|
unprojection_matrix = get_unprojection_matrix(extrinsics, intrinsics, image_shape) |
|
|
gaussians = apply_transform(gaussians_ndc, unprojection_matrix[:3]) |
|
|
return gaussians |
|
|
|
|
|
|
|
|
def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussians3D: |
|
|
"""Apply an affine transformation to 3D Gaussians. |
|
|
|
|
|
Args: |
|
|
gaussians: The Gaussians to transform. |
|
|
transform: An affine transform with shape 3x4. |
|
|
|
|
|
Returns: |
|
|
The transformed Gaussians. |
|
|
|
|
|
Note: This operation is not differentiable. |
|
|
""" |
|
|
transform_linear = transform[..., :3, :3] |
|
|
transform_offset = transform[..., :3, 3] |
|
|
|
|
|
mean_vectors = gaussians.mean_vectors @ transform_linear.T + transform_offset |
|
|
covariance_matrices = compose_covariance_matrices( |
|
|
gaussians.quaternions, gaussians.singular_values |
|
|
) |
|
|
covariance_matrices = ( |
|
|
transform_linear @ covariance_matrices @ transform_linear.transpose(-1, -2) |
|
|
) |
|
|
quaternions, singular_values = decompose_covariance_matrices(covariance_matrices) |
|
|
|
|
|
return Gaussians3D( |
|
|
mean_vectors=mean_vectors, |
|
|
singular_values=singular_values, |
|
|
quaternions=quaternions, |
|
|
colors=gaussians.colors, |
|
|
opacities=gaussians.opacities, |
|
|
) |
|
|
|
|
|
|
|
|
def decompose_covariance_matrices( |
|
|
covariance_matrices: torch.Tensor, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Decompose 3D covariance matrices into quaternions and singular values. |
|
|
|
|
|
Args: |
|
|
covariance_matrices: The covariance matrices to decompose. |
|
|
|
|
|
Returns: |
|
|
Quaternion and singular values corresponding to the orientation and scales of |
|
|
the diagonalized matrix. |
|
|
|
|
|
Note: This operation is not differentiable. |
|
|
""" |
|
|
device = covariance_matrices.device |
|
|
dtype = covariance_matrices.dtype |
|
|
|
|
|
|
|
|
covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64) |
|
|
rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices) |
|
|
|
|
|
|
|
|
|
|
|
batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0) |
|
|
num_reflections = len(gaussian_idx) |
|
|
if num_reflections > 0: |
|
|
LOGGER.warning( |
|
|
"Received %d reflection matrices from SVD. Flipping them to rotations.", |
|
|
num_reflections, |
|
|
) |
|
|
|
|
|
rotations[batch_idx, gaussian_idx, :, -1] *= -1 |
|
|
quaternions = linalg.quaternions_from_rotation_matrices(rotations) |
|
|
quaternions = quaternions.to(dtype=dtype, device=device) |
|
|
singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device) |
|
|
return quaternions, singular_values |
|
|
|
|
|
|
|
|
def compose_covariance_matrices( |
|
|
quaternions: torch.Tensor, singular_values: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
"""Compose 3D covariance matrices into quaternions and singular values. |
|
|
|
|
|
Args: |
|
|
quaternions: The quaternions describing the principal basis. |
|
|
singular_values: The scales of the diagonalized matrix. |
|
|
|
|
|
Returns: |
|
|
The 3x3 covariances matrices. |
|
|
""" |
|
|
device = quaternions.device |
|
|
rotations = linalg.rotation_matrices_from_quaternions(quaternions) |
|
|
diagonal_matrix = torch.eye(3, device=device) * singular_values[..., :, None] |
|
|
return rotations @ diagonal_matrix.square() @ rotations.transpose(-1, -2) |
|
|
|
|
|
|
|
|
def convert_spherical_harmonics_to_rgb(sh0: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert degree-0 spherical harmonics to RGB. |
|
|
|
|
|
Reference: |
|
|
https://en.wikipedia.org/wiki/Table_of_spherical_harmonics |
|
|
""" |
|
|
coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) |
|
|
return sh0 * coeff_degree0 + 0.5 |
|
|
|
|
|
|
|
|
def convert_rgb_to_spherical_harmonics(rgb: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert RGB to degree-0 spherical harmonics. |
|
|
|
|
|
Reference: |
|
|
https://en.wikipedia.org/wiki/Table_of_spherical_harmonics |
|
|
""" |
|
|
coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) |
|
|
return (rgb - 0.5) / coeff_degree0 |
|
|
|
|
|
|
|
|
def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]: |
|
|
"""Loads a ply from a file.""" |
|
|
plydata = PlyData.read(path) |
|
|
|
|
|
vertices = next(filter(lambda x: x.name == "vertex", plydata.elements)) |
|
|
|
|
|
properties = ["x", "y", "z"] |
|
|
properties.extend([f"f_dc_{i}" for i in range(3)]) |
|
|
properties.extend([f"scale_{i}" for i in range(3)]) |
|
|
properties.extend([f"rot_{i}" for i in range(3)]) |
|
|
|
|
|
for prop in properties: |
|
|
if prop not in vertices: |
|
|
raise KeyError(f"Incompatible ply file: property {prop} not found in ply elements.") |
|
|
mean_vectors = np.stack( |
|
|
( |
|
|
np.asarray(vertices["x"]), |
|
|
np.asarray(vertices["y"]), |
|
|
np.asarray(vertices["z"]), |
|
|
), |
|
|
axis=1, |
|
|
) |
|
|
|
|
|
scale_logits = np.stack( |
|
|
( |
|
|
np.asarray(vertices["scale_0"]), |
|
|
np.asarray(vertices["scale_1"]), |
|
|
np.asarray(vertices["scale_2"]), |
|
|
), |
|
|
axis=1, |
|
|
) |
|
|
|
|
|
quaternions = np.stack( |
|
|
( |
|
|
np.asarray(vertices["rot_0"]), |
|
|
np.asarray(vertices["rot_1"]), |
|
|
np.asarray(vertices["rot_2"]), |
|
|
np.asarray(vertices["rot_3"]), |
|
|
), |
|
|
axis=1, |
|
|
) |
|
|
|
|
|
spherical_harmonics_deg0 = np.stack( |
|
|
( |
|
|
np.asarray(vertices["f_dc_0"]), |
|
|
np.asarray(vertices["f_dc_1"]), |
|
|
np.asarray(vertices["f_dc_2"]), |
|
|
), |
|
|
axis=1, |
|
|
) |
|
|
|
|
|
colors = convert_spherical_harmonics_to_rgb(spherical_harmonics_deg0) |
|
|
|
|
|
opacity_logits = np.asarray(vertices["opacity"])[..., None] |
|
|
|
|
|
supplement_elements = [element for element in plydata.elements if element.name != "vertex"] |
|
|
supplement_data: dict[str, Any] = {} |
|
|
supplement_keys = ["extrinsic", "intrinsic", "color_space", "image_size"] |
|
|
|
|
|
for element in supplement_elements: |
|
|
for key in supplement_keys: |
|
|
if key not in supplement_data and key in element: |
|
|
supplement_data[key] = np.asarray(element[key]) |
|
|
|
|
|
|
|
|
if "intrinsic" in supplement_data: |
|
|
intrinsics_data = supplement_data["intrinsic"] |
|
|
|
|
|
|
|
|
if "image_size" not in supplement_data: |
|
|
if len(intrinsics_data) != 4: |
|
|
raise ValueError( |
|
|
"Expect legacy intrinsics with len=4 containing image size, " |
|
|
f"but received len={len(intrinsics_data)}" |
|
|
) |
|
|
focal_length_px = (intrinsics_data[0], intrinsics_data[1]) |
|
|
width = int(intrinsics_data[2]) |
|
|
height = int(intrinsics_data[3]) |
|
|
|
|
|
else: |
|
|
if len(intrinsics_data) != 9: |
|
|
raise ValueError( |
|
|
"Expect 9 elements in intrinsics, " f"but received {len(intrinsics_data)}." |
|
|
) |
|
|
intrinsics_matrix = intrinsics_data.reshape((3, 3)) |
|
|
focal_length_px = (intrinsics_matrix[0, 0], intrinsics_matrix[1, 1]) |
|
|
|
|
|
image_size_data = supplement_data["image_size"] |
|
|
width = image_size_data[0] |
|
|
height = image_size_data[1] |
|
|
|
|
|
|
|
|
else: |
|
|
focal_length_px = (512, 512) |
|
|
width = 640 |
|
|
height = 480 |
|
|
|
|
|
|
|
|
extrinsics_data = supplement_data.get("extrinsic", np.eye(4).flatten()) |
|
|
extrinsics_matrix = np.eye(4) |
|
|
|
|
|
|
|
|
if len(extrinsics_data) == 12: |
|
|
extrinsics_matrix[:3] = extrinsics_data.reshape((3, 4)) |
|
|
extrinsics_matrix[:3, :3] = extrinsics_matrix[:3, :3].copy().T |
|
|
elif len(extrinsics_data) == 16: |
|
|
extrinsics_matrix[:] = extrinsics_data.reshape((4, 4)) |
|
|
else: |
|
|
raise ValueError(f"Unrecognized extrinsics matrix shape {len(extrinsics_data)}") |
|
|
|
|
|
|
|
|
color_space_index = supplement_data.get("color_space", 1) |
|
|
color_space = cs_utils.decode_color_space(color_space_index) |
|
|
if color_space == "sRGB": |
|
|
colors = cs_utils.sRGB2linearRGB(colors) |
|
|
|
|
|
mean_vectors = torch.from_numpy(mean_vectors).view(1, -1, 3).float() |
|
|
quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float() |
|
|
singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float() |
|
|
opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float() |
|
|
colors = torch.from_numpy(colors).view(1, -1, 3).float() |
|
|
|
|
|
gaussians = Gaussians3D( |
|
|
mean_vectors=mean_vectors, |
|
|
quaternions=quaternions, |
|
|
singular_values=singular_values, |
|
|
opacities=opacities, |
|
|
colors=colors, |
|
|
) |
|
|
metadata = SceneMetaData(focal_length_px[0], (width, height), color_space) |
|
|
return gaussians, metadata |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def save_ply( |
|
|
gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path |
|
|
) -> PlyData: |
|
|
"""Save a predicted Gaussian3D to a ply file.""" |
|
|
|
|
|
def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: |
|
|
return torch.log(tensor / (1.0 - tensor)) |
|
|
|
|
|
xyz = gaussians.mean_vectors.flatten(0, 1) |
|
|
scale_logits = torch.log(gaussians.singular_values).flatten(0, 1) |
|
|
quaternions = gaussians.quaternions.flatten(0, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colors = convert_rgb_to_spherical_harmonics( |
|
|
cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)) |
|
|
) |
|
|
color_space_index = cs_utils.encode_color_space("sRGB") |
|
|
|
|
|
|
|
|
opacity_logits = _inverse_sigmoid(gaussians.opacities).flatten(0, 1).unsqueeze(-1) |
|
|
|
|
|
attributes = torch.cat( |
|
|
( |
|
|
xyz, |
|
|
colors, |
|
|
opacity_logits, |
|
|
scale_logits, |
|
|
quaternions, |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
dtype_full = [ |
|
|
(attribute, "f4") |
|
|
for attribute in ["x", "y", "z"] |
|
|
+ [f"f_dc_{i}" for i in range(3)] |
|
|
+ ["opacity"] |
|
|
+ [f"scale_{i}" for i in range(3)] |
|
|
+ [f"rot_{i}" for i in range(4)] |
|
|
] |
|
|
|
|
|
num_gaussians = len(xyz) |
|
|
elements = np.empty(num_gaussians, dtype=dtype_full) |
|
|
elements[:] = list(map(tuple, attributes.detach().cpu().numpy())) |
|
|
vertex_elements = PlyElement.describe(elements, "vertex") |
|
|
|
|
|
|
|
|
image_height, image_width = image_shape |
|
|
|
|
|
|
|
|
dtype_image_size = [("image_size", "u4")] |
|
|
image_size_array = np.empty(2, dtype=dtype_image_size) |
|
|
image_size_array[:] = np.array([image_width, image_height]) |
|
|
image_size_element = PlyElement.describe(image_size_array, "image_size") |
|
|
|
|
|
|
|
|
dtype_intrinsic = [("intrinsic", "f4")] |
|
|
intrinsic_array = np.empty(9, dtype=dtype_intrinsic) |
|
|
intrinsic = np.array( |
|
|
[ |
|
|
f_px, |
|
|
0, |
|
|
image_width * 0.5, |
|
|
0, |
|
|
f_px, |
|
|
image_height * 0.5, |
|
|
0, |
|
|
0, |
|
|
1, |
|
|
] |
|
|
) |
|
|
intrinsic_array[:] = intrinsic.flatten() |
|
|
intrinsic_element = PlyElement.describe(intrinsic_array, "intrinsic") |
|
|
|
|
|
|
|
|
dtype_extrinsic = [("extrinsic", "f4")] |
|
|
extrinsic_array = np.empty(16, dtype=dtype_extrinsic) |
|
|
extrinsic_array[:] = np.eye(4).flatten() |
|
|
extrinsic_element = PlyElement.describe(extrinsic_array, "extrinsic") |
|
|
|
|
|
|
|
|
dtype_frames = [("frame", "i4")] |
|
|
frame_array = np.empty(2, dtype=dtype_frames) |
|
|
frame_array[:] = np.array([1, num_gaussians], dtype=np.int32) |
|
|
frame_element = PlyElement.describe(frame_array, "frame") |
|
|
|
|
|
|
|
|
dtype_disparity = [("disparity", "f4")] |
|
|
disparity_array = np.empty(2, dtype=dtype_disparity) |
|
|
|
|
|
disparity = 1.0 / gaussians.mean_vectors[0, ..., -1] |
|
|
quantiles = ( |
|
|
torch.quantile(disparity, q=torch.tensor([0.1, 0.9], device=disparity.device)) |
|
|
.float() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
disparity_array[:] = quantiles |
|
|
disparity_element = PlyElement.describe(disparity_array, "disparity") |
|
|
|
|
|
|
|
|
dtype_color_space = [("color_space", "u1")] |
|
|
color_space_array = np.empty(1, dtype=dtype_color_space) |
|
|
color_space_array[:] = np.array([color_space_index]).flatten() |
|
|
color_space_element = PlyElement.describe(color_space_array, "color_space") |
|
|
|
|
|
dtype_version = [("version", "u1")] |
|
|
version_array = np.empty(3, dtype=dtype_version) |
|
|
version_array[:] = np.array([1, 5, 0], dtype=np.uint8).flatten() |
|
|
version_element = PlyElement.describe(version_array, "version") |
|
|
|
|
|
plydata = PlyData( |
|
|
[ |
|
|
vertex_elements, |
|
|
extrinsic_element, |
|
|
intrinsic_element, |
|
|
image_size_element, |
|
|
frame_element, |
|
|
disparity_element, |
|
|
color_space_element, |
|
|
version_element, |
|
|
] |
|
|
) |
|
|
|
|
|
plydata.write(path) |
|
|
return plydata |
|
|
|