Spaces:
Running
Running
import torch | |
from transform3d import transform_body_pose, apply_rot_delta, get_z_rot, change_for | |
def diffout2motion(diffout, normalizer): | |
# - "body_transl_delta_pelv_xy_wo_z" | |
# - "body_transl_z" | |
# - "z_orient_delta" | |
# - "body_orient_xy" | |
# - "body_pose" | |
# - "body_joints_local_wo_z_rot" | |
feats_unnorm = normalizer.cat_inputs(normalizer.unnorm_inputs( | |
normalizer.uncat_inputs(diffout, | |
normalizer.input_feats_dims), | |
normalizer.input_feats))[0] | |
# FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION | |
if "body_joints_local_wo_z_rot" in normalizer.input_feats: | |
idx = normalizer.input_feats.index("body_joints_local_wo_z_rot") | |
feats_unnorm = feats_unnorm[..., :-normalizer.input_feats_dims[idx]] | |
first_trans = torch.zeros(*diffout.shape[:-1], 3, | |
device='cuda')[:, [0]] | |
if 'z_orient_delta' in normalizer.input_feats: | |
first_orient_z = torch.eye(3, device='cuda').unsqueeze(0) # Now the shape is (1, 1, 3, 3) | |
first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3) | |
first_orient_z = transform_body_pose(first_orient_z, 'rot->6d') | |
# --> first_orient_z convert to 6d | |
# integrate z orient delta --> z component tof orientation | |
z_orient_delta = feats_unnorm[..., 9:15] | |
prev_z = first_orient_z | |
full_z_angle = [first_orient_z[:, None]] | |
for i in range(1, z_orient_delta.shape[1]): | |
curr_z = apply_rot_delta(prev_z, z_orient_delta[:, i]) | |
prev_z = curr_z.clone() | |
full_z_angle.append(curr_z[:,None]) | |
full_z_angle = torch.cat(full_z_angle, dim=1) | |
full_z_angle_rotmat = get_z_rot(full_z_angle) | |
# full_orient = torch.cat([full_z_angle, xy_orient], dim=-1) | |
xy_orient = feats_unnorm[..., 3:9] | |
xy_orient_rotmat = transform_body_pose(xy_orient, '6d->rot') | |
# xy_orient = remove_z_rot(xy_orient, in_format="6d") | |
# GLOBAL ORIENTATION | |
# full_z_angle = transform_body_pose(full_z_angle_rotmat, | |
# 'rot->6d') | |
# full_global_orient = apply_rot_delta(full_z_angle, | |
# xy_orient) | |
full_global_orient_rotmat = full_z_angle_rotmat @ xy_orient_rotmat | |
full_global_orient = transform_body_pose(full_global_orient_rotmat, | |
'rot->6d') | |
first_trans = normalizer.cat_inputs(normalizer.unnorm_inputs( | |
[first_trans], | |
['body_transl']) | |
)[0] | |
# apply deltas | |
# get velocity in global c.f. and add it to the state position | |
assert 'body_transl_delta_pelv' in normalizer.input_feats | |
pelvis_delta = feats_unnorm[..., :3] | |
trans_vel_pelv = change_for(pelvis_delta[:, 1:], | |
full_global_orient_rotmat[:, :-1], | |
forward=False) | |
# new_state_pos = prev_trans_norm.squeeze() + trans_vel_pelv | |
full_trans = torch.cumsum(trans_vel_pelv, dim=1) + first_trans | |
full_trans = torch.cat([first_trans, full_trans], dim=1) | |
# "body_transl_delta_pelv_xy_wo_z" | |
# first_trans = self.cat_inputs(self.unnorm_inputs( | |
# [first_trans], | |
# ['body_transl']) | |
# )[0] | |
# pelvis_xy = pelvis_delta_xy | |
# FULL TRANSLATION | |
# full_trans = torch.cat([pelvis_xy, | |
# feats_unnorm[..., 2:3][:,1:]], dim=-1) | |
############# | |
full_rots = torch.cat([full_global_orient, | |
feats_unnorm[...,-21*6:]], | |
dim=-1) | |
full_motion_unnorm = torch.cat([full_trans, | |
full_rots], dim=-1) | |
return full_motion_unnorm | |