bill-jiang's picture
Init
4409449
raw
history blame
4.59 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
from typing import Optional
import torch
from torch import Tensor
from einops import rearrange
from mGPT.utils.easyconvert import rep_to_rep, nfeats_of, to_matrix
import mGPT.utils.geometry_tools as geometry_tools
from .base import Rots2Rfeats
class Globalvelandy(Rots2Rfeats):
def __init__(self,
path: Optional[str] = None,
normalization: bool = False,
pose_rep: str = "rot6d",
canonicalize: bool = False,
offset: bool = True,
**kwargs) -> None:
super().__init__(path=path, normalization=normalization)
self.canonicalize = canonicalize
self.pose_rep = pose_rep
self.nfeats = nfeats_of(pose_rep)
self.offset = offset
def forward(self, data, data_rep='matrix', first_frame=None) -> Tensor:
poses, trans = data.rots, data.trans
# extract the root gravity axis
# for smpl it is the last coordinate
root_y = trans[..., 2]
trajectory = trans[..., [0, 1]]
# Compute the difference of trajectory
vel_trajectory = torch.diff(trajectory, dim=-2)
# 0 for the first one => keep the dimentionality
if first_frame is None:
first_frame = 0 * vel_trajectory[..., [0], :]
vel_trajectory = torch.cat((first_frame, vel_trajectory), dim=-2)
# first normalize the data
if self.canonicalize:
matrix_poses = rep_to_rep(data_rep, 'matrix', poses)
global_orient = matrix_poses[..., 0, :, :]
# remove the rotation
rot2d = rep_to_rep(data_rep, 'rotvec', poses[0, 0, ...])
# Remove the fist rotation along the vertical axis
rot2d[..., :2] = 0
if self.offset:
# add a bit more rotation
rot2d[..., 2] += torch.pi / 2
rot2d = rep_to_rep('rotvec', 'matrix', rot2d)
# turn with the same amount all the rotations
global_orient = torch.einsum("...kj,...kl->...jl", rot2d,
global_orient)
matrix_poses = torch.cat(
(global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]),
dim=-3)
poses = rep_to_rep('matrix', data_rep, matrix_poses)
# Turn the trajectory as well
vel_trajectory = torch.einsum("...kj,...lk->...lj",
rot2d[..., :2, :2], vel_trajectory)
poses = rep_to_rep(data_rep, self.pose_rep, poses)
features = torch.cat(
(root_y[..., None], vel_trajectory,
rearrange(poses, "... joints rot -> ... (joints rot)")),
dim=-1)
features = self.normalize(features)
return features
def extract(self, features):
root_y = features[..., 0]
vel_trajectory = features[..., 1:3]
poses_features = features[..., 3:]
poses = rearrange(poses_features,
"... (joints rot) -> ... joints rot",
rot=self.nfeats)
return root_y, vel_trajectory, poses
def inverse(self, features, last_frame=None):
features = self.unnormalize(features)
root_y, vel_trajectory, poses = self.extract(features)
# integrate the trajectory
trajectory = torch.cumsum(vel_trajectory, dim=-2)
if last_frame is None:
pass
# First frame should be 0, but if infered it is better to ensure it
trajectory = trajectory - trajectory[..., [0], :]
# Get back the translation
trans = torch.cat([trajectory, root_y[..., None]], dim=-1)
matrix_poses = rep_to_rep(self.pose_rep, 'matrix', poses)
from ..smpl import RotTransDatastruct
return RotTransDatastruct(rots=matrix_poses, trans=trans)