File size: 4,491 Bytes
d8530c7
777e3d5
d8530c7
7d87cc1
d8530c7
 
 
 
 
 
 
7d87cc1
 
 
 
d8530c7
7d87cc1
 
 
d8530c7
 
7d87cc1
 
 
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d87cc1
d8530c7
 
 
 
 
 
7d87cc1
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from transform3d import transform_body_pose, apply_rot_delta, get_z_rot, change_for

def diffout2motion(diffout, normalizer):

        # - "body_transl_delta_pelv_xy_wo_z"
        # - "body_transl_z"
        # - "z_orient_delta"
        # - "body_orient_xy"
        # - "body_pose"
        # - "body_joints_local_wo_z_rot"
        feats_unnorm = normalizer.cat_inputs(normalizer.unnorm_inputs(
                                        normalizer.uncat_inputs(diffout,
                                            normalizer.input_feats_dims),
                                        normalizer.input_feats))[0]
        # FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION
        if "body_joints_local_wo_z_rot" in normalizer.input_feats:
            idx = normalizer.input_feats.index("body_joints_local_wo_z_rot")
            feats_unnorm = feats_unnorm[..., :-normalizer.input_feats_dims[idx]]

        first_trans = torch.zeros(*diffout.shape[:-1], 3,
                                    device='cuda')[:, [0]]
        if 'z_orient_delta' in normalizer.input_feats:
            first_orient_z = torch.eye(3, device='cuda').unsqueeze(0)  # Now the shape is (1, 1, 3, 3)
            first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1)  # Now the shape is (B, 1, 3, 3)
            first_orient_z = transform_body_pose(first_orient_z, 'rot->6d') 

            # --> first_orient_z convert to 6d
            # integrate z orient delta --> z component tof orientation
            z_orient_delta = feats_unnorm[..., 9:15]

            prev_z = first_orient_z 
            full_z_angle = [first_orient_z[:, None]]
            for i in range(1, z_orient_delta.shape[1]):
                curr_z = apply_rot_delta(prev_z, z_orient_delta[:, i])
                prev_z = curr_z.clone()
                full_z_angle.append(curr_z[:,None])
            full_z_angle = torch.cat(full_z_angle, dim=1)
            full_z_angle_rotmat = get_z_rot(full_z_angle)
            # full_orient = torch.cat([full_z_angle, xy_orient], dim=-1)
            xy_orient = feats_unnorm[..., 3:9]
            xy_orient_rotmat = transform_body_pose(xy_orient, '6d->rot')
            # xy_orient = remove_z_rot(xy_orient, in_format="6d")

            # GLOBAL ORIENTATION
            # full_z_angle = transform_body_pose(full_z_angle_rotmat,
            #                                    'rot->6d')

            # full_global_orient = apply_rot_delta(full_z_angle,
            #                                      xy_orient)
            full_global_orient_rotmat = full_z_angle_rotmat @ xy_orient_rotmat
            full_global_orient = transform_body_pose(full_global_orient_rotmat,
                                                        'rot->6d')

            first_trans = normalizer.cat_inputs(normalizer.unnorm_inputs(
                                                    [first_trans],
                                                    ['body_transl'])
                                            )[0]

            # apply deltas
            # get velocity in global c.f. and add it to the state position
            assert 'body_transl_delta_pelv' in normalizer.input_feats
            pelvis_delta = feats_unnorm[..., :3]
            trans_vel_pelv = change_for(pelvis_delta[:, 1:],
                                        full_global_orient_rotmat[:, :-1],
                                        forward=False)

            # new_state_pos = prev_trans_norm.squeeze() + trans_vel_pelv
            full_trans = torch.cumsum(trans_vel_pelv, dim=1) + first_trans
            full_trans = torch.cat([first_trans, full_trans], dim=1)

            #  "body_transl_delta_pelv_xy_wo_z"
            # first_trans = self.cat_inputs(self.unnorm_inputs(
            #                                         [first_trans],
            #                                         ['body_transl'])
            #                                 )[0]

            # pelvis_xy = pelvis_delta_xy
            # FULL TRANSLATION
            # full_trans = torch.cat([pelvis_xy, 
            #                         feats_unnorm[..., 2:3][:,1:]], dim=-1)  
            #############
            full_rots = torch.cat([full_global_orient, 
                                    feats_unnorm[...,-21*6:]],
                                    dim=-1)
            full_motion_unnorm = torch.cat([full_trans,
                                            full_rots], dim=-1)

        return full_motion_unnorm