File size: 11,809 Bytes
d49f7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations  # so we can refer to class Type inside class
import logging
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np
import numpy.typing as npt

from animated_drawings.model.transform import Transform
from animated_drawings.model.box import Box
from animated_drawings.model.quaternions import Quaternions
from animated_drawings.model.vectors import Vectors
from animated_drawings.model.joint import Joint
from animated_drawings.model.time_manager import TimeManager
from animated_drawings.utils import resolve_ad_filepath


class BVH_Joint(Joint):
    """
        Joint class with channel order attribute and specialized vis widget
    """
    def __init__(self, channel_order: List[str] = [], widget: bool = True, **kwargs) -> None:
        super().__init__(**kwargs)

        self.channel_order = channel_order

        self.widget: Optional[Transform] = None
        if widget:
            self.widget = Box()
            self.add_child(self.widget)

    def _draw(self, **kwargs):
        if self.widget:
            self.widget.draw(**kwargs)


class BVH(Transform, TimeManager):
    """
    Class to encapsulate BVH (Biovision Hierarchy) animation data.
    Include a single skeletal hierarchy defined in the BVH, frame count and speed,
    and skeletal pos/rot data for each frame
    """

    def __init__(self,
                 name: str,
                 root_joint: BVH_Joint,
                 frame_max_num: int,
                 frame_time: float,
                 pos_data: npt.NDArray[np.float32],
                 rot_data: npt.NDArray[np.float32]
                 ) -> None:
        """
        Don't recommend calling this method directly.  Instead, use BVH.from_file().
        """
        super().__init__()

        self.name: str = name
        self.frame_max_num: int = frame_max_num
        self.frame_time: float = frame_time
        self.pos_data: npt.NDArray[np.float32] = pos_data
        self.rot_data: npt.NDArray[np.float32] = rot_data

        self.root_joint = root_joint
        self.add_child(self.root_joint)
        self.joint_num = self.root_joint.joint_count()

        self.cur_frame = 0  # initialize skeleton pose to first frame
        self.apply_frame(self.cur_frame)

    def get_joint_names(self) -> List[str]:
        """ Get names of joints in skeleton in the order in which BVH rotation data is stored. """
        return self.root_joint.get_chain_joint_names()

    def update(self) -> None:
        """Based upon internal time, determine which frame should be displayed and apply it"""
        cur_time: float = self.get_time()
        cur_frame = round(cur_time / self.frame_time) % self.frame_max_num
        self.apply_frame(cur_frame)

    def apply_frame(self, frame_num: int) -> None:
        """ Apply root position and joint rotation data for specified frame_num """
        self.root_joint.set_position(self.pos_data[frame_num])
        self._apply_frame_rotations(self.root_joint, frame_num, ptr=np.array(0))

    def _apply_frame_rotations(self, joint: BVH_Joint, frame_num: int, ptr: npt.NDArray[np.int32]) -> None:
        q = Quaternions(self.rot_data[frame_num, ptr])
        joint.set_rotation(q)

        ptr += 1

        for c in joint.get_children():
            if not isinstance(c, BVH_Joint):
                continue
            self._apply_frame_rotations(c, frame_num, ptr)

    def get_skeleton_fwd(self, forward_perp_vector_joint_names: List[Tuple[str, str]], update: bool = True) -> Vectors:
        """
        Get current forward vector of skeleton in world coords. If update=True, ensure skeleton transforms are current.
        Input forward_perp_vector_joint_names, a list of pairs of joint names (e.g. [[leftshould, rightshoulder], [lefthip, righthip]])
        Finds average of vectors between joint pairs, then returns vector perpendicular to their average.
        """
        if update:
            self.root_joint.update_transforms(update_ancestors=True)

        vectors_cw_perpendicular_to_fwd: List[Vectors] = []
        for (start_joint_name, end_joint_name) in forward_perp_vector_joint_names:
            start_joint = self.root_joint.get_transform_by_name(start_joint_name)
            if not start_joint:
                msg = f'Could not find BVH joint with name: {start_joint_name}'
                logging.critical(msg)
                assert False, msg

            end_joint = self.root_joint.get_transform_by_name(end_joint_name)
            if not end_joint:
                msg = f'Could not find BVH joint with name: {end_joint_name}'
                logging.critical(msg)
                assert False, msg

            bone_vector: Vectors = Vectors(end_joint.get_world_position()) - Vectors(start_joint.get_world_position())
            bone_vector.norm()
            vectors_cw_perpendicular_to_fwd.append(bone_vector)

        return Vectors(vectors_cw_perpendicular_to_fwd).average().perpendicular()

    @classmethod
    def from_file(cls, bvh_fn: str, start_frame_idx: int = 0, end_frame_idx: Optional[int] = None) -> BVH:
        """ Given a path to a .bvh, constructs and returns BVH object"""

        # search for the BVH file specified
        bvh_p: Path = resolve_ad_filepath(bvh_fn, 'bvh file')
        logging.info(f'Using BVH file located at {bvh_p.resolve()}')

        with open(str(bvh_p), 'r') as f:
            lines = f.read().splitlines()

        if lines.pop(0) != 'HIERARCHY':
            msg = f'Malformed BVH in line preceding {lines}'
            logging.critical(msg)
            assert False, msg

        # Parse the skeleton
        root_joint: BVH_Joint = BVH._parse_skeleton(lines)

        if lines.pop(0) != 'MOTION':
            msg = f'Malformed BVH in line preceding {lines}'
            logging.critical(msg)
            assert False, msg

        # Parse motion metadata
        frame_max_num = int(lines.pop(0).split(':')[-1])
        frame_time = float(lines.pop(0).split(':')[-1])

        # Parse motion data
        frames = [list(map(float, line.strip().split(' '))) for line in lines]
        if len(frames) != frame_max_num:
            msg = f'framenum specified ({frame_max_num}) and found ({len(frames)}) do not match'
            logging.critical(msg)
            assert False, msg

        # Split logically distinct root position data from joint euler angle rotation data
        pos_data: npt.NDArray[np.float32]
        rot_data: npt.NDArray[np.float32]
        pos_data, rot_data = BVH._process_frame_data(root_joint, frames)

        # Set end_frame if not passed in
        if not end_frame_idx:
            end_frame_idx = frame_max_num

        # Ensure end_frame_idx <= frame_max_num
        if frame_max_num < end_frame_idx:
            msg = f'config specified end_frame_idx > bvh frame_max_num ({end_frame_idx} > {frame_max_num}). Replacing with frame_max_num.'
            logging.warning(msg)
            end_frame_idx = frame_max_num

        # slice position and rotation data using start and end frame indices
        pos_data = pos_data[start_frame_idx:end_frame_idx, :]
        rot_data = rot_data[start_frame_idx:end_frame_idx, :]

        # new frame_max_num based is end_frame_idx minus start_frame_idx
        frame_max_num = end_frame_idx - start_frame_idx

        return BVH(bvh_p.name, root_joint, frame_max_num, frame_time, pos_data, rot_data)

    @classmethod
    def _parse_skeleton(cls, lines: List[str]) -> BVH_Joint:
        """
        Called recursively to parse and construct skeleton from BVH
        :param lines: partially-processed contents of BVH file. Is modified in-place.
        :return: Joint
        """

        # Get the joint name
        if lines[0].strip().startswith('ROOT'):
            _, joint_name = lines.pop(0).strip().split(' ')
        elif lines[0].strip().startswith('JOINT'):
            _, joint_name = lines.pop(0).strip().split(' ')
        elif lines[0].strip().startswith('End Site'):
            joint_name = lines.pop(0).strip()
        else:
            msg = f'Malformed BVH. Line: {lines[0]}'
            logging.critical(msg)
            assert False, msg

        if lines.pop(0).strip() != '{':
            msg = f'Malformed BVH in line preceding {lines}'
            logging.critical(msg)
            assert False, msg

        # Get offset
        if not lines[0].strip().startswith('OFFSET'):
            msg = f'Malformed BVH in line preceding {lines}'
            logging.critical(msg)
            assert False, msg
        _, *xyz = lines.pop(0).strip().split(' ')
        offset = Vectors(list(map(float, xyz)))

        # Get channels
        if lines[0].strip().startswith('CHANNELS'):
            channel_order = lines.pop(0).strip().split(' ')
            _, channel_num, *channel_order = channel_order
        else:
            channel_num, channel_order = 0, []
        if int(channel_num) != len(channel_order):
            msg = f'Malformed BVH in line preceding {lines}'
            logging.critical(msg)
            assert False, msg

        # Recurse for children
        children: List[BVH_Joint] = []
        while lines[0].strip() != '}':
            children.append(BVH._parse_skeleton(lines))
        lines.pop(0)  # }

        return BVH_Joint(name=joint_name, offset=offset, channel_order=channel_order, children=children)

    @classmethod
    def _process_frame_data(cls, skeleton: BVH_Joint, frames: List[List[float]]) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
        """ Given skeleton and frame data, return root position data and joint quaternion data, separately"""

        def _get_frame_channel_order(joint: BVH_Joint, channels=[]):
            channels.extend(joint.channel_order)
            for child in [child for child in joint.get_children() if isinstance(child, BVH_Joint)]:
                _get_frame_channel_order(child, channels)
            return channels
        channels = _get_frame_channel_order(skeleton)

        # create a mask so we retain only joint rotations and root position
        mask = np.array(list(map(lambda x: True if 'rotation' in x else False, channels)))
        mask[:3] = True  # hack to make sure we keep root position

        frames = np.array(frames, dtype=np.float32)[:, mask]

        # split root pose data and joint euler angle data
        pos_data, ea_rots = np.split(np.array(frames, dtype=np.float32), [3], axis=1)

        # quaternion rot data will go here
        rot_data = np.empty([len(frames), skeleton.joint_count(), 4], dtype=np.float32)
        BVH._pose_ea_to_q(skeleton, ea_rots, rot_data)

        return pos_data, rot_data

    @classmethod
    def _pose_ea_to_q(cls, joint: BVH_Joint, ea_rots: npt.NDArray[np.float32], q_rots: npt.NDArray[np.float32], p1: int = 0, p2: int = 0) -> Tuple[int, int]:
        """
        Given joint and array of euler angle rotation data, converts to quaternions and stores in q_rots.
        Only called by _process_frame_data(). Modifies q_rots inplace.
        :param p1: pointer to find where in ea_rots to read euler angles from
        :param p2: pointer to determine where in q_rots to input quaternion
        """
        axis_chars = "".join([c[0].lower() for c in joint.channel_order if c.endswith('rotation')])  # e.g. 'xyz'

        q_rots[:, p2] = Quaternions.from_euler_angles(axis_chars, ea_rots[:, p1:p1+len(axis_chars)]).qs
        p1 += len(axis_chars)
        p2 += 1

        for child in joint.get_children():
            if isinstance(child, BVH_Joint):
                p1, p2 = BVH._pose_ea_to_q(child, ea_rots, q_rots, p1, p2)

        return p1, p2