Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: ps-license@tuebingen.mpg.de | |
from typing import Optional | |
import torch | |
from torch import Tensor | |
from einops import rearrange | |
from mGPT.utils.easyconvert import rep_to_rep, nfeats_of, to_matrix | |
import mGPT.utils.geometry_tools as geometry_tools | |
from .base import Rots2Rfeats | |
class Globalvelandy(Rots2Rfeats): | |
def __init__(self, | |
path: Optional[str] = None, | |
normalization: bool = False, | |
pose_rep: str = "rot6d", | |
canonicalize: bool = False, | |
offset: bool = True, | |
**kwargs) -> None: | |
super().__init__(path=path, normalization=normalization) | |
self.canonicalize = canonicalize | |
self.pose_rep = pose_rep | |
self.nfeats = nfeats_of(pose_rep) | |
self.offset = offset | |
def forward(self, data, data_rep='matrix', first_frame=None) -> Tensor: | |
poses, trans = data.rots, data.trans | |
# extract the root gravity axis | |
# for smpl it is the last coordinate | |
root_y = trans[..., 2] | |
trajectory = trans[..., [0, 1]] | |
# Compute the difference of trajectory | |
vel_trajectory = torch.diff(trajectory, dim=-2) | |
# 0 for the first one => keep the dimentionality | |
if first_frame is None: | |
first_frame = 0 * vel_trajectory[..., [0], :] | |
vel_trajectory = torch.cat((first_frame, vel_trajectory), dim=-2) | |
# first normalize the data | |
if self.canonicalize: | |
matrix_poses = rep_to_rep(data_rep, 'matrix', poses) | |
global_orient = matrix_poses[..., 0, :, :] | |
# remove the rotation | |
rot2d = rep_to_rep(data_rep, 'rotvec', poses[0, 0, ...]) | |
# Remove the fist rotation along the vertical axis | |
rot2d[..., :2] = 0 | |
if self.offset: | |
# add a bit more rotation | |
rot2d[..., 2] += torch.pi / 2 | |
rot2d = rep_to_rep('rotvec', 'matrix', rot2d) | |
# turn with the same amount all the rotations | |
global_orient = torch.einsum("...kj,...kl->...jl", rot2d, | |
global_orient) | |
matrix_poses = torch.cat( | |
(global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]), | |
dim=-3) | |
poses = rep_to_rep('matrix', data_rep, matrix_poses) | |
# Turn the trajectory as well | |
vel_trajectory = torch.einsum("...kj,...lk->...lj", | |
rot2d[..., :2, :2], vel_trajectory) | |
poses = rep_to_rep(data_rep, self.pose_rep, poses) | |
features = torch.cat( | |
(root_y[..., None], vel_trajectory, | |
rearrange(poses, "... joints rot -> ... (joints rot)")), | |
dim=-1) | |
features = self.normalize(features) | |
return features | |
def extract(self, features): | |
root_y = features[..., 0] | |
vel_trajectory = features[..., 1:3] | |
poses_features = features[..., 3:] | |
poses = rearrange(poses_features, | |
"... (joints rot) -> ... joints rot", | |
rot=self.nfeats) | |
return root_y, vel_trajectory, poses | |
def inverse(self, features, last_frame=None): | |
features = self.unnormalize(features) | |
root_y, vel_trajectory, poses = self.extract(features) | |
# integrate the trajectory | |
trajectory = torch.cumsum(vel_trajectory, dim=-2) | |
if last_frame is None: | |
pass | |
# First frame should be 0, but if infered it is better to ensure it | |
trajectory = trajectory - trajectory[..., [0], :] | |
# Get back the translation | |
trans = torch.cat([trajectory, root_y[..., None]], dim=-1) | |
matrix_poses = rep_to_rep(self.pose_rep, 'matrix', poses) | |
from ..smpl import RotTransDatastruct | |
return RotTransDatastruct(rots=matrix_poses, trans=trans) | |