| import numpy as np |
| import torch |
| from einops import repeat |
| from jaxtyping import Float |
| from scipy.spatial.transform import Rotation as R |
| from torch import Tensor |
|
|
|
|
| def generate_spin( |
| num_frames: int, |
| device: torch.device, |
| elevation: float, |
| radius: float, |
| ) -> Float[Tensor, "frame 4 4"]: |
| |
| tf_translation = torch.eye(4, dtype=torch.float32, device=device) |
| tf_translation[:2] *= -1 |
| tf_translation[2, 3] = -radius |
|
|
| |
| phi = 2 * np.pi * (np.arange(num_frames) / num_frames) |
| rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) |
|
|
| azimuth = R.from_rotvec(rotation_vectors).as_matrix() |
| azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) |
| tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) |
| tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() |
| tf_azimuth[:, :3, :3] = azimuth |
|
|
| |
| deg_elevation = np.deg2rad(elevation) |
| elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) |
| elevation = torch.tensor(elevation.as_matrix()) |
| tf_elevation = torch.eye(4, dtype=torch.float32, device=device) |
| tf_elevation[:3, :3] = elevation |
|
|
| return tf_azimuth @ tf_elevation @ tf_translation |
|
|