Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import annotations # so we can refer to class Type inside class | |
import numpy as np | |
import numpy.typing as npt | |
from animated_drawings.model.vectors import Vectors | |
from animated_drawings.model.quaternions import Quaternions | |
import logging | |
from typing import Union, Optional, List, Tuple | |
class Transform(): | |
"""Base class from which all other scene objects descend""" | |
def __init__(self, | |
parent: Optional[Transform] = None, | |
name: Optional[str] = None, | |
children: List[Transform] = [], | |
offset: Union[npt.NDArray[np.float32], Vectors, None] = None, | |
**kwargs | |
) -> None: | |
super().__init__(**kwargs) | |
self._parent: Optional[Transform] = parent | |
self._children: List[Transform] = [] | |
for child in children: | |
self.add_child(child) | |
self.name: Optional[str] = name | |
self._translate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
self._rotate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
self._scale_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
if offset is not None: | |
self.offset(offset) | |
self._local_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
self._world_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
self.dirty_bit: bool = True # are world/local transforms stale? | |
def update_transforms(self, parent_dirty_bit: bool = False, recurse_on_children: bool = True, update_ancestors: bool = False) -> None: | |
""" | |
Updates transforms if stale. | |
If own dirty bit is set, recompute local matrix | |
If own or parent's dirty bit is set, recompute world matrix | |
If own or parent's dirty bit is set, recurses on children, unless param recurse_on_children is false. | |
If update_ancestors is true, first find first ancestor, then call update_transforms upon it. | |
Set dirty bit back to false. | |
""" | |
if update_ancestors: | |
ancestor, ancestor_parent = self, self.get_parent() | |
while ancestor_parent is not None: | |
ancestor, ancestor_parent = ancestor_parent, ancestor_parent.get_parent() | |
ancestor.update_transforms() | |
if self.dirty_bit: | |
self.compute_local_transform() | |
if self.dirty_bit | parent_dirty_bit: | |
self.compute_world_transform() | |
if recurse_on_children: | |
for c in self.get_children(): | |
c.update_transforms(self.dirty_bit | parent_dirty_bit) | |
self.dirty_bit = False | |
def compute_local_transform(self) -> None: | |
self._local_transform = self._translate_m @ self._rotate_m @ self._scale_m | |
def compute_world_transform(self) -> None: | |
self._world_transform = self._local_transform | |
if self._parent: | |
self._world_transform = self._parent._world_transform @ self._world_transform | |
def get_world_transform(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]: | |
""" | |
Get the transform's world matrix. | |
If update is true, check to ensure the world_transform is current | |
""" | |
if update_ancestors: | |
self.update_transforms(update_ancestors=True) | |
return np.copy(self._world_transform) | |
def set_scale(self, scale: float) -> None: | |
self._scale_m[:-1, :-1] = scale * np.identity(3, dtype=np.float32) | |
self.dirty_bit = True | |
def set_position(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None: | |
""" Set the absolute values of the translational elements of transform """ | |
if isinstance(pos, Vectors): | |
pos = pos.vs | |
if pos.shape == (1, 3): | |
pos = np.squeeze(pos) | |
elif pos.shape == (3,): | |
pass | |
else: | |
msg = f'bad vector dim passed to set_position. Found: {pos.shape}' | |
logging.critical(msg) | |
assert False, msg | |
self._translate_m[:-1, -1] = pos | |
self.dirty_bit = True | |
def get_local_position(self) -> npt.NDArray[np.float32]: | |
""" Ensure local transform is up-to-date and return local xyz coordinates """ | |
if self.dirty_bit: | |
self.compute_local_transform() | |
return np.copy(self._local_transform[:-1, -1]) | |
def get_world_position(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]: | |
""" | |
Ensure all parent transforms are update and return world xyz coordinates | |
If update_ancestor_transforms is true, update ancestor transforms to ensure | |
up-to-date world_transform before returning | |
""" | |
if update_ancestors: | |
self.update_transforms(update_ancestors=True) | |
return np.copy(self._world_transform[:-1, -1]) | |
def offset(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None: | |
""" Translational offset by the specified amount """ | |
if isinstance(pos, Vectors): | |
pos = pos.vs[0] | |
assert isinstance(pos, np.ndarray) | |
self.set_position(self._translate_m[:-1, -1] + pos) | |
def look_at(self, fwd_: Union[npt.NDArray[np.float32], Vectors, None]) -> None: | |
"""Given a forward vector, rotate the transform to face that position""" | |
if fwd_ is None: | |
fwd_ = Vectors(self.get_world_position()) | |
elif isinstance(fwd_, np.ndarray): | |
fwd_ = Vectors(fwd_) | |
fwd: Vectors = fwd_.copy() # norming will change the vector | |
if fwd.vs.shape != (1, 3): | |
msg = f'look_at fwd_ vector must have shape [1,3]. Found: {fwd.vs.shape}' | |
logging.critical(msg) | |
assert False, msg | |
tmp: Vectors = Vectors([0.0, 1.0, 0.0]) | |
# if fwd and tmp are same vector, modify tmp to avoid collapse | |
if np.isclose(fwd.vs, tmp.vs).all() or np.isclose(fwd.vs, -tmp.vs).all(): | |
tmp.vs[0] += 0.001 | |
right: Vectors = tmp.cross(fwd) | |
up: Vectors = fwd.cross(right) | |
fwd.norm() | |
right.norm() | |
up.norm() | |
rotate_m = np.identity(4, dtype=np.float32) | |
rotate_m[:-1, 0] = np.squeeze(right.vs) | |
rotate_m[:-1, 1] = np.squeeze(up.vs) | |
rotate_m[:-1, 2] = np.squeeze(fwd.vs) | |
self._rotate_m = rotate_m | |
self.dirty_bit = True | |
def get_right_up_fwd_vectors(self) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]: | |
inverted: npt.NDArray[np.float32] = np.linalg.inv(self.get_world_transform()) | |
right: npt.NDArray[np.float32] = inverted[:-1, 0] | |
up: npt.NDArray[np.float32] = inverted[:-1, 1] | |
fwd: npt.NDArray[np.float32] = inverted[:-1, 2] | |
return right, up, fwd | |
def set_rotation(self, q: Quaternions) -> None: | |
if q.qs.shape != (1, 4): | |
msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}' | |
logging.critical(msg) | |
assert False, msg | |
self._rotate_m = q.to_rotation_matrix() | |
self.dirty_bit = True | |
def rotation_offset(self, q: Quaternions) -> None: | |
if q.qs.shape != (1, 4): | |
msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}' | |
logging.critical(msg) | |
assert False, msg | |
self._rotate_m = (q * Quaternions.from_rotation_matrix(self._rotate_m)).to_rotation_matrix() | |
self.dirty_bit = True | |
def add_child(self, child: Transform) -> None: | |
self._children.append(child) | |
child.set_parent(self) | |
def get_children(self) -> List[Transform]: | |
return self._children | |
def set_parent(self, parent: Transform) -> None: | |
self._parent = parent | |
self.dirty_bit = True | |
def get_parent(self) -> Optional[Transform]: | |
return self._parent | |
def get_transform_by_name(self, name: str) -> Optional[Transform]: | |
""" Search self and children for transform with matching name. Return it if found, None otherwise. """ | |
# are we match? | |
if self.name == name: | |
return self | |
# recurse to check if a child is match | |
for c in self.get_children(): | |
transform_or_none = c.get_transform_by_name(name) | |
if transform_or_none: # if we found it | |
return transform_or_none | |
# no match | |
return None | |
def draw(self, recurse: bool = True, **kwargs) -> None: | |
""" Draw this transform and recurse on children """ | |
self._draw(**kwargs) | |
if recurse: | |
for child in self.get_children(): | |
child.draw(**kwargs) | |
def _draw(self, **kwargs) -> None: | |
"""Transforms default to not being drawn. Subclasses must implement how they appear""" | |