| import torch |
| from jaxtyping import Float, Shaped |
| from torch import Tensor |
|
|
| from ..model.decoder.cuda_splatting import render_cuda_orthographic |
| from ..model.types import Gaussians |
| from ..visualization.annotation import add_label |
| from ..visualization.drawing.cameras import draw_cameras |
| from .drawing.cameras import compute_equal_aabb_with_margin |
|
|
|
|
| def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: |
| shapes = torch.stack([torch.tensor(x.shape) for x in images]) |
| padded_shape = shapes.max(dim=0)[0] |
| results = [ |
| torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) |
| for x in images |
| ] |
| for image, result in zip(images, results): |
| slices = [slice(0, x) for x in image.shape] |
| result[slices] = image[slices] |
| return results |
|
|
|
|
| def render_projections( |
| gaussians: Gaussians, |
| resolution: int, |
| margin: float = 0.1, |
| draw_label: bool = True, |
| extra_label: str = "", |
| ) -> Float[Tensor, "batch 3 3 height width"]: |
| device = gaussians.means.device |
| b, _, _ = gaussians.means.shape |
|
|
| |
| minima = gaussians.means.min(dim=1).values |
| maxima = gaussians.means.max(dim=1).values |
| scene_minima, scene_maxima = compute_equal_aabb_with_margin( |
| minima, maxima, margin=margin |
| ) |
|
|
| projections = [] |
| for look_axis in range(3): |
| right_axis = (look_axis + 1) % 3 |
| down_axis = (look_axis + 2) % 3 |
|
|
| |
| extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) |
| extrinsics[:, right_axis, 0] = 1 |
| extrinsics[:, down_axis, 1] = 1 |
| extrinsics[:, look_axis, 2] = 1 |
| extrinsics[:, right_axis, 3] = 0.5 * ( |
| scene_minima[:, right_axis] + scene_maxima[:, right_axis] |
| ) |
| extrinsics[:, down_axis, 3] = 0.5 * ( |
| scene_minima[:, down_axis] + scene_maxima[:, down_axis] |
| ) |
| extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] |
| extrinsics[:, 3, 3] = 1 |
|
|
| |
| extents = scene_maxima - scene_minima |
| far = extents[:, look_axis] |
| near = torch.zeros_like(far) |
| width = extents[:, right_axis] |
| height = extents[:, down_axis] |
|
|
| projection = render_cuda_orthographic( |
| extrinsics, |
| width, |
| height, |
| near, |
| far, |
| (resolution, resolution), |
| torch.zeros((b, 3), dtype=torch.float32, device=device), |
| gaussians.means, |
| gaussians.covariances, |
| gaussians.harmonics, |
| gaussians.opacities, |
| fov_degrees=10.0, |
| ) |
| if draw_label: |
| right_axis_name = "XYZ"[right_axis] |
| down_axis_name = "XYZ"[down_axis] |
| label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" |
| projection = torch.stack([add_label(x, label) for x in projection]) |
|
|
| projections.append(projection) |
|
|
| return torch.stack(pad(projections), dim=1) |
|
|
|
|
| def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: |
| |
| num_context_views = batch["context"]["extrinsics"].shape[1] |
| num_target_views = batch["target"]["extrinsics"].shape[1] |
| color = torch.ones( |
| (num_target_views + num_context_views, 3), |
| dtype=torch.float32, |
| device=batch["target"]["extrinsics"].device, |
| ) |
| color[num_context_views:, 1:] = 0 |
|
|
| return draw_cameras( |
| resolution, |
| torch.cat( |
| (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) |
| ), |
| torch.cat( |
| (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) |
| ), |
| color, |
| torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), |
| torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), |
| ) |
|
|