Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import annotations # so we can refer to class Type inside class | |
| import numpy as np | |
| import numpy.typing as npt | |
| import logging | |
| from typing import Union, Iterable, List, Tuple | |
| from animated_drawings.model.vectors import Vectors | |
| import math | |
| from animated_drawings.utils import TOLERANCE | |
| from functools import reduce | |
| class Quaternions: | |
| """ | |
| Wrapper class around ndarray interpreted as one or more quaternions. Quaternion order is [w, x, y, z] | |
| When passing in existing Quaternions, new Quaternions object will share the underlying nparray, so be careful. | |
| Strongly influenced by Daniel Holden's excellent Quaternions class. | |
| """ | |
| def __init__(self, qs: Union[Iterable[Union[int, float]], npt.NDArray[np.float32], Quaternions]) -> None: | |
| self.qs: npt.NDArray[np.float32] | |
| if isinstance(qs, np.ndarray): | |
| if not qs.shape[-1] == 4: | |
| msg = f'Final dimension passed to Quaternions must be 4. Found {qs.shape[-1]}' | |
| logging.critical(msg) | |
| assert False, msg | |
| if len(qs.shape) == 1: | |
| qs = np.expand_dims(qs, axis=0) | |
| self.qs = qs | |
| elif isinstance(qs, tuple) or isinstance(qs, list): | |
| try: | |
| qs = np.array(qs) | |
| assert qs.shape[-1] == 4 | |
| except Exception: | |
| msg = 'Could not convert quaternion data to ndarray with shape[-1] == 4' | |
| logging.critical(msg) | |
| assert False, msg | |
| if len(qs.shape) == 1: | |
| qs = np.expand_dims(qs, axis=0) | |
| self.qs = qs | |
| elif isinstance(qs, Quaternions): | |
| self.qs = qs.qs | |
| else: | |
| msg = 'Quaternions must be constructed from Quaternions or numpy array' | |
| logging.critical(msg) | |
| assert False, msg | |
| self.normalize() | |
| def normalize(self) -> None: | |
| self.qs = self.qs / np.expand_dims(np.sum(self.qs ** 2.0, axis=-1) ** 0.5, axis=-1) | |
| def to_rotation_matrix(self) -> npt.NDArray[np.float32]: | |
| """ | |
| From Ken Shoemake | |
| https://www.ljll.math.upmc.fr/~frey/papers/scientific%20visualisation/Shoemake%20K.,%20Quaternions.pdf | |
| :return: 4x4 rotation matrix representation of quaternions | |
| """ | |
| w = self.qs[..., 0].squeeze() | |
| x = self.qs[..., 1].squeeze() | |
| y = self.qs[..., 2].squeeze() | |
| z = self.qs[..., 3].squeeze() | |
| xx, yy, zz = x**2, y**2, z**2 | |
| wx, wy, wz = w*x, w*y, w*z | |
| xy, xz = x*y, x*z # no | |
| yz = y*z | |
| # Row 1 | |
| r00 = 1 - 2 * (yy + zz) | |
| r01 = 2 * (xy - wz) | |
| r02 = 2 * (xz + wy) | |
| # Row 2 | |
| r10 = 2 * (xy + wz) | |
| r11 = 1 - 2 * (xx + zz) | |
| r12 = 2 * (yz - wx) | |
| # Row 3 | |
| r20 = 2 * (xz - wy) | |
| r21 = 2 * (yz + wx) | |
| r22 = 1 - 2 * (xx + yy) | |
| return np.array([[r00, r01, r02, 0.0], | |
| [r10, r11, r12, 0.0], | |
| [r20, r21, r22, 0.0], | |
| [0.0, 0.0, 0.0, 1.0]], dtype=np.float32) | |
| def rotate_between_vectors(cls, v1: Vectors, v2: Vectors) -> Quaternions: | |
| """ Computes quaternion rotating from v1 to v2. """ | |
| xyz: List[float] = v1.cross(v2).vs.squeeze().tolist() | |
| w: float = math.sqrt((v1.length**2) * (v2.length**2)) + np.dot(v1.vs.squeeze(), v2.vs.squeeze()) | |
| ret_q = Quaternions([w, *xyz]) | |
| ret_q.normalize() | |
| return ret_q | |
| def from_angle_axis(cls, angles: npt.NDArray[np.float32], axes: Vectors) -> Quaternions: | |
| axes.norm() | |
| if len(angles.shape) == 1: | |
| angles = np.expand_dims(angles, axis=0) | |
| ss = np.sin(angles / 2.0) | |
| cs = np.cos(angles / 2.0) | |
| return Quaternions(np.concatenate([cs, axes.vs * ss], axis=-1)) | |
| def identity(cls, ret_shape: Tuple[int]) -> Quaternions: | |
| qs = np.broadcast_to(np.array([1.0, 0.0, 0.0, 0.0]), [*ret_shape, 4]) | |
| return Quaternions(qs) | |
| def from_euler_angles(cls, order: str, angles: npt.NDArray[np.float32]) -> Quaternions: | |
| """ | |
| Applies a series of euler angle rotations. Angles applied from right to left | |
| :param order: string comprised of x, y, and/or z | |
| :param angles: angles in degrees | |
| """ | |
| if len(angles.shape) == 1: | |
| angles = np.expand_dims(angles, axis=0) | |
| if len(order) != angles.shape[-1]: | |
| msg = 'length of orders and angles does not match' | |
| logging.critical(msg) | |
| assert False, msg | |
| _quats = [Quaternions.identity(angles.shape[:-1])] | |
| for axis_char, pos in zip(order, range(len(order))): | |
| angle = angles[..., pos] * np.pi / 180 | |
| angle = np.expand_dims(angle, axis=1) | |
| axis_char = axis_char.lower() | |
| if axis_char not in 'xyz': | |
| msg = f'order contained unsupported char:{axis_char}' | |
| logging.critical(msg) | |
| assert False, msg | |
| axis = np.zeros([*angles.shape[:-1], 3]) | |
| axis[..., ord(axis_char) - ord('x')] = 1.0 | |
| _quats.insert(0, Quaternions.from_angle_axis(angle, Vectors(axis))) | |
| ret_q = reduce(lambda a, b: b * a, _quats) | |
| return ret_q | |
| def from_rotation_matrix(cls, M: npt.NDArray[np.float32]) -> Quaternions: | |
| """ | |
| As described here: https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf | |
| """ | |
| is_orthogonal = np.isclose(M @ M.T, np.identity(4), atol=TOLERANCE) | |
| if not is_orthogonal.all(): | |
| msg = "attempted to create quaternion from non-orthogonal rotation matrix" | |
| logging.critical(msg) | |
| assert False, msg | |
| if not np.isclose(np.linalg.det(M), 1.0): | |
| msg = "attempted to create quaternion from rotation matrix with det != 1" | |
| logging.critical(msg) | |
| assert False, msg | |
| # Note: Mike Day's article uses row vectors, whereas we used column, so here use transpose of matrix | |
| MT = M.T | |
| m00, m01, m02 = MT[0, 0], MT[0, 1], MT[0, 2] | |
| m10, m11, m12 = MT[1, 0], MT[1, 1], MT[1, 2] | |
| m20, m21, m22 = MT[2, 0], MT[2, 1], MT[2, 2] | |
| if m22 < 0: | |
| if m00 > m11: | |
| t = 1 + m00 - m11 - m22 | |
| q = np.array([m12-m21, t, m01+m10, m20+m02]) | |
| else: | |
| t = 1 - m00 + m11 - m22 | |
| q = np.array([m20-m02, m01+m10, t, m12+m21]) | |
| else: | |
| if m00 < -m11: | |
| t = 1 - m00 - m11 + m22 | |
| q = np.array([m01-m10, m20+m02, m12+m21, t]) | |
| else: | |
| t = 1 + m00 + m11 + m22 | |
| q = np.array([ t, m12-m21, m20-m02, m01-m10]) | |
| q *= (0.5 / math.sqrt(t)) | |
| ret_q = Quaternions(q) | |
| ret_q.normalize() | |
| return ret_q | |
| def __mul__(self, other: Quaternions): | |
| """ | |
| From https://danceswithcode.net/engineeringnotes/quaternions/quaternions.html | |
| """ | |
| s0 = self.qs[..., 0] | |
| s1 = self.qs[..., 1] | |
| s2 = self.qs[..., 2] | |
| s3 = self.qs[..., 3] | |
| r0 = other.qs[..., 0] | |
| r1 = other.qs[..., 1] | |
| r2 = other.qs[..., 2] | |
| r3 = other.qs[..., 3] | |
| t = np.empty(self.qs.shape) | |
| t[..., 0] = r0*s0 - r1*s1 - r2*s2 - r3*s3 | |
| t[..., 1] = r0*s1 + r1*s0 - r2*s3 + r3*s2 | |
| t[..., 2] = r0*s2 + r1*s3 + r2*s0 - r3*s1 | |
| t[..., 3] = r0*s3 - r1*s2 + r2*s1 + r3*s0 | |
| return Quaternions(t) | |
| def __neg__(self): | |
| return Quaternions(self.qs * np.array([1, -1, -1, -1])) | |
| def __str__(self): | |
| return f"Quaternions({str(self.qs)})" | |
| def __repr__(self): | |
| return f"Quaternions({str(self.qs)})" | |