|
from dataclasses import dataclass |
|
from typing import Literal |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
import torchvision |
|
|
|
from ..types import Gaussians |
|
|
|
from .decoder import Decoder, DecoderOutput |
|
from math import sqrt |
|
from gsplat import rasterization |
|
|
|
from ...misc.utils import vis_depth_map |
|
|
|
DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"] |
|
|
|
@dataclass |
|
class DecoderSplattingCUDACfg: |
|
name: Literal["splatting_cuda"] |
|
background_color: list[float] |
|
make_scale_invariant: bool |
|
|
|
|
|
class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): |
|
background_color: Float[Tensor, "3"] |
|
|
|
def __init__( |
|
self, |
|
cfg: DecoderSplattingCUDACfg, |
|
) -> None: |
|
super().__init__(cfg) |
|
self.make_scale_invariant = cfg.make_scale_invariant |
|
self.register_buffer( |
|
"background_color", |
|
torch.tensor(cfg.background_color, dtype=torch.float32), |
|
persistent=False, |
|
) |
|
|
|
def rendering_fn( |
|
self, |
|
gaussians: Gaussians, |
|
extrinsics: Float[Tensor, "batch view 4 4"], |
|
intrinsics: Float[Tensor, "batch view 3 3"], |
|
near: Float[Tensor, "batch view"], |
|
far: Float[Tensor, "batch view"], |
|
image_shape: tuple[int, int], |
|
depth_mode: DepthRenderingMode | None = None, |
|
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None, |
|
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None, |
|
) -> DecoderOutput: |
|
B, V, _, _ = intrinsics.shape |
|
H, W = image_shape |
|
rendered_imgs, rendered_depths, rendered_alphas = [], [], [] |
|
xyzs, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(0, 1, 3, 2).contiguous() |
|
covariances = gaussians.covariances |
|
for i in range(B): |
|
xyz_i = xyzs[i].float() |
|
feature_i = features[i].float() |
|
covar_i = covariances[i].float() |
|
scale_i = scales[i].float() |
|
rotation_i = rotations[i].float() |
|
opacity_i = opacitys[i].squeeze().float() |
|
test_w2c_i = extrinsics[i].float().inverse() |
|
test_intr_i_normalized = intrinsics[i].float() |
|
|
|
test_intr_i = test_intr_i_normalized.clone() |
|
test_intr_i[:, 0] = test_intr_i_normalized[:, 0] * W |
|
test_intr_i[:, 1] = test_intr_i_normalized[:, 1] * H |
|
sh_degree = (int(sqrt(feature_i.shape[-2])) - 1) |
|
|
|
rendering_list = [] |
|
rendering_depth_list = [] |
|
rendering_alpha_list = [] |
|
for j in range(V): |
|
rendering, alpha, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i, |
|
test_w2c_i[j:j+1], test_intr_i[j:j+1], W, H, sh_degree=sh_degree, |
|
|
|
render_mode="RGB+D", packed=False, |
|
near_plane=1e-10, |
|
backgrounds=self.background_color.unsqueeze(0).repeat(1, 1), |
|
radius_clip=0.1, |
|
covars=covar_i, |
|
rasterize_mode='classic') |
|
rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1) |
|
rendering_img = rendering_img.clamp(0.0, 1.0) |
|
rendering_list.append(rendering_img.permute(0, 3, 1, 2)) |
|
rendering_depth_list.append(rendering_depth) |
|
rendering_alpha_list.append(alpha) |
|
rendered_depths.append(torch.cat(rendering_depth_list, dim=0).squeeze()) |
|
rendered_imgs.append(torch.cat(rendering_list, dim=0)) |
|
rendered_alphas.append(torch.cat(rendering_alpha_list, dim=0).squeeze()) |
|
return DecoderOutput(torch.stack(rendered_imgs), torch.stack(rendered_depths), torch.stack(rendered_alphas), lod_rendering=None) |
|
|
|
def forward( |
|
self, |
|
gaussians: Gaussians, |
|
extrinsics: Float[Tensor, "batch view 4 4"], |
|
intrinsics: Float[Tensor, "batch view 3 3"], |
|
near: Float[Tensor, "batch view"], |
|
far: Float[Tensor, "batch view"], |
|
image_shape: tuple[int, int], |
|
depth_mode: DepthRenderingMode | None = None, |
|
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None, |
|
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None, |
|
) -> DecoderOutput: |
|
|
|
return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta, cam_trans_delta) |
|
|
|
|