| from typing import Optional |
|
|
| import torch |
| from einops import einsum, rearrange, repeat |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from ...geometry.projection import unproject |
| from ..annotation import add_label |
| from .lines import draw_lines |
| from .types import Scalar, sanitize_scalar |
|
|
|
|
| def draw_cameras( |
| resolution: int, |
| extrinsics: Float[Tensor, "batch 4 4"], |
| intrinsics: Float[Tensor, "batch 3 3"], |
| color: Float[Tensor, "batch 3"], |
| near: Optional[Scalar] = None, |
| far: Optional[Scalar] = None, |
| margin: float = 0.1, |
| frustum_scale: float = 0.05, |
| ) -> Float[Tensor, "3 3 height width"]: |
| device = extrinsics.device |
|
|
| |
| minima, maxima = compute_aabb(extrinsics, intrinsics, near, far) |
| scene_minima, scene_maxima = compute_equal_aabb_with_margin( |
| minima, maxima, margin=margin |
| ) |
| span = (scene_maxima - scene_minima).max() |
|
|
| |
| corner_depth = (span * frustum_scale)[None] |
| frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth) |
| if near is not None: |
| near_corners = unproject_frustum_corners(extrinsics, intrinsics, near) |
| if far is not None: |
| far_corners = unproject_frustum_corners(extrinsics, intrinsics, far) |
|
|
| |
| projections = [] |
| for projected_axis in range(3): |
| image = torch.zeros( |
| (3, resolution, resolution), |
| dtype=torch.float32, |
| device=device, |
| ) |
| image_x_axis = (projected_axis + 1) % 3 |
| image_y_axis = (projected_axis + 2) % 3 |
|
|
| def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]: |
| x = points[..., image_x_axis] |
| y = points[..., image_y_axis] |
| return torch.stack([x, y], dim=-1) |
|
|
| x_range, y_range = torch.stack( |
| (project(scene_minima), project(scene_maxima)), dim=-1 |
| ) |
|
|
| |
| if near is not None: |
| projected_near_corners = project(near_corners) |
| image = draw_lines( |
| image, |
| rearrange(projected_near_corners, "b p xy -> (b p) xy"), |
| rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"), |
| color=0.25, |
| width=2, |
| x_range=x_range, |
| y_range=y_range, |
| ) |
| if far is not None: |
| projected_far_corners = project(far_corners) |
| image = draw_lines( |
| image, |
| rearrange(projected_far_corners, "b p xy -> (b p) xy"), |
| rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"), |
| color=0.25, |
| width=2, |
| x_range=x_range, |
| y_range=y_range, |
| ) |
| if near is not None and far is not None: |
| image = draw_lines( |
| image, |
| rearrange(projected_near_corners, "b p xy -> (b p) xy"), |
| rearrange(projected_far_corners, "b p xy -> (b p) xy"), |
| color=0.25, |
| width=2, |
| x_range=x_range, |
| y_range=y_range, |
| ) |
|
|
| |
| projected_origins = project(extrinsics[:, :3, 3]) |
| projected_frustum_corners = project(frustum_corners) |
| start = [ |
| repeat(projected_origins, "b xy -> (b p) xy", p=4), |
| rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"), |
| ] |
| start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4) |
| image = draw_lines( |
| image, |
| start, |
| repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2), |
| color=repeat(color, "b c -> (b r p) c", r=2, p=4), |
| width=2, |
| x_range=x_range, |
| y_range=y_range, |
| ) |
|
|
| x_name = "XYZ"[image_x_axis] |
| y_name = "XYZ"[image_y_axis] |
| image = add_label(image, f"{x_name}{y_name} Projection") |
|
|
| |
| projections.append(image) |
|
|
| return torch.stack(projections) |
|
|
|
|
| def compute_aabb( |
| extrinsics: Float[Tensor, "batch 4 4"], |
| intrinsics: Float[Tensor, "batch 3 3"], |
| near: Optional[Scalar] = None, |
| far: Optional[Scalar] = None, |
| ) -> tuple[ |
| Float[Tensor, "3"], |
| Float[Tensor, "3"], |
| ]: |
| """Compute an axis-aligned bounding box for the camera frustums.""" |
|
|
| device = extrinsics.device |
|
|
| |
| points = [extrinsics[:, :3, 3]] |
|
|
| if near is not None: |
| near = sanitize_scalar(near, device) |
| corners = unproject_frustum_corners(extrinsics, intrinsics, near) |
| points.append(rearrange(corners, "b p xyz -> (b p) xyz")) |
|
|
| if far is not None: |
| far = sanitize_scalar(far, device) |
| corners = unproject_frustum_corners(extrinsics, intrinsics, far) |
| points.append(rearrange(corners, "b p xyz -> (b p) xyz")) |
|
|
| points = torch.cat(points, dim=0) |
| return points.min(dim=0).values, points.max(dim=0).values |
|
|
|
|
| def compute_equal_aabb_with_margin( |
| minima: Float[Tensor, "*#batch 3"], |
| maxima: Float[Tensor, "*#batch 3"], |
| margin: float = 0.1, |
| ) -> tuple[ |
| Float[Tensor, "*batch 3"], |
| Float[Tensor, "*batch 3"], |
| ]: |
| midpoint = (maxima + minima) * 0.5 |
| span = (maxima - minima).max() * (1 + margin) |
| scene_minima = midpoint - 0.5 * span |
| scene_maxima = midpoint + 0.5 * span |
| return scene_minima, scene_maxima |
|
|
|
|
| def unproject_frustum_corners( |
| extrinsics: Float[Tensor, "batch 4 4"], |
| intrinsics: Float[Tensor, "batch 3 3"], |
| depth: Float[Tensor, "#batch"], |
| ) -> Float[Tensor, "batch 4 3"]: |
| device = extrinsics.device |
|
|
| |
| xy = torch.linspace(0, 1, 2, device=device) |
| xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1) |
| xy = rearrange(xy, "i j xy -> (i j) xy") |
| xy = xy[torch.tensor([0, 1, 3, 2], device=device)] |
|
|
| |
| directions = unproject( |
| xy, |
| torch.ones(1, dtype=torch.float32, device=device), |
| rearrange(intrinsics, "b i j -> b () i j"), |
| ) |
|
|
| |
| |
| directions = directions / directions[..., -1:] |
| directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i") |
|
|
| origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz") |
| depth = rearrange(depth, "b -> b () ()") |
| return origins + depth * directions |
|
|