File size: 2,998 Bytes
78c7556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# - "body_transl_delta_pelv"
# - "body_orient_xy"
# - "z_orient_delta"
# - "body_pose"
# - "body_joints_local_wo_z_rot"
from transform3d import transform_body_pose, change_for, remove_z_rot, get_z_rot, rot_diff
from einops import rearrange
import torch

def to_tensor(array):
    if torch.is_tensor(array):
        return array
    else:
        return torch.tensor(array)

def _get_body_transl_delta_pelv(data):
    """
    get body pelvis tranlation delta relative to pelvis coord.frame
    v_i = t_i - t_{i-1} relative to R_{i-1}
    """
    trans = to_tensor(data['trans'])
    trans_vel = trans - trans.roll(1, 0)  # shift one right and subtract
    pelvis_orient = transform_body_pose(to_tensor(data['rots'][..., :3]), "aa->rot")
    trans_vel_pelv = change_for(trans_vel, pelvis_orient.roll(1, 0))
    trans_vel_pelv[0] = 0  # zero out velocity of first frame
    return trans_vel_pelv

def _get_body_orient_xy(data):
    """get body global orientation"""
    # default is axis-angle representation
    pelvis_orient = to_tensor(data['rots'][..., :3])
    # if rot_repr == "6d":
    # axis-angle to rotation matrix & drop last row
    pelvis_orient_xy = remove_z_rot(pelvis_orient, in_format="aa")
    return pelvis_orient_xy

def _get_body_pose(data):
    """get body pose"""
    # default is axis-angle representation: Frames x (Jx3) (J=21)
    pose = to_tensor(data['rots'][..., 3:3 + 21*3])  # drop pelvis orientation
    pose = transform_body_pose(pose, f"aa->6d")
    return pose

def _get_body_joints_local_wo_z_rot(data):
    """get body joint coordinates relative to the pelvis"""
    joints = to_tensor(data['joint_positions'][:, :22, :])
    pelvis_transl = to_tensor(joints[:, 0, :])
    joints_glob = to_tensor(joints[:, :22, :])
    pelvis_orient = to_tensor(data['rots'][..., :3])

    pelvis_orient_z = get_z_rot(pelvis_orient, in_format="aa")
    # pelvis_orient_z = transform_body_pose(pelvis_orient_z, "aa->rot").float()
    # relative_joints = R.T @ (p_global - pelvis_translation)
    rel_joints = torch.einsum('fdi,fjd->fji',
                                pelvis_orient_z,
                                joints_glob - pelvis_transl[:, None, :])

    return rearrange(rel_joints, '... j c -> ... (j c)')

def _get_z_orient_delta(data):
    """get global body orientation delta"""
    # default is axis-angle representation
    pelvis_orient = to_tensor(data['rots'][..., :3])
    pelvis_orient_z = get_z_rot(pelvis_orient, in_format="aa")
    pelvis_orient_z = transform_body_pose(pelvis_orient_z, "rot->aa")
    z_orient_delta = rot_diff(pelvis_orient_z, in_format="aa",
                                    out_format='6d')
    return z_orient_delta

FEAT_GET_METHODS = {
            "body_transl_delta_pelv": _get_body_transl_delta_pelv,
            "body_orient_xy": _get_body_orient_xy,
            "z_orient_delta": _get_z_orient_delta,
            "body_pose": _get_body_pose,
            "body_joints_local_wo_z_rot": _get_body_joints_local_wo_z_rot,
        }