| import torch |
| from einops import rearrange |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
|
|
| @torch.no_grad() |
| def generate_wobble_transformation( |
| radius: Float[Tensor, "*#batch"], |
| t: Float[Tensor, " time_step"], |
| num_rotations: int = 1, |
| scale_radius_with_t: bool = True, |
| ) -> Float[Tensor, "*batch time_step 4 4"]: |
| |
| tf = torch.eye(4, dtype=torch.float32, device=t.device) |
| tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() |
| radius = radius[..., None] |
| if scale_radius_with_t: |
| radius = radius * t |
| tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius |
| tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius |
| return tf |
|
|
|
|
| @torch.no_grad() |
| def generate_wobble( |
| extrinsics: Float[Tensor, "*#batch 4 4"], |
| radius: Float[Tensor, "*#batch"], |
| t: Float[Tensor, " time_step"], |
| ) -> Float[Tensor, "*batch time_step 4 4"]: |
| tf = generate_wobble_transformation(radius, t) |
| return rearrange(extrinsics, "... i j -> ... () i j") @ tf |
|
|