motionfix-demo / geometry_utils.py
atnikos's picture
fix demo changes
777e3d5
import torch
from transform3d import transform_body_pose, apply_rot_delta, get_z_rot, change_for
def diffout2motion(diffout, normalizer):
# - "body_transl_delta_pelv_xy_wo_z"
# - "body_transl_z"
# - "z_orient_delta"
# - "body_orient_xy"
# - "body_pose"
# - "body_joints_local_wo_z_rot"
feats_unnorm = normalizer.cat_inputs(normalizer.unnorm_inputs(
normalizer.uncat_inputs(diffout,
normalizer.input_feats_dims),
normalizer.input_feats))[0]
# FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION
if "body_joints_local_wo_z_rot" in normalizer.input_feats:
idx = normalizer.input_feats.index("body_joints_local_wo_z_rot")
feats_unnorm = feats_unnorm[..., :-normalizer.input_feats_dims[idx]]
first_trans = torch.zeros(*diffout.shape[:-1], 3,
device='cuda')[:, [0]]
if 'z_orient_delta' in normalizer.input_feats:
first_orient_z = torch.eye(3, device='cuda').unsqueeze(0) # Now the shape is (1, 1, 3, 3)
first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3)
first_orient_z = transform_body_pose(first_orient_z, 'rot->6d')
# --> first_orient_z convert to 6d
# integrate z orient delta --> z component tof orientation
z_orient_delta = feats_unnorm[..., 9:15]
prev_z = first_orient_z
full_z_angle = [first_orient_z[:, None]]
for i in range(1, z_orient_delta.shape[1]):
curr_z = apply_rot_delta(prev_z, z_orient_delta[:, i])
prev_z = curr_z.clone()
full_z_angle.append(curr_z[:,None])
full_z_angle = torch.cat(full_z_angle, dim=1)
full_z_angle_rotmat = get_z_rot(full_z_angle)
# full_orient = torch.cat([full_z_angle, xy_orient], dim=-1)
xy_orient = feats_unnorm[..., 3:9]
xy_orient_rotmat = transform_body_pose(xy_orient, '6d->rot')
# xy_orient = remove_z_rot(xy_orient, in_format="6d")
# GLOBAL ORIENTATION
# full_z_angle = transform_body_pose(full_z_angle_rotmat,
# 'rot->6d')
# full_global_orient = apply_rot_delta(full_z_angle,
# xy_orient)
full_global_orient_rotmat = full_z_angle_rotmat @ xy_orient_rotmat
full_global_orient = transform_body_pose(full_global_orient_rotmat,
'rot->6d')
first_trans = normalizer.cat_inputs(normalizer.unnorm_inputs(
[first_trans],
['body_transl'])
)[0]
# apply deltas
# get velocity in global c.f. and add it to the state position
assert 'body_transl_delta_pelv' in normalizer.input_feats
pelvis_delta = feats_unnorm[..., :3]
trans_vel_pelv = change_for(pelvis_delta[:, 1:],
full_global_orient_rotmat[:, :-1],
forward=False)
# new_state_pos = prev_trans_norm.squeeze() + trans_vel_pelv
full_trans = torch.cumsum(trans_vel_pelv, dim=1) + first_trans
full_trans = torch.cat([first_trans, full_trans], dim=1)
# "body_transl_delta_pelv_xy_wo_z"
# first_trans = self.cat_inputs(self.unnorm_inputs(
# [first_trans],
# ['body_transl'])
# )[0]
# pelvis_xy = pelvis_delta_xy
# FULL TRANSLATION
# full_trans = torch.cat([pelvis_xy,
# feats_unnorm[..., 2:3][:,1:]], dim=-1)
#############
full_rots = torch.cat([full_global_orient,
feats_unnorm[...,-21*6:]],
dim=-1)
full_motion_unnorm = torch.cat([full_trans,
full_rots], dim=-1)
return full_motion_unnorm