MoGe / utils3d /torch /transforms.py
Ruicheng's picture
first commit
ec0c8fa
raw
history blame
41.3 kB
from typing import *
from numbers import Number
import torch
import torch.nn.functional as F
from ._helpers import batched
__all__ = [
'perspective',
'perspective_from_fov',
'perspective_from_fov_xy',
'intrinsics_from_focal_center',
'intrinsics_from_fov',
'intrinsics_from_fov_xy',
'view_look_at',
'extrinsics_look_at',
'perspective_to_intrinsics',
'intrinsics_to_perspective',
'extrinsics_to_view',
'view_to_extrinsics',
'normalize_intrinsics',
'crop_intrinsics',
'pixel_to_uv',
'pixel_to_ndc',
'uv_to_pixel',
'project_depth',
'depth_buffer_to_linear',
'project_gl',
'project_cv',
'unproject_gl',
'unproject_cv',
'skew_symmetric',
'rotation_matrix_from_vectors',
'euler_axis_angle_rotation',
'euler_angles_to_matrix',
'matrix_to_euler_angles',
'matrix_to_quaternion',
'quaternion_to_matrix',
'matrix_to_axis_angle',
'axis_angle_to_matrix',
'axis_angle_to_quaternion',
'quaternion_to_axis_angle',
'slerp',
'interpolate_extrinsics',
'interpolate_view',
'extrinsics_to_essential',
'to4x4',
'rotation_matrix_2d',
'rotate_2d',
'translate_2d',
'scale_2d',
'apply_2d',
]
@batched(0,0,0,0)
def perspective(
fov_y: Union[float, torch.Tensor],
aspect: Union[float, torch.Tensor],
near: Union[float, torch.Tensor],
far: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Get OpenGL perspective matrix
Args:
fov_y (float | torch.Tensor): field of view in y axis
aspect (float | torch.Tensor): aspect ratio
near (float | torch.Tensor): near plane to clip
far (float | torch.Tensor): far plane to clip
Returns:
(torch.Tensor): [..., 4, 4] perspective matrix
"""
N = fov_y.shape[0]
ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device)
ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect)
ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2))
ret[:, 2, 2] = (near + far) / (near - far)
ret[:, 2, 3] = 2. * near * far / (near - far)
ret[:, 3, 2] = -1.
return ret
def perspective_from_fov(
fov: Union[float, torch.Tensor],
width: Union[int, torch.Tensor],
height: Union[int, torch.Tensor],
near: Union[float, torch.Tensor],
far: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Get OpenGL perspective matrix from field of view in largest dimension
Args:
fov (float | torch.Tensor): field of view in largest dimension
width (int | torch.Tensor): image width
height (int | torch.Tensor): image height
near (float | torch.Tensor): near plane to clip
far (float | torch.Tensor): far plane to clip
Returns:
(torch.Tensor): [..., 4, 4] perspective matrix
"""
fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height))
aspect = width / height
return perspective(fov_y, aspect, near, far)
def perspective_from_fov_xy(
fov_x: Union[float, torch.Tensor],
fov_y: Union[float, torch.Tensor],
near: Union[float, torch.Tensor],
far: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Get OpenGL perspective matrix from field of view in x and y axis
Args:
fov_x (float | torch.Tensor): field of view in x axis
fov_y (float | torch.Tensor): field of view in y axis
near (float | torch.Tensor): near plane to clip
far (float | torch.Tensor): far plane to clip
Returns:
(torch.Tensor): [..., 4, 4] perspective matrix
"""
aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2)
return perspective(fov_y, aspect, near, far)
@batched(0,0,0,0)
def intrinsics_from_focal_center(
fx: Union[float, torch.Tensor],
fy: Union[float, torch.Tensor],
cx: Union[float, torch.Tensor],
cy: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Get OpenCV intrinsics matrix
Args:
focal_x (float | torch.Tensor): focal length in x axis
focal_y (float | torch.Tensor): focal length in y axis
cx (float | torch.Tensor): principal point in x axis
cy (float | torch.Tensor): principal point in y axis
Returns:
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
"""
N = fx.shape[0]
ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device)
ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3))
return ret
@batched(0, 0, 0, 0, 0, 0)
def intrinsics_from_fov(
fov_max: Union[float, torch.Tensor] = None,
fov_min: Union[float, torch.Tensor] = None,
fov_x: Union[float, torch.Tensor] = None,
fov_y: Union[float, torch.Tensor] = None,
width: Union[int, torch.Tensor] = None,
height: Union[int, torch.Tensor] = None,
) -> torch.Tensor:
"""
Get normalized OpenCV intrinsics matrix from given field of view.
You can provide either fov_max, fov_min, fov_x or fov_y
Args:
width (int | torch.Tensor): image width
height (int | torch.Tensor): image height
fov_max (float | torch.Tensor): field of view in largest dimension
fov_min (float | torch.Tensor): field of view in smallest dimension
fov_x (float | torch.Tensor): field of view in x axis
fov_y (float | torch.Tensor): field of view in y axis
Returns:
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
"""
if fov_max is not None:
fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2))
fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2))
elif fov_min is not None:
fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2))
fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2))
elif fov_x is not None and fov_y is not None:
fx = 1 / (2 * torch.tan(fov_x / 2))
fy = 1 / (2 * torch.tan(fov_y / 2))
elif fov_x is not None:
fx = 1 / (2 * torch.tan(fov_x / 2))
fy = fx * width / height
elif fov_y is not None:
fy = 1 / (2 * torch.tan(fov_y / 2))
fx = fy * height / width
cx = 0.5
cy = 0.5
ret = intrinsics_from_focal_center(fx, fy, cx, cy)
return ret
def intrinsics_from_fov_xy(
fov_x: Union[float, torch.Tensor],
fov_y: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Get OpenCV intrinsics matrix from field of view in x and y axis
Args:
fov_x (float | torch.Tensor): field of view in x axis
fov_y (float | torch.Tensor): field of view in y axis
Returns:
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
"""
focal_x = 0.5 / torch.tan(fov_x / 2)
focal_y = 0.5 / torch.tan(fov_y / 2)
cx = cy = 0.5
return intrinsics_from_focal_center(focal_x, focal_y, cx, cy)
@batched(1,1,1)
def view_look_at(
eye: torch.Tensor,
look_at: torch.Tensor,
up: torch.Tensor
) -> torch.Tensor:
"""
Get OpenGL view matrix looking at something
Args:
eye (torch.Tensor): [..., 3] the eye position
look_at (torch.Tensor): [..., 3] the position to look at
up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
Returns:
(torch.Tensor): [..., 4, 4], view matrix
"""
N = eye.shape[0]
z = eye - look_at
x = torch.cross(up, z, dim=-1)
y = torch.cross(z, x, dim=-1)
# x = torch.cross(y, z, dim=-1)
x = x / x.norm(dim=-1, keepdim=True)
y = y / y.norm(dim=-1, keepdim=True)
z = z / z.norm(dim=-1, keepdim=True)
R = torch.stack([x, y, z], dim=-2)
t = -torch.matmul(R, eye[..., None])
ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
ret[:, :3, :3] = R
ret[:, :3, 3] = t[:, :, 0]
ret[:, 3, 3] = 1.
return ret
@batched(1, 1, 1)
def extrinsics_look_at(
eye: torch.Tensor,
look_at: torch.Tensor,
up: torch.Tensor
) -> torch.Tensor:
"""
Get OpenCV extrinsics matrix looking at something
Args:
eye (torch.Tensor): [..., 3] the eye position
look_at (torch.Tensor): [..., 3] the position to look at
up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
Returns:
(torch.Tensor): [..., 4, 4], extrinsics matrix
"""
N = eye.shape[0]
z = look_at - eye
x = torch.cross(-up, z, dim=-1)
y = torch.cross(z, x, dim=-1)
# x = torch.cross(y, z, dim=-1)
x = x / x.norm(dim=-1, keepdim=True)
y = y / y.norm(dim=-1, keepdim=True)
z = z / z.norm(dim=-1, keepdim=True)
R = torch.stack([x, y, z], dim=-2)
t = -torch.matmul(R, eye[..., None])
ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
ret[:, :3, :3] = R
ret[:, :3, 3] = t[:, :, 0]
ret[:, 3, 3] = 1.
return ret
@batched(2)
def perspective_to_intrinsics(
perspective: torch.Tensor
) -> torch.Tensor:
"""
OpenGL perspective matrix to OpenCV intrinsics
Args:
perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
Returns:
(torch.Tensor): shape [..., 3, 3] OpenCV intrinsics
"""
assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix"
ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \
@ perspective[:, [0, 1, 3], :3] \
@ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device))
return ret / ret[:, 2, 2, None, None]
@batched(2,0,0)
def intrinsics_to_perspective(
intrinsics: torch.Tensor,
near: Union[float, torch.Tensor],
far: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
near (float | torch.Tensor): [...] near plane to clip
far (float | torch.Tensor): [...] far plane to clip
Returns:
(torch.Tensor): [..., 4, 4] OpenGL perspective matrix
"""
N = intrinsics.shape[0]
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
ret[:, 0, 0] = 2 * fx
ret[:, 1, 1] = 2 * fy
ret[:, 0, 2] = -2 * cx + 1
ret[:, 1, 2] = 2 * cy - 1
ret[:, 2, 2] = (near + far) / (near - far)
ret[:, 2, 3] = 2. * near * far / (near - far)
ret[:, 3, 2] = -1.
return ret
@batched(2)
def extrinsics_to_view(
extrinsics: torch.Tensor
) -> torch.Tensor:
"""
OpenCV camera extrinsics to OpenGL view matrix
Args:
extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
Returns:
(torch.Tensor): [..., 4, 4] OpenGL view matrix
"""
return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None]
@batched(2)
def view_to_extrinsics(
view: torch.Tensor
) -> torch.Tensor:
"""
OpenGL view matrix to OpenCV camera extrinsics
Args:
view (torch.Tensor): [..., 4, 4] OpenGL view matrix
Returns:
(torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
"""
return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None]
@batched(2,0,0)
def normalize_intrinsics(
intrinsics: torch.Tensor,
width: Union[int, torch.Tensor],
height: Union[int, torch.Tensor]
) -> torch.Tensor:
"""
Normalize camera intrinsics(s) to uv space
Args:
intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize
width (int | torch.Tensor): [...] image width(s)
height (int | torch.Tensor): [...] image height(s)
Returns:
(torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)
"""
zeros = torch.zeros_like(width)
ones = torch.ones_like(width)
transform = torch.stack([
1 / width, zeros, 0.5 / width,
zeros, 1 / height, 0.5 / height,
zeros, zeros, ones
]).reshape(*zeros.shape, 3, 3).to(intrinsics)
return transform @ intrinsics
@batched(2,0,0,0,0,0,0)
def crop_intrinsics(
intrinsics: torch.Tensor,
width: Union[int, torch.Tensor],
height: Union[int, torch.Tensor],
left: Union[int, torch.Tensor],
top: Union[int, torch.Tensor],
crop_width: Union[int, torch.Tensor],
crop_height: Union[int, torch.Tensor]
) -> torch.Tensor:
"""
Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
Args:
intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop
width (int | torch.Tensor): [...] image width(s)
height (int | torch.Tensor): [...] image height(s)
left (int | torch.Tensor): [...] left crop boundary
top (int | torch.Tensor): [...] top crop boundary
crop_width (int | torch.Tensor): [...] crop width
crop_height (int | torch.Tensor): [...] crop height
Returns:
(torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)
"""
zeros = torch.zeros_like(width)
ones = torch.ones_like(width)
transform = torch.stack([
width / crop_width, zeros, -left / crop_width,
zeros, height / crop_height, -top / crop_height,
zeros, zeros, ones
]).reshape(*zeros.shape, 3, 3).to(intrinsics)
return transform @ intrinsics
@batched(1,0,0)
def pixel_to_uv(
pixel: torch.Tensor,
width: Union[int, torch.Tensor],
height: Union[int, torch.Tensor]
) -> torch.Tensor:
"""
Args:
pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
width (int | torch.Tensor): [...] image width(s)
height (int | torch.Tensor): [...] image height(s)
Returns:
(torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
"""
if not torch.is_floating_point(pixel):
pixel = pixel.float()
uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel)
return uv
@batched(1,0,0)
def uv_to_pixel(
uv: torch.Tensor,
width: Union[int, torch.Tensor],
height: Union[int, torch.Tensor]
) -> torch.Tensor:
"""
Args:
uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
width (int | torch.Tensor): [...] image width(s)
height (int | torch.Tensor): [...] image height(s)
Returns:
(torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
"""
pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5
return pixel
@batched(1,0,0)
def pixel_to_ndc(
pixel: torch.Tensor,
width: Union[int, torch.Tensor],
height: Union[int, torch.Tensor]
) -> torch.Tensor:
"""
Args:
pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
width (int | torch.Tensor): [...] image width(s)
height (int | torch.Tensor): [...] image height(s)
Returns:
(torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)
"""
if not torch.is_floating_point(pixel):
pixel = pixel.float()
ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \
+ torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device)
return ndc
@batched(0,0,0)
def project_depth(
depth: torch.Tensor,
near: Union[float, torch.Tensor],
far: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Project linear depth to depth value in screen space
Args:
depth (torch.Tensor): [...] depth value
near (float | torch.Tensor): [...] near plane to clip
far (float | torch.Tensor): [...] far plane to clip
Returns:
(torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]
"""
return (far - near * far / depth) / (far - near)
@batched(0,0,0)
def depth_buffer_to_linear(
depth: torch.Tensor,
near: Union[float, torch.Tensor],
far: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Linearize depth value to linear depth
Args:
depth (torch.Tensor): [...] screen depth value, ranging in [0, 1]
near (float | torch.Tensor): [...] near plane to clip
far (float | torch.Tensor): [...] far plane to clip
Returns:
(torch.Tensor): [...] linear depth
"""
return near * far / (far - (far - near) * depth)
@batched(2, 2, 2, 2)
def project_gl(
points: torch.Tensor,
model: torch.Tensor = None,
view: torch.Tensor = None,
perspective: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project 3D points to 2D following the OpenGL convention (except for row major matrice)
Args:
points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last
dimension is 4, the points are assumed to be in homogeneous coordinates
model (torch.Tensor): [..., 4, 4] model matrix
view (torch.Tensor): [..., 4, 4] view matrix
perspective (torch.Tensor): [..., 4, 4] perspective matrix
Returns:
scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1].
The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
linear_depth (torch.Tensor): [..., N] linear depth
"""
assert perspective is not None, "perspective matrix is required"
if points.shape[-1] == 3:
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
mvp = perspective if perspective is not None else torch.eye(4).to(points)
if view is not None:
mvp = mvp @ view
if model is not None:
mvp = mvp @ model
clip_coord = points @ mvp.transpose(-1, -2)
ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:]
scr_coord = ndc_coord * 0.5 + 0.5
linear_depth = clip_coord[..., 3]
return scr_coord, linear_depth
@batched(2, 2, 2)
def project_cv(
points: torch.Tensor,
extrinsics: torch.Tensor = None,
intrinsics: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project 3D points to 2D following the OpenCV convention
Args:
points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last
dimension is 4, the points are assumed to be in homogeneous coordinates
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
Returns:
uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
The origin (0., 0.) is corresponding to the left & top
linear_depth (torch.Tensor): [..., N] linear depth
"""
assert intrinsics is not None, "intrinsics matrix is required"
if points.shape[-1] == 3:
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
if extrinsics is not None:
points = points @ extrinsics.transpose(-1, -2)
points = points[..., :3] @ intrinsics.transpose(-2, -1)
uv_coord = points[..., :2] / points[..., 2:]
linear_depth = points[..., 2]
return uv_coord, linear_depth
@batched(2, 2, 2, 2)
def unproject_gl(
screen_coord: torch.Tensor,
model: torch.Tensor = None,
view: torch.Tensor = None,
perspective: torch.Tensor = None
) -> torch.Tensor:
"""
Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
Args:
screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1].
The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
model (torch.Tensor): [..., 4, 4] model matrix
view (torch.Tensor): [..., 4, 4] view matrix
perspective (torch.Tensor): [..., 4, 4] perspective matrix
Returns:
points (torch.Tensor): [..., N, 3] 3d points
"""
assert perspective is not None, "perspective matrix is required"
ndc_xy = screen_coord * 2 - 1
clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1)
transform = perspective
if view is not None:
transform = transform @ view
if model is not None:
transform = transform @ model
transform = torch.inverse(transform)
points = clip_coord @ transform.transpose(-1, -2)
points = points[..., :3] / points[..., 3:]
return points
@batched(2, 1, 2, 2)
def unproject_cv(
uv_coord: torch.Tensor,
depth: torch.Tensor,
extrinsics: torch.Tensor = None,
intrinsics: torch.Tensor = None
) -> torch.Tensor:
"""
Unproject uv coordinates to 3D view space following the OpenCV convention
Args:
uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
The origin (0., 0.) is corresponding to the left & top
depth (torch.Tensor): [..., N] depth value
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
Returns:
points (torch.Tensor): [..., N, 3] 3d points
"""
assert intrinsics is not None, "intrinsics matrix is required"
points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
points = points @ torch.inverse(intrinsics).transpose(-2, -1)
points = points * depth[..., None]
if extrinsics is not None:
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
return points
def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ
convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)])
for c in convention
]
# return functools.reduce(torch.matmul, matrices)
return matrices[2] @ matrices[1] @ matrices[0]
def skew_symmetric(v: torch.Tensor):
"Skew symmetric matrix from a 3D vector"
assert v.shape[-1] == 3, "v must be 3D"
x, y, z = v.unbind(dim=-1)
zeros = torch.zeros_like(x)
return torch.stack([
zeros, -z, y,
z, zeros, -x,
-y, x, zeros,
], dim=-1).reshape(*v.shape[:-1], 3, 3)
def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor):
"Rotation matrix that rotates v1 to v2"
I = torch.eye(3).to(v1)
v1 = F.normalize(v1, dim=-1)
v2 = F.normalize(v2, dim=-1)
v = torch.cross(v1, v2, dim=-1)
c = torch.sum(v1 * v2, dim=-1)
K = skew_symmetric(v)
R = I + K + (1 / (1 + c))[None, None] * (K @ K)
return R
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
) -> torch.Tensor:
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to Euler angles in radians.
NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d)
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)
"""
if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'):
raise ValueError(f"Invalid convention {convention}.")
if not matrix.shape[-2:] == (3, 3):
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
i0 = 'XYZ'.index(convention[0])
i2 = 'XYZ'.index(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0))
else:
central_angle = torch.acos(matrix[..., i2, i2])
# Angles in composition order
o = [
_angle_from_tan(
convention[0], convention[1], matrix[..., i2, :], True, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0], False, tait_bryan
),
]
return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1)
def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
Args:
axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
Returns:
torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters
"""
batch_shape = axis_angle.shape[:-1]
device, dtype = axis_angle.device, axis_angle.dtype
angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True)
axis = axis_angle / angle
cos = torch.cos(angle)[..., None, :]
sin = torch.sin(angle)[..., None, :]
rx, ry, rz = torch.split(axis, 3, dim=-1)
zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device)
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device)
rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K)
return rot_mat
def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector)
Args:
rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
Returns:
torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices
"""
quat = matrix_to_quaternion(rot_mat)
axis_angle = quaternion_to_axis_angle(quat, eps=eps)
return axis_angle
def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector)
Args:
quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
Returns:
torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions
"""
assert quaternion.shape[-1] == 4
norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True)
axis = quaternion[..., 1:] / norm.clamp(min=eps)
angle = 2 * torch.atan2(norm, quaternion[..., 0:1])
return angle * axis
def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z)
Args:
axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
Returns:
torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters
"""
axis = F.normalize(axis_angle, dim=-1, eps=eps)
angle = torch.norm(axis_angle, dim=-1, keepdim=True)
quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1)
return quat
def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""Convert 3x3 rotation matrix to quaternion (w, x, y, z)
Args:
rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
Returns:
torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices
"""
# Extract the diagonal and off-diagonal elements of the rotation matrix
m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1)
diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1)
M = torch.tensor([
[1, 1, 1],
[1, -1, -1],
[-1, 1, -1],
[-1, -1, 1]
], dtype=rot_mat.dtype, device=rot_mat.device)
wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5)
_, max_idx = wxyz.max(dim=-1)
xw = torch.sign(m21 - m12)
yw = torch.sign(m02 - m20)
zw = torch.sign(m10 - m01)
yz = torch.sign(m21 + m12)
xz = torch.sign(m02 + m20)
xy = torch.sign(m01 + m10)
ones = torch.ones_like(xw)
sign = torch.where(
max_idx[..., None] == 0,
torch.stack([ones, xw, yw, zw], dim=-1),
torch.where(
max_idx[..., None] == 1,
torch.stack([xw, ones, xy, xz], dim=-1),
torch.where(
max_idx[..., None] == 2,
torch.stack([yw, xy, ones, yz], dim=-1),
torch.stack([zw, xz, yz, ones], dim=-1)
)
)
)
quat = sign * wxyz
quat = F.normalize(quat, dim=-1, eps=eps)
return quat
def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""Converts a batch of quaternions (w, x, y, z) to rotation matrices
Args:
quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
Returns:
torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions
"""
assert quaternion.shape[-1] == 4
quaternion = F.normalize(quaternion, dim=-1, eps=eps)
w, x, y, z = quaternion.unbind(dim=-1)
zeros = torch.zeros_like(w)
I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device)
xyz = quaternion[..., 1:]
A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None]
B = torch.stack([
zeros, -z, y,
z, zeros, -x,
-y, x, zeros
], dim=-1).unflatten(-1, (3, 3))
rot_mat = I + 2 * (A + w[..., None, None] * B)
return rot_mat
def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
"""Spherical linear interpolation between two rotation matrices
Args:
rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix
rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix
t (torch.Tensor): scalar or shape (...,), the interpolation factor
Returns:
torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix
"""
assert rot_mat_1.shape[-2:] == (3, 3)
rot_vec_1 = matrix_to_axis_angle(rot_mat_1)
rot_vec_2 = matrix_to_axis_angle(rot_mat_2)
if isinstance(t, Number):
t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device)
rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2
rot_mat = axis_angle_to_matrix(rot_vec)
return rot_mat
def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
"""Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
Args:
ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
t (torch.Tensor): scalar or shape (...,), the interpolation factor
Returns:
torch.Tensor: shape (..., 4, 4), the interpolated camera pose
"""
return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t))
def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]):
"""Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
Args:
ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
t (torch.Tensor): scalar or shape (...,), the interpolation factor
Returns:
torch.Tensor: shape (..., 4, 4), the interpolated camera pose
"""
return interpolate_extrinsics(view1, view2, t)
def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]):
assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4)
if isinstance(t, Number):
t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device)
pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3]
rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t)
transform = torch.cat([rot, pos[..., None]], dim=-1)
transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2)
return transform
def extrinsics_to_essential(extrinsics: torch.Tensor):
"""
extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
Args:
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
Returns:
(torch.Tensor): [..., 3, 3] essential matrix
"""
assert extrinsics.shape[-2:] == (4, 4)
R = extrinsics[..., :3, :3]
t = extrinsics[..., :3, 3]
zeros = torch.zeros_like(t)
t_x = torch.stack([
zeros, -t[..., 2], t[..., 1],
t[..., 2], zeros, -t[..., 0],
-t[..., 1], t[..., 0], zeros
]).reshape(*t.shape[:-1], 3, 3)
return R @ t_x
def to4x4(R: torch.Tensor, t: torch.Tensor):
"""
Compose rotation matrix and translation vector to 4x4 transformation matrix
Args:
R (torch.Tensor): [..., 3, 3] rotation matrix
t (torch.Tensor): [..., 3] translation vector
Returns:
(torch.Tensor): [..., 4, 4] transformation matrix
"""
assert R.shape[-2:] == (3, 3)
assert t.shape[-1] == 3
assert R.shape[:-2] == t.shape[:-1]
return torch.cat([
torch.cat([R, t[..., None]], dim=-1),
torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4)
], dim=-2)
def rotation_matrix_2d(theta: Union[float, torch.Tensor]):
"""
2x2 matrix for 2D rotation
Args:
theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
Returns:
(torch.Tensor): (..., 2, 2) rotation matrix
"""
if isinstance(theta, float):
theta = torch.tensor(theta)
return torch.stack([
torch.cos(theta), -torch.sin(theta),
torch.sin(theta), torch.cos(theta),
], dim=-1).unflatten(-1, (2, 2))
def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None):
"""
3x3 matrix for 2D rotation around a center
```
[[Rxx, Rxy, tx],
[Ryx, Ryy, ty],
[0, 0, 1]]
```
Args:
theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0)
Returns:
(torch.Tensor): (..., 3, 3) transformation matrix
"""
if isinstance(theta, float):
theta = torch.tensor(theta)
if center is not None:
theta = theta.to(center)
if center is None:
center = torch.zeros(2).to(theta).expand(*theta.shape, -1)
R = rotation_matrix_2d(theta)
return torch.cat([
torch.cat([
R,
center[..., :, None] - R @ center[..., :, None],
], dim=-1),
torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1),
], dim=-2)
def translate_2d(translation: torch.Tensor):
"""
Translation matrix for 2D translation
```
[[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
```
Args:
translation (torch.Tensor): translation vector, arbitrary shape (..., 2)
Returns:
(torch.Tensor): (..., 3, 3) transformation matrix
"""
return torch.cat([
torch.cat([
torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
translation[..., None],
], dim=-1),
torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
], dim=-2)
def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None):
"""
Scale matrix for 2D scaling
```
[[s, 0, tx],
[0, s, ty],
[0, 0, 1]]
```
Args:
scale (float | torch.Tensor): scale factor, arbitrary shape (...,)
center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0)
Returns:
(torch.Tensor): (..., 3, 3) transformation matrix
"""
if isinstance(scale, float):
scale = torch.tensor(scale)
if center is not None:
scale = scale.to(center)
if center is None:
center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1)
return torch.cat([
torch.cat([
scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1),
center[..., :, None] - center[..., :, None] * scale[..., None, None],
], dim=-1),
torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1),
], dim=-2)
def apply_2d(transform: torch.Tensor, points: torch.Tensor):
"""
Apply (3x3 or 2x3) 2D affine transformation to points
```
p = R @ p + t
```
Args:
transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix
points (torch.Tensor): (..., N, 2) points to transform
Returns:
(torch.Tensor): (..., N, 2) transformed points
"""
assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3"
assert points.shape[-1] == 2, "points must be 2D"
return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2]