|
|
"""Contains utility functionality to render different modalities. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import dataclasses |
|
|
from typing import Literal, NamedTuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from .gaussians import Gaussians3D |
|
|
from .linalg import eyes |
|
|
|
|
|
TrajetoryType = Literal["swipe", "shake", "rotate", "rotate_forward"] |
|
|
LookAtMode = Literal["point", "ahead"] |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class CameraInfo: |
|
|
"""Camera info for a pinhole camera.""" |
|
|
|
|
|
intrinsics: torch.Tensor |
|
|
extrinsics: torch.Tensor |
|
|
width: int |
|
|
height: int |
|
|
|
|
|
|
|
|
class FocusRange(NamedTuple): |
|
|
"""Parametrizes a range of depth / disparity values.""" |
|
|
|
|
|
min: float |
|
|
focus: float |
|
|
max: float |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class TrajectoryParams: |
|
|
"""Parameters for trajectory.""" |
|
|
|
|
|
type: TrajetoryType = "rotate_forward" |
|
|
lookat_mode: LookAtMode = "point" |
|
|
max_disparity: float = 0.08 |
|
|
max_zoom: float = 0.15 |
|
|
distance_m: float = 0.0 |
|
|
num_steps: int = 60 |
|
|
num_repeats: int = 1 |
|
|
|
|
|
|
|
|
def compute_max_offset( |
|
|
scene: Gaussians3D, |
|
|
params: TrajectoryParams, |
|
|
resolution_px: tuple[int, int], |
|
|
f_px: float, |
|
|
) -> np.ndarray: |
|
|
"""Compute the maximum offset for camera along X/Y/Z axis.""" |
|
|
scene_points = scene.mean_vectors |
|
|
extrinsics = torch.eye(4).to(scene_points.device) |
|
|
min_depth, _, _ = _compute_depth_quantiles(scene_points, extrinsics) |
|
|
|
|
|
r_px = resolution_px |
|
|
diagonal = np.sqrt((r_px[0] / f_px) ** 2 + (r_px[1] / f_px) ** 2) |
|
|
max_lateral_offset_m = params.max_disparity * diagonal * min_depth |
|
|
|
|
|
max_medial_offset_m = params.max_zoom * min_depth |
|
|
max_offset_xyz_m = np.array([max_lateral_offset_m, max_lateral_offset_m, max_medial_offset_m]) |
|
|
|
|
|
return max_offset_xyz_m |
|
|
|
|
|
|
|
|
def create_eye_trajectory( |
|
|
scene: Gaussians3D, |
|
|
params: TrajectoryParams, |
|
|
resolution_px: tuple[int, int], |
|
|
f_px: float, |
|
|
) -> list[torch.Tensor]: |
|
|
"""Create eye trajectory for trajectory type.""" |
|
|
max_offset_xyz_m = compute_max_offset( |
|
|
scene, |
|
|
params, |
|
|
resolution_px, |
|
|
f_px, |
|
|
) |
|
|
|
|
|
|
|
|
if params.type == "swipe": |
|
|
return create_eye_trajectory_swipe( |
|
|
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats |
|
|
) |
|
|
elif params.type == "shake": |
|
|
return create_eye_trajectory_shake( |
|
|
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats |
|
|
) |
|
|
elif params.type == "rotate": |
|
|
return create_eye_trajectory_rotate( |
|
|
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats |
|
|
) |
|
|
elif params.type == "rotate_forward": |
|
|
return create_eye_trajectory_rotate_forward( |
|
|
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid trajectory type {params.type}.") |
|
|
|
|
|
|
|
|
def create_eye_trajectory_swipe( |
|
|
offset_xyz_m: np.ndarray, |
|
|
distance_m: float, |
|
|
num_steps: int, |
|
|
num_repeats: int, |
|
|
) -> list[torch.Tensor]: |
|
|
"""Create a left to right swipe trajectory.""" |
|
|
offset_x_m, _, _ = offset_xyz_m |
|
|
eye_positions = [ |
|
|
torch.tensor([x, 0, distance_m], dtype=torch.float32) |
|
|
for x in np.linspace(-offset_x_m, offset_x_m, num_steps) |
|
|
] |
|
|
return eye_positions * num_repeats |
|
|
|
|
|
|
|
|
def create_eye_trajectory_shake( |
|
|
offset_xyz_m: np.ndarray, |
|
|
distance_m: float, |
|
|
num_steps: int, |
|
|
num_repeats: int, |
|
|
) -> list[torch.Tensor]: |
|
|
"""Create a left right shake followed by an up down shake trajectory.""" |
|
|
num_steps_total = num_steps * num_repeats |
|
|
num_steps_horizontal = num_steps_total // 2 |
|
|
num_steps_vertical = num_steps_total - num_steps_horizontal |
|
|
|
|
|
offset_x_m, offset_y_m, _ = offset_xyz_m |
|
|
eye_positions: list[torch.Tensor] = [] |
|
|
eye_positions.extend( |
|
|
torch.tensor( |
|
|
[offset_x_m * np.sin(2 * np.pi * t), 0.0, distance_m], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
for t in np.linspace(0, num_repeats, num_steps_horizontal) |
|
|
) |
|
|
eye_positions.extend( |
|
|
torch.tensor( |
|
|
[0.0, offset_y_m * np.sin(2 * np.pi * t), distance_m], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
for t in np.linspace(0, num_repeats, num_steps_vertical) |
|
|
) |
|
|
|
|
|
return eye_positions |
|
|
|
|
|
|
|
|
def create_eye_trajectory_rotate( |
|
|
offset_xyz_m: np.ndarray, |
|
|
distance_m: float, |
|
|
num_steps: int, |
|
|
num_repeats: int, |
|
|
) -> list[torch.Tensor]: |
|
|
"""Create a rotating trajectory.""" |
|
|
num_steps_total = num_steps * num_repeats |
|
|
offset_x_m, offset_y_m, _ = offset_xyz_m |
|
|
eye_positions = [ |
|
|
torch.tensor( |
|
|
[ |
|
|
offset_x_m * np.sin(2 * np.pi * t), |
|
|
offset_y_m * np.cos(2 * np.pi * t), |
|
|
distance_m, |
|
|
], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
for t in np.linspace(0, num_repeats, num_steps_total) |
|
|
] |
|
|
|
|
|
return eye_positions |
|
|
|
|
|
|
|
|
def create_eye_trajectory_rotate_forward( |
|
|
offset_xyz_m: np.ndarray, |
|
|
distance_m: float, |
|
|
num_steps: int, |
|
|
num_repeats: int, |
|
|
) -> list[torch.Tensor]: |
|
|
"""Create a rotating trajectory.""" |
|
|
num_steps_total = num_steps * num_repeats |
|
|
offset_x_m, _, offset_z_m = offset_xyz_m |
|
|
eye_positions = [ |
|
|
torch.tensor( |
|
|
[ |
|
|
offset_x_m * np.sin(2 * np.pi * t), |
|
|
0.0, |
|
|
distance_m + offset_z_m * (1.0 - np.cos(2 * np.pi * t)) / 2, |
|
|
], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
for t in np.linspace(0, num_repeats, num_steps_total) |
|
|
] |
|
|
|
|
|
return eye_positions |
|
|
|
|
|
|
|
|
def create_camera_model( |
|
|
scene: Gaussians3D, |
|
|
intrinsics: torch.Tensor, |
|
|
resolution_px: tuple[int, int], |
|
|
lookat_mode: LookAtMode = "point", |
|
|
) -> PinholeCameraModel: |
|
|
"""Create camera model to simulate general pinhole camera.""" |
|
|
screen_extrinsics = torch.eye(4) |
|
|
screen_intrinsics = intrinsics.clone() |
|
|
|
|
|
image_width, image_height = resolution_px |
|
|
screen_resolution_px = get_screen_resolution_px_from_input( |
|
|
width=image_width, height=image_height |
|
|
) |
|
|
|
|
|
screen_intrinsics[0] *= screen_resolution_px[0] / image_width |
|
|
screen_intrinsics[1] *= screen_resolution_px[1] / image_height |
|
|
|
|
|
camera_model = PinholeCameraModel( |
|
|
scene, |
|
|
screen_extrinsics=screen_extrinsics, |
|
|
screen_intrinsics=screen_intrinsics, |
|
|
screen_resolution_px=screen_resolution_px, |
|
|
focus_depth_quantile=0.1, |
|
|
min_depth_focus=2.0, |
|
|
lookat_mode=lookat_mode, |
|
|
) |
|
|
return camera_model |
|
|
|
|
|
|
|
|
def create_camera_matrix( |
|
|
position: torch.Tensor, |
|
|
look_at_position: torch.Tensor | None = None, |
|
|
world_up: torch.Tensor | None = None, |
|
|
inverse: bool = False, |
|
|
) -> torch.Tensor: |
|
|
"""Create camera matrix from vectors.""" |
|
|
device = position.device |
|
|
|
|
|
if look_at_position is None: |
|
|
look_at_position = torch.zeros(3, device=device) |
|
|
if world_up is None: |
|
|
world_up = torch.tensor([0.0, 0.0, 1.0], device=device) |
|
|
|
|
|
position, look_at_position, world_up = torch.broadcast_tensors( |
|
|
position, look_at_position, world_up |
|
|
) |
|
|
|
|
|
camera_front = look_at_position - position |
|
|
camera_front = camera_front / camera_front.norm(dim=-1, keepdim=True) |
|
|
|
|
|
camera_right = torch.cross(camera_front, world_up, dim=-1) |
|
|
camera_right = camera_right / camera_right.norm(dim=-1, keepdim=True) |
|
|
|
|
|
camera_down = torch.cross(camera_front, camera_right, dim=-1) |
|
|
rotation_matrix = torch.stack([camera_right, camera_down, camera_front], dim=-1) |
|
|
|
|
|
matrix = eyes(dim=4, shape=position.shape[:-1], device=device) |
|
|
if inverse: |
|
|
matrix[..., :3, :3] = rotation_matrix.transpose(-1, -2) |
|
|
matrix[..., :3, 3:4] = -rotation_matrix.transpose(-1, -2) @ position[..., None] |
|
|
else: |
|
|
matrix[..., :3, :3] = rotation_matrix |
|
|
matrix[..., :3, 3] = position |
|
|
|
|
|
return matrix |
|
|
|
|
|
|
|
|
class PinholeCameraModel: |
|
|
"""Camera model that focuses on point.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
scene: Gaussians3D, |
|
|
screen_extrinsics: torch.Tensor, |
|
|
screen_intrinsics: torch.Tensor, |
|
|
screen_resolution_px: tuple[int, int], |
|
|
focus_depth_quantile: float = 0.1, |
|
|
min_depth_focus: float = 2.0, |
|
|
lookat_point: tuple[float, float, float] | None = None, |
|
|
lookat_mode: LookAtMode = "point", |
|
|
) -> None: |
|
|
"""Initialize GeneralPinholeCameraModel. |
|
|
|
|
|
Args: |
|
|
scene: The scene to display. |
|
|
screen_extrinsics: Extrinsics of the default position. |
|
|
screen_intrinsics: Intrinsics to use for rendering. |
|
|
screen_resolution_px: Width and height to render. |
|
|
focus_depth_quantile: Where inside the depth range to focus on. |
|
|
min_depth_focus: Depth to focus at. |
|
|
lookat_point: a point that the camera's Z axis directs towards. |
|
|
lookat_mode: "point" to look at a fixed point, |
|
|
"ahead" to look straight ahead. |
|
|
""" |
|
|
self.scene = scene |
|
|
self.screen_extrinsics = screen_extrinsics |
|
|
self.screen_intrinsics = screen_intrinsics |
|
|
self.screen_resolution_px = screen_resolution_px |
|
|
|
|
|
self.focus_depth_quantile = focus_depth_quantile |
|
|
self.min_depth_focus = min_depth_focus |
|
|
self.lookat_point = lookat_point |
|
|
self.lookat_mode = lookat_mode |
|
|
|
|
|
scene_points = scene.mean_vectors |
|
|
if scene_points.ndim == 3: |
|
|
scene_points = scene_points[0] |
|
|
elif scene_points.ndim != 2: |
|
|
raise ValueError("Unsupported dimensionality of scene points.") |
|
|
self._scene_points = scene_points.cpu() |
|
|
|
|
|
self.depth_quantiles = _compute_depth_quantiles( |
|
|
self._scene_points, |
|
|
self.screen_extrinsics, |
|
|
q_focus=self.focus_depth_quantile, |
|
|
) |
|
|
|
|
|
def compute(self, eye_pos: torch.Tensor) -> CameraInfo: |
|
|
"""Compute camera for eye position.""" |
|
|
extrinsics = self.screen_extrinsics.clone() |
|
|
|
|
|
origin = eye_pos if self.lookat_mode == "ahead" else torch.zeros(3) |
|
|
|
|
|
if self.lookat_point is None: |
|
|
depth_focus = max(self.min_depth_focus, self.depth_quantiles.focus) |
|
|
look_at_position = origin + torch.tensor([0.0, 0.0, depth_focus]) |
|
|
else: |
|
|
look_at_position = origin + torch.tensor([*self.lookat_point]) |
|
|
|
|
|
world_up = torch.tensor([0.0, -1.0, 0.0]) |
|
|
extrinsics_modifier = create_camera_matrix( |
|
|
eye_pos, look_at_position, world_up, inverse=True |
|
|
) |
|
|
extrinsics = extrinsics_modifier @ self.screen_extrinsics |
|
|
|
|
|
camera_info = CameraInfo( |
|
|
intrinsics=self.screen_intrinsics, |
|
|
extrinsics=extrinsics, |
|
|
width=self.screen_resolution_px[0], |
|
|
height=self.screen_resolution_px[1], |
|
|
) |
|
|
return camera_info |
|
|
|
|
|
def set_screen_extrinsics(self, new_value: torch.Tensor) -> None: |
|
|
"""Modify the default extrinsics.""" |
|
|
self.screen_extrinsics = new_value |
|
|
self.depth_quantiles = _compute_depth_quantiles(self._scene_points, self.screen_extrinsics) |
|
|
|
|
|
|
|
|
def get_screen_resolution_px_from_input(width: int, height: int) -> tuple[int, int]: |
|
|
"""Get resolution for metadata dictionary.""" |
|
|
resolution_px = (width, height) |
|
|
|
|
|
if resolution_px[1] > 3000: |
|
|
resolution_px = (resolution_px[0] // 2, resolution_px[1] // 2) |
|
|
|
|
|
|
|
|
if resolution_px[0] % 2 != 0: |
|
|
resolution_px = (resolution_px[0] + 1, resolution_px[1]) |
|
|
if resolution_px[1] % 2 != 0: |
|
|
resolution_px = (resolution_px[0], resolution_px[1] + 1) |
|
|
return resolution_px |
|
|
|
|
|
|
|
|
def _compute_depth_quantiles( |
|
|
points: torch.Tensor, |
|
|
extrinsics: torch.Tensor, |
|
|
q_near: float = 0.001, |
|
|
q_focus: float = 0.1, |
|
|
q_far: float = 0.999, |
|
|
) -> FocusRange: |
|
|
"""Compute disparity quantiles for scene and extrinsics id.""" |
|
|
points_local = points @ extrinsics[:3, :3].T + extrinsics[:3, 3] |
|
|
depth_values = points_local[..., 2].flatten() |
|
|
depth_values = depth_values[depth_values > 0] |
|
|
q_values = torch.tensor([q_near, q_focus, q_far]) |
|
|
depth_quantiles_pt = torch.quantile(depth_values.cpu(), q_values) |
|
|
depth_quantiles = FocusRange( |
|
|
min=float(depth_quantiles_pt[0]), |
|
|
focus=float(depth_quantiles_pt[1]), |
|
|
max=float(depth_quantiles_pt[2]), |
|
|
) |
|
|
return depth_quantiles |
|
|
|