| | import torch |
| | from jaxtyping import Float, Shaped |
| | from torch import Tensor |
| |
|
| | from ..model.decoder.cuda_splatting import render_cuda_orthographic |
| | from ..model.types import Gaussians |
| | from ..visualization.annotation import add_label |
| | from ..visualization.drawing.cameras import draw_cameras |
| | from .drawing.cameras import compute_equal_aabb_with_margin |
| |
|
| |
|
| | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: |
| | shapes = torch.stack([torch.tensor(x.shape) for x in images]) |
| | padded_shape = shapes.max(dim=0)[0] |
| | results = [ |
| | torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) |
| | for x in images |
| | ] |
| | for image, result in zip(images, results): |
| | slices = [slice(0, x) for x in image.shape] |
| | result[slices] = image[slices] |
| | return results |
| |
|
| |
|
| | def render_projections( |
| | gaussians: Gaussians, |
| | resolution: int, |
| | margin: float = 0.1, |
| | draw_label: bool = True, |
| | extra_label: str = "", |
| | ) -> Float[Tensor, "batch 3 3 height width"]: |
| | device = gaussians.means.device |
| | b, _, _ = gaussians.means.shape |
| |
|
| | |
| | minima = gaussians.means.min(dim=1).values |
| | maxima = gaussians.means.max(dim=1).values |
| | scene_minima, scene_maxima = compute_equal_aabb_with_margin( |
| | minima, maxima, margin=margin |
| | ) |
| |
|
| | projections = [] |
| | for look_axis in range(3): |
| | right_axis = (look_axis + 1) % 3 |
| | down_axis = (look_axis + 2) % 3 |
| |
|
| | |
| | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) |
| | extrinsics[:, right_axis, 0] = 1 |
| | extrinsics[:, down_axis, 1] = 1 |
| | extrinsics[:, look_axis, 2] = 1 |
| | extrinsics[:, right_axis, 3] = 0.5 * ( |
| | scene_minima[:, right_axis] + scene_maxima[:, right_axis] |
| | ) |
| | extrinsics[:, down_axis, 3] = 0.5 * ( |
| | scene_minima[:, down_axis] + scene_maxima[:, down_axis] |
| | ) |
| | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] |
| | extrinsics[:, 3, 3] = 1 |
| |
|
| | |
| | extents = scene_maxima - scene_minima |
| | far = extents[:, look_axis] |
| | near = torch.zeros_like(far) |
| | width = extents[:, right_axis] |
| | height = extents[:, down_axis] |
| |
|
| | projection = render_cuda_orthographic( |
| | extrinsics, |
| | width, |
| | height, |
| | near, |
| | far, |
| | (resolution, resolution), |
| | torch.zeros((b, 3), dtype=torch.float32, device=device), |
| | gaussians.means, |
| | gaussians.covariances, |
| | gaussians.harmonics, |
| | gaussians.opacities, |
| | fov_degrees=10.0, |
| | ) |
| | if draw_label: |
| | right_axis_name = "XYZ"[right_axis] |
| | down_axis_name = "XYZ"[down_axis] |
| | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" |
| | projection = torch.stack([add_label(x, label) for x in projection]) |
| |
|
| | projections.append(projection) |
| |
|
| | return torch.stack(pad(projections), dim=1) |
| |
|
| |
|
| | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: |
| | |
| | num_context_views = batch["context"]["extrinsics"].shape[1] |
| | num_target_views = batch["target"]["extrinsics"].shape[1] |
| | color = torch.ones( |
| | (num_target_views + num_context_views, 3), |
| | dtype=torch.float32, |
| | device=batch["target"]["extrinsics"].device, |
| | ) |
| | color[num_context_views:, 1:] = 0 |
| |
|
| | return draw_cameras( |
| | resolution, |
| | torch.cat( |
| | (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) |
| | ), |
| | torch.cat( |
| | (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) |
| | ), |
| | color, |
| | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), |
| | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), |
| | ) |
| |
|