from typing import * from numbers import Number import torch import torch.nn.functional as F from ._helpers import batched __all__ = [ 'perspective', 'perspective_from_fov', 'perspective_from_fov_xy', 'intrinsics_from_focal_center', 'intrinsics_from_fov', 'intrinsics_from_fov_xy', 'view_look_at', 'extrinsics_look_at', 'perspective_to_intrinsics', 'intrinsics_to_perspective', 'extrinsics_to_view', 'view_to_extrinsics', 'normalize_intrinsics', 'crop_intrinsics', 'pixel_to_uv', 'pixel_to_ndc', 'uv_to_pixel', 'project_depth', 'depth_buffer_to_linear', 'project_gl', 'project_cv', 'unproject_gl', 'unproject_cv', 'skew_symmetric', 'rotation_matrix_from_vectors', 'euler_axis_angle_rotation', 'euler_angles_to_matrix', 'matrix_to_euler_angles', 'matrix_to_quaternion', 'quaternion_to_matrix', 'matrix_to_axis_angle', 'axis_angle_to_matrix', 'axis_angle_to_quaternion', 'quaternion_to_axis_angle', 'slerp', 'interpolate_extrinsics', 'interpolate_view', 'extrinsics_to_essential', 'to4x4', 'rotation_matrix_2d', 'rotate_2d', 'translate_2d', 'scale_2d', 'apply_2d', ] @batched(0,0,0,0) def perspective( fov_y: Union[float, torch.Tensor], aspect: Union[float, torch.Tensor], near: Union[float, torch.Tensor], far: Union[float, torch.Tensor] ) -> torch.Tensor: """ Get OpenGL perspective matrix Args: fov_y (float | torch.Tensor): field of view in y axis aspect (float | torch.Tensor): aspect ratio near (float | torch.Tensor): near plane to clip far (float | torch.Tensor): far plane to clip Returns: (torch.Tensor): [..., 4, 4] perspective matrix """ N = fov_y.shape[0] ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device) ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect) ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2)) ret[:, 2, 2] = (near + far) / (near - far) ret[:, 2, 3] = 2. * near * far / (near - far) ret[:, 3, 2] = -1. return ret def perspective_from_fov( fov: Union[float, torch.Tensor], width: Union[int, torch.Tensor], height: Union[int, torch.Tensor], near: Union[float, torch.Tensor], far: Union[float, torch.Tensor] ) -> torch.Tensor: """ Get OpenGL perspective matrix from field of view in largest dimension Args: fov (float | torch.Tensor): field of view in largest dimension width (int | torch.Tensor): image width height (int | torch.Tensor): image height near (float | torch.Tensor): near plane to clip far (float | torch.Tensor): far plane to clip Returns: (torch.Tensor): [..., 4, 4] perspective matrix """ fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height)) aspect = width / height return perspective(fov_y, aspect, near, far) def perspective_from_fov_xy( fov_x: Union[float, torch.Tensor], fov_y: Union[float, torch.Tensor], near: Union[float, torch.Tensor], far: Union[float, torch.Tensor] ) -> torch.Tensor: """ Get OpenGL perspective matrix from field of view in x and y axis Args: fov_x (float | torch.Tensor): field of view in x axis fov_y (float | torch.Tensor): field of view in y axis near (float | torch.Tensor): near plane to clip far (float | torch.Tensor): far plane to clip Returns: (torch.Tensor): [..., 4, 4] perspective matrix """ aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2) return perspective(fov_y, aspect, near, far) @batched(0,0,0,0) def intrinsics_from_focal_center( fx: Union[float, torch.Tensor], fy: Union[float, torch.Tensor], cx: Union[float, torch.Tensor], cy: Union[float, torch.Tensor] ) -> torch.Tensor: """ Get OpenCV intrinsics matrix Args: focal_x (float | torch.Tensor): focal length in x axis focal_y (float | torch.Tensor): focal length in y axis cx (float | torch.Tensor): principal point in x axis cy (float | torch.Tensor): principal point in y axis Returns: (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix """ N = fx.shape[0] ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device) zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device) ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3)) return ret @batched(0, 0, 0, 0, 0, 0) def intrinsics_from_fov( fov_max: Union[float, torch.Tensor] = None, fov_min: Union[float, torch.Tensor] = None, fov_x: Union[float, torch.Tensor] = None, fov_y: Union[float, torch.Tensor] = None, width: Union[int, torch.Tensor] = None, height: Union[int, torch.Tensor] = None, ) -> torch.Tensor: """ Get normalized OpenCV intrinsics matrix from given field of view. You can provide either fov_max, fov_min, fov_x or fov_y Args: width (int | torch.Tensor): image width height (int | torch.Tensor): image height fov_max (float | torch.Tensor): field of view in largest dimension fov_min (float | torch.Tensor): field of view in smallest dimension fov_x (float | torch.Tensor): field of view in x axis fov_y (float | torch.Tensor): field of view in y axis Returns: (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix """ if fov_max is not None: fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2)) fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2)) elif fov_min is not None: fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2)) fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2)) elif fov_x is not None and fov_y is not None: fx = 1 / (2 * torch.tan(fov_x / 2)) fy = 1 / (2 * torch.tan(fov_y / 2)) elif fov_x is not None: fx = 1 / (2 * torch.tan(fov_x / 2)) fy = fx * width / height elif fov_y is not None: fy = 1 / (2 * torch.tan(fov_y / 2)) fx = fy * height / width cx = 0.5 cy = 0.5 ret = intrinsics_from_focal_center(fx, fy, cx, cy) return ret def intrinsics_from_fov_xy( fov_x: Union[float, torch.Tensor], fov_y: Union[float, torch.Tensor] ) -> torch.Tensor: """ Get OpenCV intrinsics matrix from field of view in x and y axis Args: fov_x (float | torch.Tensor): field of view in x axis fov_y (float | torch.Tensor): field of view in y axis Returns: (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix """ focal_x = 0.5 / torch.tan(fov_x / 2) focal_y = 0.5 / torch.tan(fov_y / 2) cx = cy = 0.5 return intrinsics_from_focal_center(focal_x, focal_y, cx, cy) @batched(1,1,1) def view_look_at( eye: torch.Tensor, look_at: torch.Tensor, up: torch.Tensor ) -> torch.Tensor: """ Get OpenGL view matrix looking at something Args: eye (torch.Tensor): [..., 3] the eye position look_at (torch.Tensor): [..., 3] the position to look at up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction Returns: (torch.Tensor): [..., 4, 4], view matrix """ N = eye.shape[0] z = eye - look_at x = torch.cross(up, z, dim=-1) y = torch.cross(z, x, dim=-1) # x = torch.cross(y, z, dim=-1) x = x / x.norm(dim=-1, keepdim=True) y = y / y.norm(dim=-1, keepdim=True) z = z / z.norm(dim=-1, keepdim=True) R = torch.stack([x, y, z], dim=-2) t = -torch.matmul(R, eye[..., None]) ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) ret[:, :3, :3] = R ret[:, :3, 3] = t[:, :, 0] ret[:, 3, 3] = 1. return ret @batched(1, 1, 1) def extrinsics_look_at( eye: torch.Tensor, look_at: torch.Tensor, up: torch.Tensor ) -> torch.Tensor: """ Get OpenCV extrinsics matrix looking at something Args: eye (torch.Tensor): [..., 3] the eye position look_at (torch.Tensor): [..., 3] the position to look at up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction Returns: (torch.Tensor): [..., 4, 4], extrinsics matrix """ N = eye.shape[0] z = look_at - eye x = torch.cross(-up, z, dim=-1) y = torch.cross(z, x, dim=-1) # x = torch.cross(y, z, dim=-1) x = x / x.norm(dim=-1, keepdim=True) y = y / y.norm(dim=-1, keepdim=True) z = z / z.norm(dim=-1, keepdim=True) R = torch.stack([x, y, z], dim=-2) t = -torch.matmul(R, eye[..., None]) ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) ret[:, :3, :3] = R ret[:, :3, 3] = t[:, :, 0] ret[:, 3, 3] = 1. return ret @batched(2) def perspective_to_intrinsics( perspective: torch.Tensor ) -> torch.Tensor: """ OpenGL perspective matrix to OpenCV intrinsics Args: perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix Returns: (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics """ assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix" ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \ @ perspective[:, [0, 1, 3], :3] \ @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device)) return ret / ret[:, 2, 2, None, None] @batched(2,0,0) def intrinsics_to_perspective( intrinsics: torch.Tensor, near: Union[float, torch.Tensor], far: Union[float, torch.Tensor], ) -> torch.Tensor: """ OpenCV intrinsics to OpenGL perspective matrix Args: intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix near (float | torch.Tensor): [...] near plane to clip far (float | torch.Tensor): [...] far plane to clip Returns: (torch.Tensor): [..., 4, 4] OpenGL perspective matrix """ N = intrinsics.shape[0] fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device) ret[:, 0, 0] = 2 * fx ret[:, 1, 1] = 2 * fy ret[:, 0, 2] = -2 * cx + 1 ret[:, 1, 2] = 2 * cy - 1 ret[:, 2, 2] = (near + far) / (near - far) ret[:, 2, 3] = 2. * near * far / (near - far) ret[:, 3, 2] = -1. return ret @batched(2) def extrinsics_to_view( extrinsics: torch.Tensor ) -> torch.Tensor: """ OpenCV camera extrinsics to OpenGL view matrix Args: extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix Returns: (torch.Tensor): [..., 4, 4] OpenGL view matrix """ return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None] @batched(2) def view_to_extrinsics( view: torch.Tensor ) -> torch.Tensor: """ OpenGL view matrix to OpenCV camera extrinsics Args: view (torch.Tensor): [..., 4, 4] OpenGL view matrix Returns: (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix """ return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None] @batched(2,0,0) def normalize_intrinsics( intrinsics: torch.Tensor, width: Union[int, torch.Tensor], height: Union[int, torch.Tensor] ) -> torch.Tensor: """ Normalize camera intrinsics(s) to uv space Args: intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize width (int | torch.Tensor): [...] image width(s) height (int | torch.Tensor): [...] image height(s) Returns: (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s) """ zeros = torch.zeros_like(width) ones = torch.ones_like(width) transform = torch.stack([ 1 / width, zeros, 0.5 / width, zeros, 1 / height, 0.5 / height, zeros, zeros, ones ]).reshape(*zeros.shape, 3, 3).to(intrinsics) return transform @ intrinsics @batched(2,0,0,0,0,0,0) def crop_intrinsics( intrinsics: torch.Tensor, width: Union[int, torch.Tensor], height: Union[int, torch.Tensor], left: Union[int, torch.Tensor], top: Union[int, torch.Tensor], crop_width: Union[int, torch.Tensor], crop_height: Union[int, torch.Tensor] ) -> torch.Tensor: """ Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] Args: intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop width (int | torch.Tensor): [...] image width(s) height (int | torch.Tensor): [...] image height(s) left (int | torch.Tensor): [...] left crop boundary top (int | torch.Tensor): [...] top crop boundary crop_width (int | torch.Tensor): [...] crop width crop_height (int | torch.Tensor): [...] crop height Returns: (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s) """ zeros = torch.zeros_like(width) ones = torch.ones_like(width) transform = torch.stack([ width / crop_width, zeros, -left / crop_width, zeros, height / crop_height, -top / crop_height, zeros, zeros, ones ]).reshape(*zeros.shape, 3, 3).to(intrinsics) return transform @ intrinsics @batched(1,0,0) def pixel_to_uv( pixel: torch.Tensor, width: Union[int, torch.Tensor], height: Union[int, torch.Tensor] ) -> torch.Tensor: """ Args: pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) width (int | torch.Tensor): [...] image width(s) height (int | torch.Tensor): [...] image height(s) Returns: (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) """ if not torch.is_floating_point(pixel): pixel = pixel.float() uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel) return uv @batched(1,0,0) def uv_to_pixel( uv: torch.Tensor, width: Union[int, torch.Tensor], height: Union[int, torch.Tensor] ) -> torch.Tensor: """ Args: uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) width (int | torch.Tensor): [...] image width(s) height (int | torch.Tensor): [...] image height(s) Returns: (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) """ pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5 return pixel @batched(1,0,0) def pixel_to_ndc( pixel: torch.Tensor, width: Union[int, torch.Tensor], height: Union[int, torch.Tensor] ) -> torch.Tensor: """ Args: pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) width (int | torch.Tensor): [...] image width(s) height (int | torch.Tensor): [...] image height(s) Returns: (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) """ if not torch.is_floating_point(pixel): pixel = pixel.float() ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \ + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device) return ndc @batched(0,0,0) def project_depth( depth: torch.Tensor, near: Union[float, torch.Tensor], far: Union[float, torch.Tensor] ) -> torch.Tensor: """ Project linear depth to depth value in screen space Args: depth (torch.Tensor): [...] depth value near (float | torch.Tensor): [...] near plane to clip far (float | torch.Tensor): [...] far plane to clip Returns: (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1] """ return (far - near * far / depth) / (far - near) @batched(0,0,0) def depth_buffer_to_linear( depth: torch.Tensor, near: Union[float, torch.Tensor], far: Union[float, torch.Tensor] ) -> torch.Tensor: """ Linearize depth value to linear depth Args: depth (torch.Tensor): [...] screen depth value, ranging in [0, 1] near (float | torch.Tensor): [...] near plane to clip far (float | torch.Tensor): [...] far plane to clip Returns: (torch.Tensor): [...] linear depth """ return near * far / (far - (far - near) * depth) @batched(2, 2, 2, 2) def project_gl( points: torch.Tensor, model: torch.Tensor = None, view: torch.Tensor = None, perspective: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Project 3D points to 2D following the OpenGL convention (except for row major matrice) Args: points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last dimension is 4, the points are assumed to be in homogeneous coordinates model (torch.Tensor): [..., 4, 4] model matrix view (torch.Tensor): [..., 4, 4] view matrix perspective (torch.Tensor): [..., 4, 4] perspective matrix Returns: scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1]. The origin (0., 0., 0.) is corresponding to the left & bottom & nearest linear_depth (torch.Tensor): [..., N] linear depth """ assert perspective is not None, "perspective matrix is required" if points.shape[-1] == 3: points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) mvp = perspective if perspective is not None else torch.eye(4).to(points) if view is not None: mvp = mvp @ view if model is not None: mvp = mvp @ model clip_coord = points @ mvp.transpose(-1, -2) ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] scr_coord = ndc_coord * 0.5 + 0.5 linear_depth = clip_coord[..., 3] return scr_coord, linear_depth @batched(2, 2, 2) def project_cv( points: torch.Tensor, extrinsics: torch.Tensor = None, intrinsics: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Project 3D points to 2D following the OpenCV convention Args: points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last dimension is 4, the points are assumed to be in homogeneous coordinates extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix Returns: uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. The origin (0., 0.) is corresponding to the left & top linear_depth (torch.Tensor): [..., N] linear depth """ assert intrinsics is not None, "intrinsics matrix is required" if points.shape[-1] == 3: points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) if extrinsics is not None: points = points @ extrinsics.transpose(-1, -2) points = points[..., :3] @ intrinsics.transpose(-2, -1) uv_coord = points[..., :2] / points[..., 2:] linear_depth = points[..., 2] return uv_coord, linear_depth @batched(2, 2, 2, 2) def unproject_gl( screen_coord: torch.Tensor, model: torch.Tensor = None, view: torch.Tensor = None, perspective: torch.Tensor = None ) -> torch.Tensor: """ Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) Args: screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1]. The origin (0., 0., 0.) is corresponding to the left & bottom & nearest model (torch.Tensor): [..., 4, 4] model matrix view (torch.Tensor): [..., 4, 4] view matrix perspective (torch.Tensor): [..., 4, 4] perspective matrix Returns: points (torch.Tensor): [..., N, 3] 3d points """ assert perspective is not None, "perspective matrix is required" ndc_xy = screen_coord * 2 - 1 clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1) transform = perspective if view is not None: transform = transform @ view if model is not None: transform = transform @ model transform = torch.inverse(transform) points = clip_coord @ transform.transpose(-1, -2) points = points[..., :3] / points[..., 3:] return points @batched(2, 1, 2, 2) def unproject_cv( uv_coord: torch.Tensor, depth: torch.Tensor, extrinsics: torch.Tensor = None, intrinsics: torch.Tensor = None ) -> torch.Tensor: """ Unproject uv coordinates to 3D view space following the OpenCV convention Args: uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. The origin (0., 0.) is corresponding to the left & top depth (torch.Tensor): [..., N] depth value extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix Returns: points (torch.Tensor): [..., N, 3] 3d points """ assert intrinsics is not None, "intrinsics matrix is required" points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1) points = points @ torch.inverse(intrinsics).transpose(-2, -1) points = points * depth[..., None] if extrinsics is not None: points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3] return points def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: """ Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. Args: axis: Axis label "X" or "Y or "Z". angle: any shape tensor of Euler angles in radians Returns: Rotation matrices as tensor of shape (..., 3, 3). """ cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) else: raise ValueError("letter must be either X, Y or Z.") return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor: """ Convert rotations given as Euler angles in radians to rotation matrices. Args: euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: raise ValueError("Invalid input euler angles.") if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = [ euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) for c in convention ] # return functools.reduce(torch.matmul, matrices) return matrices[2] @ matrices[1] @ matrices[0] def skew_symmetric(v: torch.Tensor): "Skew symmetric matrix from a 3D vector" assert v.shape[-1] == 3, "v must be 3D" x, y, z = v.unbind(dim=-1) zeros = torch.zeros_like(x) return torch.stack([ zeros, -z, y, z, zeros, -x, -y, x, zeros, ], dim=-1).reshape(*v.shape[:-1], 3, 3) def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor): "Rotation matrix that rotates v1 to v2" I = torch.eye(3).to(v1) v1 = F.normalize(v1, dim=-1) v2 = F.normalize(v2, dim=-1) v = torch.cross(v1, v2, dim=-1) c = torch.sum(v1 * v2, dim=-1) K = skew_symmetric(v) R = I + K + (1 / (1 + c))[None, None] * (K @ K) return R def _angle_from_tan( axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool ) -> torch.Tensor: """ Extract the first or third Euler angle from the two members of the matrix which are positive constant times its sine and cosine. Args: axis: Axis label "X" or "Y or "Z" for the angle we are finding. other_axis: Axis label "X" or "Y or "Z" for the middle axis in the convention. data: Rotation matrices as tensor of shape (..., 3, 3). horizontal: Whether we are looking for the angle for the third axis, which means the relevant entries are in the same row of the rotation matrix. If not, they are in the same column. tait_bryan: Whether the first and third axes in the convention differ. Returns: Euler Angles in radians for each matrix in data as a tensor of shape (...). """ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] if horizontal: i2, i1 = i1, i2 even = (axis + other_axis) in ["XY", "YZ", "ZX"] if horizontal == even: return torch.atan2(data[..., i1], data[..., i2]) if tait_bryan: return torch.atan2(-data[..., i2], data[..., i1]) return torch.atan2(data[..., i2], -data[..., i1]) def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: """ Convert rotations given as rotation matrices to Euler angles in radians. NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d) Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). convention: Convention string of three uppercase letters. Returns: Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d) """ if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'): raise ValueError(f"Invalid convention {convention}.") if not matrix.shape[-2:] == (3, 3): raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") i0 = 'XYZ'.index(convention[0]) i2 = 'XYZ'.index(convention[2]) tait_bryan = i0 != i2 if tait_bryan: central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0)) else: central_angle = torch.acos(matrix[..., i2, i2]) # Angles in composition order o = [ _angle_from_tan( convention[0], convention[1], matrix[..., i2, :], True, tait_bryan ), central_angle, _angle_from_tan( convention[2], convention[1], matrix[..., i0], False, tait_bryan ), ] return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1) def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation Args: axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors Returns: torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters """ batch_shape = axis_angle.shape[:-1] device, dtype = axis_angle.device, axis_angle.dtype angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True) axis = axis_angle / angle cos = torch.cos(angle)[..., None, :] sin = torch.sin(angle)[..., None, :] rx, ry, rz = torch.split(axis, 3, dim=-1) zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device) rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K) return rot_mat def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector) Args: rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert Returns: torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices """ quat = matrix_to_quaternion(rot_mat) axis_angle = quaternion_to_axis_angle(quat, eps=eps) return axis_angle def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector) Args: quaternion (torch.Tensor): shape (..., 4), the quaternions to convert Returns: torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions """ assert quaternion.shape[-1] == 4 norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True) axis = quaternion[..., 1:] / norm.clamp(min=eps) angle = 2 * torch.atan2(norm, quaternion[..., 0:1]) return angle * axis def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z) Args: axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors Returns: torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters """ axis = F.normalize(axis_angle, dim=-1, eps=eps) angle = torch.norm(axis_angle, dim=-1, keepdim=True) quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1) return quat def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """Convert 3x3 rotation matrix to quaternion (w, x, y, z) Args: rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert Returns: torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices """ # Extract the diagonal and off-diagonal elements of the rotation matrix m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1) diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1) M = torch.tensor([ [1, 1, 1], [1, -1, -1], [-1, 1, -1], [-1, -1, 1] ], dtype=rot_mat.dtype, device=rot_mat.device) wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5) _, max_idx = wxyz.max(dim=-1) xw = torch.sign(m21 - m12) yw = torch.sign(m02 - m20) zw = torch.sign(m10 - m01) yz = torch.sign(m21 + m12) xz = torch.sign(m02 + m20) xy = torch.sign(m01 + m10) ones = torch.ones_like(xw) sign = torch.where( max_idx[..., None] == 0, torch.stack([ones, xw, yw, zw], dim=-1), torch.where( max_idx[..., None] == 1, torch.stack([xw, ones, xy, xz], dim=-1), torch.where( max_idx[..., None] == 2, torch.stack([yw, xy, ones, yz], dim=-1), torch.stack([zw, xz, yz, ones], dim=-1) ) ) ) quat = sign * wxyz quat = F.normalize(quat, dim=-1, eps=eps) return quat def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: """Converts a batch of quaternions (w, x, y, z) to rotation matrices Args: quaternion (torch.Tensor): shape (..., 4), the quaternions to convert Returns: torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions """ assert quaternion.shape[-1] == 4 quaternion = F.normalize(quaternion, dim=-1, eps=eps) w, x, y, z = quaternion.unbind(dim=-1) zeros = torch.zeros_like(w) I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device) xyz = quaternion[..., 1:] A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None] B = torch.stack([ zeros, -z, y, z, zeros, -x, -y, x, zeros ], dim=-1).unflatten(-1, (3, 3)) rot_mat = I + 2 * (A + w[..., None, None] * B) return rot_mat def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: """Spherical linear interpolation between two rotation matrices Args: rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix t (torch.Tensor): scalar or shape (...,), the interpolation factor Returns: torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix """ assert rot_mat_1.shape[-2:] == (3, 3) rot_vec_1 = matrix_to_axis_angle(rot_mat_1) rot_vec_2 = matrix_to_axis_angle(rot_mat_2) if isinstance(t, Number): t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device) rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2 rot_mat = axis_angle_to_matrix(rot_vec) return rot_mat def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. Args: ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose t (torch.Tensor): scalar or shape (...,), the interpolation factor Returns: torch.Tensor: shape (..., 4, 4), the interpolated camera pose """ return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t)) def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]): """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. Args: ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose t (torch.Tensor): scalar or shape (...,), the interpolation factor Returns: torch.Tensor: shape (..., 4, 4), the interpolated camera pose """ return interpolate_extrinsics(view1, view2, t) def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]): assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4) if isinstance(t, Number): t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device) pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3] rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t) transform = torch.cat([rot, pos[..., None]], dim=-1) transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2) return transform def extrinsics_to_essential(extrinsics: torch.Tensor): """ extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` Args: extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix Returns: (torch.Tensor): [..., 3, 3] essential matrix """ assert extrinsics.shape[-2:] == (4, 4) R = extrinsics[..., :3, :3] t = extrinsics[..., :3, 3] zeros = torch.zeros_like(t) t_x = torch.stack([ zeros, -t[..., 2], t[..., 1], t[..., 2], zeros, -t[..., 0], -t[..., 1], t[..., 0], zeros ]).reshape(*t.shape[:-1], 3, 3) return R @ t_x def to4x4(R: torch.Tensor, t: torch.Tensor): """ Compose rotation matrix and translation vector to 4x4 transformation matrix Args: R (torch.Tensor): [..., 3, 3] rotation matrix t (torch.Tensor): [..., 3] translation vector Returns: (torch.Tensor): [..., 4, 4] transformation matrix """ assert R.shape[-2:] == (3, 3) assert t.shape[-1] == 3 assert R.shape[:-2] == t.shape[:-1] return torch.cat([ torch.cat([R, t[..., None]], dim=-1), torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4) ], dim=-2) def rotation_matrix_2d(theta: Union[float, torch.Tensor]): """ 2x2 matrix for 2D rotation Args: theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) Returns: (torch.Tensor): (..., 2, 2) rotation matrix """ if isinstance(theta, float): theta = torch.tensor(theta) return torch.stack([ torch.cos(theta), -torch.sin(theta), torch.sin(theta), torch.cos(theta), ], dim=-1).unflatten(-1, (2, 2)) def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None): """ 3x3 matrix for 2D rotation around a center ``` [[Rxx, Rxy, tx], [Ryx, Ryy, ty], [0, 0, 1]] ``` Args: theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0) Returns: (torch.Tensor): (..., 3, 3) transformation matrix """ if isinstance(theta, float): theta = torch.tensor(theta) if center is not None: theta = theta.to(center) if center is None: center = torch.zeros(2).to(theta).expand(*theta.shape, -1) R = rotation_matrix_2d(theta) return torch.cat([ torch.cat([ R, center[..., :, None] - R @ center[..., :, None], ], dim=-1), torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1), ], dim=-2) def translate_2d(translation: torch.Tensor): """ Translation matrix for 2D translation ``` [[1, 0, tx], [0, 1, ty], [0, 0, 1]] ``` Args: translation (torch.Tensor): translation vector, arbitrary shape (..., 2) Returns: (torch.Tensor): (..., 3, 3) transformation matrix """ return torch.cat([ torch.cat([ torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), translation[..., None], ], dim=-1), torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), ], dim=-2) def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None): """ Scale matrix for 2D scaling ``` [[s, 0, tx], [0, s, ty], [0, 0, 1]] ``` Args: scale (float | torch.Tensor): scale factor, arbitrary shape (...,) center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0) Returns: (torch.Tensor): (..., 3, 3) transformation matrix """ if isinstance(scale, float): scale = torch.tensor(scale) if center is not None: scale = scale.to(center) if center is None: center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1) return torch.cat([ torch.cat([ scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1), center[..., :, None] - center[..., :, None] * scale[..., None, None], ], dim=-1), torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1), ], dim=-2) def apply_2d(transform: torch.Tensor, points: torch.Tensor): """ Apply (3x3 or 2x3) 2D affine transformation to points ``` p = R @ p + t ``` Args: transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix points (torch.Tensor): (..., N, 2) points to transform Returns: (torch.Tensor): (..., N, 2) transformed points """ assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3" assert points.shape[-1] == 2, "points must be 2D" return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2]