| | from typing import Optional |
| |
|
| | import torch |
| | from einops import einsum, rearrange, repeat |
| | from jaxtyping import Float |
| | from torch import Tensor |
| |
|
| | from ...geometry.projection import unproject |
| | from ..annotation import add_label |
| | from .lines import draw_lines |
| | from .types import Scalar, sanitize_scalar |
| |
|
| |
|
| | def draw_cameras( |
| | resolution: int, |
| | extrinsics: Float[Tensor, "batch 4 4"], |
| | intrinsics: Float[Tensor, "batch 3 3"], |
| | color: Float[Tensor, "batch 3"], |
| | near: Optional[Scalar] = None, |
| | far: Optional[Scalar] = None, |
| | margin: float = 0.1, |
| | frustum_scale: float = 0.05, |
| | ) -> Float[Tensor, "3 3 height width"]: |
| | device = extrinsics.device |
| |
|
| | |
| | minima, maxima = compute_aabb(extrinsics, intrinsics, near, far) |
| | scene_minima, scene_maxima = compute_equal_aabb_with_margin( |
| | minima, maxima, margin=margin |
| | ) |
| | span = (scene_maxima - scene_minima).max() |
| |
|
| | |
| | corner_depth = (span * frustum_scale)[None] |
| | frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth) |
| | if near is not None: |
| | near_corners = unproject_frustum_corners(extrinsics, intrinsics, near) |
| | if far is not None: |
| | far_corners = unproject_frustum_corners(extrinsics, intrinsics, far) |
| |
|
| | |
| | projections = [] |
| | for projected_axis in range(3): |
| | image = torch.zeros( |
| | (3, resolution, resolution), |
| | dtype=torch.float32, |
| | device=device, |
| | ) |
| | image_x_axis = (projected_axis + 1) % 3 |
| | image_y_axis = (projected_axis + 2) % 3 |
| |
|
| | def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]: |
| | x = points[..., image_x_axis] |
| | y = points[..., image_y_axis] |
| | return torch.stack([x, y], dim=-1) |
| |
|
| | x_range, y_range = torch.stack( |
| | (project(scene_minima), project(scene_maxima)), dim=-1 |
| | ) |
| |
|
| | |
| | if near is not None: |
| | projected_near_corners = project(near_corners) |
| | image = draw_lines( |
| | image, |
| | rearrange(projected_near_corners, "b p xy -> (b p) xy"), |
| | rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"), |
| | color=0.25, |
| | width=2, |
| | x_range=x_range, |
| | y_range=y_range, |
| | ) |
| | if far is not None: |
| | projected_far_corners = project(far_corners) |
| | image = draw_lines( |
| | image, |
| | rearrange(projected_far_corners, "b p xy -> (b p) xy"), |
| | rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"), |
| | color=0.25, |
| | width=2, |
| | x_range=x_range, |
| | y_range=y_range, |
| | ) |
| | if near is not None and far is not None: |
| | image = draw_lines( |
| | image, |
| | rearrange(projected_near_corners, "b p xy -> (b p) xy"), |
| | rearrange(projected_far_corners, "b p xy -> (b p) xy"), |
| | color=0.25, |
| | width=2, |
| | x_range=x_range, |
| | y_range=y_range, |
| | ) |
| |
|
| | |
| | projected_origins = project(extrinsics[:, :3, 3]) |
| | projected_frustum_corners = project(frustum_corners) |
| | start = [ |
| | repeat(projected_origins, "b xy -> (b p) xy", p=4), |
| | rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"), |
| | ] |
| | start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4) |
| | image = draw_lines( |
| | image, |
| | start, |
| | repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2), |
| | color=repeat(color, "b c -> (b r p) c", r=2, p=4), |
| | width=2, |
| | x_range=x_range, |
| | y_range=y_range, |
| | ) |
| |
|
| | x_name = "XYZ"[image_x_axis] |
| | y_name = "XYZ"[image_y_axis] |
| | image = add_label(image, f"{x_name}{y_name} Projection") |
| |
|
| | |
| | projections.append(image) |
| |
|
| | return torch.stack(projections) |
| |
|
| |
|
| | def compute_aabb( |
| | extrinsics: Float[Tensor, "batch 4 4"], |
| | intrinsics: Float[Tensor, "batch 3 3"], |
| | near: Optional[Scalar] = None, |
| | far: Optional[Scalar] = None, |
| | ) -> tuple[ |
| | Float[Tensor, "3"], |
| | Float[Tensor, "3"], |
| | ]: |
| | """Compute an axis-aligned bounding box for the camera frustums.""" |
| |
|
| | device = extrinsics.device |
| |
|
| | |
| | points = [extrinsics[:, :3, 3]] |
| |
|
| | if near is not None: |
| | near = sanitize_scalar(near, device) |
| | corners = unproject_frustum_corners(extrinsics, intrinsics, near) |
| | points.append(rearrange(corners, "b p xyz -> (b p) xyz")) |
| |
|
| | if far is not None: |
| | far = sanitize_scalar(far, device) |
| | corners = unproject_frustum_corners(extrinsics, intrinsics, far) |
| | points.append(rearrange(corners, "b p xyz -> (b p) xyz")) |
| |
|
| | points = torch.cat(points, dim=0) |
| | return points.min(dim=0).values, points.max(dim=0).values |
| |
|
| |
|
| | def compute_equal_aabb_with_margin( |
| | minima: Float[Tensor, "*#batch 3"], |
| | maxima: Float[Tensor, "*#batch 3"], |
| | margin: float = 0.1, |
| | ) -> tuple[ |
| | Float[Tensor, "*batch 3"], |
| | Float[Tensor, "*batch 3"], |
| | ]: |
| | midpoint = (maxima + minima) * 0.5 |
| | span = (maxima - minima).max() * (1 + margin) |
| | scene_minima = midpoint - 0.5 * span |
| | scene_maxima = midpoint + 0.5 * span |
| | return scene_minima, scene_maxima |
| |
|
| |
|
| | def unproject_frustum_corners( |
| | extrinsics: Float[Tensor, "batch 4 4"], |
| | intrinsics: Float[Tensor, "batch 3 3"], |
| | depth: Float[Tensor, "#batch"], |
| | ) -> Float[Tensor, "batch 4 3"]: |
| | device = extrinsics.device |
| |
|
| | |
| | xy = torch.linspace(0, 1, 2, device=device) |
| | xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1) |
| | xy = rearrange(xy, "i j xy -> (i j) xy") |
| | xy = xy[torch.tensor([0, 1, 3, 2], device=device)] |
| |
|
| | |
| | directions = unproject( |
| | xy, |
| | torch.ones(1, dtype=torch.float32, device=device), |
| | rearrange(intrinsics, "b i j -> b () i j"), |
| | ) |
| |
|
| | |
| | |
| | directions = directions / directions[..., -1:] |
| | directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i") |
| |
|
| | origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz") |
| | depth = rearrange(depth, "b -> b () ()") |
| | return origins + depth * directions |
| |
|