AnySplat / src /visualization /validation_in_3d.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
from jaxtyping import Float, Shaped
from torch import Tensor
from ..model.decoder.cuda_splatting import render_cuda_orthographic
from ..model.types import Gaussians
from ..visualization.annotation import add_label
from ..visualization.drawing.cameras import draw_cameras
from .drawing.cameras import compute_equal_aabb_with_margin
def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]:
shapes = torch.stack([torch.tensor(x.shape) for x in images])
padded_shape = shapes.max(dim=0)[0]
results = [
torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device)
for x in images
]
for image, result in zip(images, results):
slices = [slice(0, x) for x in image.shape]
result[slices] = image[slices]
return results
def render_projections(
gaussians: Gaussians,
resolution: int,
margin: float = 0.1,
draw_label: bool = True,
extra_label: str = "",
) -> Float[Tensor, "batch 3 3 height width"]:
device = gaussians.means.device
b, _, _ = gaussians.means.shape
# Compute the minima and maxima of the scene.
minima = gaussians.means.min(dim=1).values
maxima = gaussians.means.max(dim=1).values
scene_minima, scene_maxima = compute_equal_aabb_with_margin(
minima, maxima, margin=margin
)
projections = []
for look_axis in range(3):
right_axis = (look_axis + 1) % 3
down_axis = (look_axis + 2) % 3
# Define the extrinsics for rendering.
extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device)
extrinsics[:, right_axis, 0] = 1
extrinsics[:, down_axis, 1] = 1
extrinsics[:, look_axis, 2] = 1
extrinsics[:, right_axis, 3] = 0.5 * (
scene_minima[:, right_axis] + scene_maxima[:, right_axis]
)
extrinsics[:, down_axis, 3] = 0.5 * (
scene_minima[:, down_axis] + scene_maxima[:, down_axis]
)
extrinsics[:, look_axis, 3] = scene_minima[:, look_axis]
extrinsics[:, 3, 3] = 1
# Define the intrinsics for rendering.
extents = scene_maxima - scene_minima
far = extents[:, look_axis]
near = torch.zeros_like(far)
width = extents[:, right_axis]
height = extents[:, down_axis]
projection = render_cuda_orthographic(
extrinsics,
width,
height,
near,
far,
(resolution, resolution),
torch.zeros((b, 3), dtype=torch.float32, device=device),
gaussians.means,
gaussians.covariances,
gaussians.harmonics,
gaussians.opacities,
fov_degrees=10.0,
)
if draw_label:
right_axis_name = "XYZ"[right_axis]
down_axis_name = "XYZ"[down_axis]
label = f"{right_axis_name}{down_axis_name} Projection {extra_label}"
projection = torch.stack([add_label(x, label) for x in projection])
projections.append(projection)
return torch.stack(pad(projections), dim=1)
def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]:
# Define colors for context and target views.
num_context_views = batch["context"]["extrinsics"].shape[1]
num_target_views = batch["target"]["extrinsics"].shape[1]
color = torch.ones(
(num_target_views + num_context_views, 3),
dtype=torch.float32,
device=batch["target"]["extrinsics"].device,
)
color[num_context_views:, 1:] = 0
return draw_cameras(
resolution,
torch.cat(
(batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0])
),
torch.cat(
(batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0])
),
color,
torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])),
torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])),
)