import torch from typing import Optional from .transforms import real_orient_mat2q, real_orient_q2mat def update_params_after_orient_rotation( poses : torch.Tensor, # (B, 46) rot_mat : torch.Tensor, # the rotation orientation matrix root_offset : Optional[torch.Tensor] = None, # the offset from custom root to model root ): ''' ### Args - `poses`: torch.Tensor, shape = (B, 46) - `rot_mat`: torch.Tensor, shape = (B, 3, 3) - `root_offset`: torch.Tensor or None, shape = (B, 3) - If None, the function won't update the translation. - If not None, the function will calculate the root translation offset that make the model rotate around the custom root instead of the model root. ### Returns - If `root_offset` is None: - `poses`: torch.Tensor, shape = (B, 46) - If `root_offset` is not None: - `poses`: torch.Tensor, shape = (B, 46) - `trans_offset`: torch.Tensor, shape = (B, 3) ''' poses = poses.clone() # 1. Transform the SKEL orientation to real matrix. orient_q = poses[:, :3] # (B, 3) orient_mat = real_orient_q2mat(orient_q) # (B, 3, 3) orient_mat = torch.einsum('bij,bjk->bik', rot_mat, orient_mat) # (B, 3, 3) orient_q = real_orient_mat2q(orient_mat) # (B, 3) poses[:, :3] = orient_q # 2. Update the translation if needed. if root_offset is not None: root_before = root_offset.clone() # (B, 3) root_after = torch.einsum('bij,bj->bi', rot_mat, root_before) # (B, 3) root_offset = root_after - root_before # (B, 3) ret = poses, root_offset else: ret = poses return ret