import torch from transform3d import transform_body_pose, apply_rot_delta, get_z_rot, change_for def diffout2motion(diffout, normalizer): # - "body_transl_delta_pelv_xy_wo_z" # - "body_transl_z" # - "z_orient_delta" # - "body_orient_xy" # - "body_pose" # - "body_joints_local_wo_z_rot" feats_unnorm = normalizer.cat_inputs(normalizer.unnorm_inputs( normalizer.uncat_inputs(diffout, normalizer.input_feats_dims), normalizer.input_feats))[0] # FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION if "body_joints_local_wo_z_rot" in normalizer.input_feats: idx = normalizer.input_feats.index("body_joints_local_wo_z_rot") feats_unnorm = feats_unnorm[..., :-normalizer.input_feats_dims[idx]] first_trans = torch.zeros(*diffout.shape[:-1], 3, device='cuda')[:, [0]] if 'z_orient_delta' in normalizer.input_feats: first_orient_z = torch.eye(3, device='cuda').unsqueeze(0) # Now the shape is (1, 1, 3, 3) first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3) first_orient_z = transform_body_pose(first_orient_z, 'rot->6d') # --> first_orient_z convert to 6d # integrate z orient delta --> z component tof orientation z_orient_delta = feats_unnorm[..., 9:15] prev_z = first_orient_z full_z_angle = [first_orient_z[:, None]] for i in range(1, z_orient_delta.shape[1]): curr_z = apply_rot_delta(prev_z, z_orient_delta[:, i]) prev_z = curr_z.clone() full_z_angle.append(curr_z[:,None]) full_z_angle = torch.cat(full_z_angle, dim=1) full_z_angle_rotmat = get_z_rot(full_z_angle) # full_orient = torch.cat([full_z_angle, xy_orient], dim=-1) xy_orient = feats_unnorm[..., 3:9] xy_orient_rotmat = transform_body_pose(xy_orient, '6d->rot') # xy_orient = remove_z_rot(xy_orient, in_format="6d") # GLOBAL ORIENTATION # full_z_angle = transform_body_pose(full_z_angle_rotmat, # 'rot->6d') # full_global_orient = apply_rot_delta(full_z_angle, # xy_orient) full_global_orient_rotmat = full_z_angle_rotmat @ xy_orient_rotmat full_global_orient = transform_body_pose(full_global_orient_rotmat, 'rot->6d') first_trans = normalizer.cat_inputs(normalizer.unnorm_inputs( [first_trans], ['body_transl']) )[0] # apply deltas # get velocity in global c.f. and add it to the state position assert 'body_transl_delta_pelv' in normalizer.input_feats pelvis_delta = feats_unnorm[..., :3] trans_vel_pelv = change_for(pelvis_delta[:, 1:], full_global_orient_rotmat[:, :-1], forward=False) # new_state_pos = prev_trans_norm.squeeze() + trans_vel_pelv full_trans = torch.cumsum(trans_vel_pelv, dim=1) + first_trans full_trans = torch.cat([first_trans, full_trans], dim=1) # "body_transl_delta_pelv_xy_wo_z" # first_trans = self.cat_inputs(self.unnorm_inputs( # [first_trans], # ['body_transl']) # )[0] # pelvis_xy = pelvis_delta_xy # FULL TRANSLATION # full_trans = torch.cat([pelvis_xy, # feats_unnorm[..., 2:3][:,1:]], dim=-1) ############# full_rots = torch.cat([full_global_orient, feats_unnorm[...,-21*6:]], dim=-1) full_motion_unnorm = torch.cat([full_trans, full_rots], dim=-1) return full_motion_unnorm