Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import annotations # so we can refer to class Type inside class | |
import logging | |
from pathlib import Path | |
from typing import List, Tuple, Optional | |
import numpy as np | |
import numpy.typing as npt | |
from animated_drawings.model.transform import Transform | |
from animated_drawings.model.box import Box | |
from animated_drawings.model.quaternions import Quaternions | |
from animated_drawings.model.vectors import Vectors | |
from animated_drawings.model.joint import Joint | |
from animated_drawings.model.time_manager import TimeManager | |
from animated_drawings.utils import resolve_ad_filepath | |
class BVH_Joint(Joint): | |
""" | |
Joint class with channel order attribute and specialized vis widget | |
""" | |
def __init__(self, channel_order: List[str] = [], widget: bool = True, **kwargs) -> None: | |
super().__init__(**kwargs) | |
self.channel_order = channel_order | |
self.widget: Optional[Transform] = None | |
if widget: | |
self.widget = Box() | |
self.add_child(self.widget) | |
def _draw(self, **kwargs): | |
if self.widget: | |
self.widget.draw(**kwargs) | |
class BVH(Transform, TimeManager): | |
""" | |
Class to encapsulate BVH (Biovision Hierarchy) animation data. | |
Include a single skeletal hierarchy defined in the BVH, frame count and speed, | |
and skeletal pos/rot data for each frame | |
""" | |
def __init__(self, | |
name: str, | |
root_joint: BVH_Joint, | |
frame_max_num: int, | |
frame_time: float, | |
pos_data: npt.NDArray[np.float32], | |
rot_data: npt.NDArray[np.float32] | |
) -> None: | |
""" | |
Don't recommend calling this method directly. Instead, use BVH.from_file(). | |
""" | |
super().__init__() | |
self.name: str = name | |
self.frame_max_num: int = frame_max_num | |
self.frame_time: float = frame_time | |
self.pos_data: npt.NDArray[np.float32] = pos_data | |
self.rot_data: npt.NDArray[np.float32] = rot_data | |
self.root_joint = root_joint | |
self.add_child(self.root_joint) | |
self.joint_num = self.root_joint.joint_count() | |
self.cur_frame = 0 # initialize skeleton pose to first frame | |
self.apply_frame(self.cur_frame) | |
def get_joint_names(self) -> List[str]: | |
""" Get names of joints in skeleton in the order in which BVH rotation data is stored. """ | |
return self.root_joint.get_chain_joint_names() | |
def update(self) -> None: | |
"""Based upon internal time, determine which frame should be displayed and apply it""" | |
cur_time: float = self.get_time() | |
cur_frame = round(cur_time / self.frame_time) % self.frame_max_num | |
self.apply_frame(cur_frame) | |
def apply_frame(self, frame_num: int) -> None: | |
""" Apply root position and joint rotation data for specified frame_num """ | |
self.root_joint.set_position(self.pos_data[frame_num]) | |
self._apply_frame_rotations(self.root_joint, frame_num, ptr=np.array(0)) | |
def _apply_frame_rotations(self, joint: BVH_Joint, frame_num: int, ptr: npt.NDArray[np.int32]) -> None: | |
q = Quaternions(self.rot_data[frame_num, ptr]) | |
joint.set_rotation(q) | |
ptr += 1 | |
for c in joint.get_children(): | |
if not isinstance(c, BVH_Joint): | |
continue | |
self._apply_frame_rotations(c, frame_num, ptr) | |
def get_skeleton_fwd(self, forward_perp_vector_joint_names: List[Tuple[str, str]], update: bool = True) -> Vectors: | |
""" | |
Get current forward vector of skeleton in world coords. If update=True, ensure skeleton transforms are current. | |
Input forward_perp_vector_joint_names, a list of pairs of joint names (e.g. [[leftshould, rightshoulder], [lefthip, righthip]]) | |
Finds average of vectors between joint pairs, then returns vector perpendicular to their average. | |
""" | |
if update: | |
self.root_joint.update_transforms(update_ancestors=True) | |
vectors_cw_perpendicular_to_fwd: List[Vectors] = [] | |
for (start_joint_name, end_joint_name) in forward_perp_vector_joint_names: | |
start_joint = self.root_joint.get_transform_by_name(start_joint_name) | |
if not start_joint: | |
msg = f'Could not find BVH joint with name: {start_joint_name}' | |
logging.critical(msg) | |
assert False, msg | |
end_joint = self.root_joint.get_transform_by_name(end_joint_name) | |
if not end_joint: | |
msg = f'Could not find BVH joint with name: {end_joint_name}' | |
logging.critical(msg) | |
assert False, msg | |
bone_vector: Vectors = Vectors(end_joint.get_world_position()) - Vectors(start_joint.get_world_position()) | |
bone_vector.norm() | |
vectors_cw_perpendicular_to_fwd.append(bone_vector) | |
return Vectors(vectors_cw_perpendicular_to_fwd).average().perpendicular() | |
def from_file(cls, bvh_fn: str, start_frame_idx: int = 0, end_frame_idx: Optional[int] = None) -> BVH: | |
""" Given a path to a .bvh, constructs and returns BVH object""" | |
# search for the BVH file specified | |
bvh_p: Path = resolve_ad_filepath(bvh_fn, 'bvh file') | |
logging.info(f'Using BVH file located at {bvh_p.resolve()}') | |
with open(str(bvh_p), 'r') as f: | |
lines = f.read().splitlines() | |
if lines.pop(0) != 'HIERARCHY': | |
msg = f'Malformed BVH in line preceding {lines}' | |
logging.critical(msg) | |
assert False, msg | |
# Parse the skeleton | |
root_joint: BVH_Joint = BVH._parse_skeleton(lines) | |
if lines.pop(0) != 'MOTION': | |
msg = f'Malformed BVH in line preceding {lines}' | |
logging.critical(msg) | |
assert False, msg | |
# Parse motion metadata | |
frame_max_num = int(lines.pop(0).split(':')[-1]) | |
frame_time = float(lines.pop(0).split(':')[-1]) | |
# Parse motion data | |
frames = [list(map(float, line.strip().split(' '))) for line in lines] | |
if len(frames) != frame_max_num: | |
msg = f'framenum specified ({frame_max_num}) and found ({len(frames)}) do not match' | |
logging.critical(msg) | |
assert False, msg | |
# Split logically distinct root position data from joint euler angle rotation data | |
pos_data: npt.NDArray[np.float32] | |
rot_data: npt.NDArray[np.float32] | |
pos_data, rot_data = BVH._process_frame_data(root_joint, frames) | |
# Set end_frame if not passed in | |
if not end_frame_idx: | |
end_frame_idx = frame_max_num | |
# Ensure end_frame_idx <= frame_max_num | |
if frame_max_num < end_frame_idx: | |
msg = f'config specified end_frame_idx > bvh frame_max_num ({end_frame_idx} > {frame_max_num}). Replacing with frame_max_num.' | |
logging.warning(msg) | |
end_frame_idx = frame_max_num | |
# slice position and rotation data using start and end frame indices | |
pos_data = pos_data[start_frame_idx:end_frame_idx, :] | |
rot_data = rot_data[start_frame_idx:end_frame_idx, :] | |
# new frame_max_num based is end_frame_idx minus start_frame_idx | |
frame_max_num = end_frame_idx - start_frame_idx | |
return BVH(bvh_p.name, root_joint, frame_max_num, frame_time, pos_data, rot_data) | |
def _parse_skeleton(cls, lines: List[str]) -> BVH_Joint: | |
""" | |
Called recursively to parse and construct skeleton from BVH | |
:param lines: partially-processed contents of BVH file. Is modified in-place. | |
:return: Joint | |
""" | |
# Get the joint name | |
if lines[0].strip().startswith('ROOT'): | |
_, joint_name = lines.pop(0).strip().split(' ') | |
elif lines[0].strip().startswith('JOINT'): | |
_, joint_name = lines.pop(0).strip().split(' ') | |
elif lines[0].strip().startswith('End Site'): | |
joint_name = lines.pop(0).strip() | |
else: | |
msg = f'Malformed BVH. Line: {lines[0]}' | |
logging.critical(msg) | |
assert False, msg | |
if lines.pop(0).strip() != '{': | |
msg = f'Malformed BVH in line preceding {lines}' | |
logging.critical(msg) | |
assert False, msg | |
# Get offset | |
if not lines[0].strip().startswith('OFFSET'): | |
msg = f'Malformed BVH in line preceding {lines}' | |
logging.critical(msg) | |
assert False, msg | |
_, *xyz = lines.pop(0).strip().split(' ') | |
offset = Vectors(list(map(float, xyz))) | |
# Get channels | |
if lines[0].strip().startswith('CHANNELS'): | |
channel_order = lines.pop(0).strip().split(' ') | |
_, channel_num, *channel_order = channel_order | |
else: | |
channel_num, channel_order = 0, [] | |
if int(channel_num) != len(channel_order): | |
msg = f'Malformed BVH in line preceding {lines}' | |
logging.critical(msg) | |
assert False, msg | |
# Recurse for children | |
children: List[BVH_Joint] = [] | |
while lines[0].strip() != '}': | |
children.append(BVH._parse_skeleton(lines)) | |
lines.pop(0) # } | |
return BVH_Joint(name=joint_name, offset=offset, channel_order=channel_order, children=children) | |
def _process_frame_data(cls, skeleton: BVH_Joint, frames: List[List[float]]) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: | |
""" Given skeleton and frame data, return root position data and joint quaternion data, separately""" | |
def _get_frame_channel_order(joint: BVH_Joint, channels=[]): | |
channels.extend(joint.channel_order) | |
for child in [child for child in joint.get_children() if isinstance(child, BVH_Joint)]: | |
_get_frame_channel_order(child, channels) | |
return channels | |
channels = _get_frame_channel_order(skeleton) | |
# create a mask so we retain only joint rotations and root position | |
mask = np.array(list(map(lambda x: True if 'rotation' in x else False, channels))) | |
mask[:3] = True # hack to make sure we keep root position | |
frames = np.array(frames, dtype=np.float32)[:, mask] | |
# split root pose data and joint euler angle data | |
pos_data, ea_rots = np.split(np.array(frames, dtype=np.float32), [3], axis=1) | |
# quaternion rot data will go here | |
rot_data = np.empty([len(frames), skeleton.joint_count(), 4], dtype=np.float32) | |
BVH._pose_ea_to_q(skeleton, ea_rots, rot_data) | |
return pos_data, rot_data | |
def _pose_ea_to_q(cls, joint: BVH_Joint, ea_rots: npt.NDArray[np.float32], q_rots: npt.NDArray[np.float32], p1: int = 0, p2: int = 0) -> Tuple[int, int]: | |
""" | |
Given joint and array of euler angle rotation data, converts to quaternions and stores in q_rots. | |
Only called by _process_frame_data(). Modifies q_rots inplace. | |
:param p1: pointer to find where in ea_rots to read euler angles from | |
:param p2: pointer to determine where in q_rots to input quaternion | |
""" | |
axis_chars = "".join([c[0].lower() for c in joint.channel_order if c.endswith('rotation')]) # e.g. 'xyz' | |
q_rots[:, p2] = Quaternions.from_euler_angles(axis_chars, ea_rots[:, p1:p1+len(axis_chars)]).qs | |
p1 += len(axis_chars) | |
p2 += 1 | |
for child in joint.get_children(): | |
if isinstance(child, BVH_Joint): | |
p1, p2 = BVH._pose_ea_to_q(child, ea_rots, q_rots, p1, p2) | |
return p1, p2 | |