AnySplat / src /model /ply_export.py
alexnasa's picture
Upload 243 files
2568013 verified
from pathlib import Path
import numpy as np
import torch
from einops import einsum, rearrange
from jaxtyping import Float
from plyfile import PlyData, PlyElement
from scipy.spatial.transform import Rotation as R
from torch import Tensor
def construct_list_of_attributes(num_rest: int) -> list[str]:
attributes = ["x", "y", "z", "nx", "ny", "nz"]
for i in range(3):
attributes.append(f"f_dc_{i}")
for i in range(num_rest):
attributes.append(f"f_rest_{i}")
attributes.append("opacity")
for i in range(3):
attributes.append(f"scale_{i}")
for i in range(4):
attributes.append(f"rot_{i}")
return attributes
def export_ply(
means: Float[Tensor, "gaussian 3"],
scales: Float[Tensor, "gaussian 3"],
rotations: Float[Tensor, "gaussian 4"],
harmonics: Float[Tensor, "gaussian 3 d_sh"],
opacities: Float[Tensor, " gaussian"],
path: Path,
shift_and_scale: bool = False,
save_sh_dc_only: bool = True,
):
if shift_and_scale:
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values
# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor
# Apply the rotation to the Gaussian rotations.
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
rotations = R.from_matrix(rotations).as_quat()
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
rotations = np.stack((w, x, y, z), axis=-1)
# Since current model use SH_degree = 4,
# which require large memory to store, we can only save the DC band to save memory.
f_dc = harmonics[..., 0]
f_rest = harmonics[..., 1:].flatten(start_dim=1)
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = [
means.detach().cpu().numpy(),
torch.zeros_like(means).detach().cpu().numpy(),
f_dc.detach().cpu().contiguous().numpy(),
f_rest.detach().cpu().contiguous().numpy(),
opacities[..., None].detach().cpu().numpy(),
scales.log().detach().cpu().numpy(),
rotations,
]
if save_sh_dc_only:
# remove f_rest from attributes
attributes.pop(3)
attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
PlyData([PlyElement.describe(elements, "vertex")]).write(path)