File size: 4,591 Bytes
4409449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

from typing import Optional

import torch
from torch import Tensor
from einops import rearrange

from mGPT.utils.easyconvert import rep_to_rep, nfeats_of, to_matrix
import mGPT.utils.geometry_tools as geometry_tools

from .base import Rots2Rfeats


class Globalvelandy(Rots2Rfeats):
    def __init__(self,
                 path: Optional[str] = None,
                 normalization: bool = False,
                 pose_rep: str = "rot6d",
                 canonicalize: bool = False,
                 offset: bool = True,
                 **kwargs) -> None:
        super().__init__(path=path, normalization=normalization)

        self.canonicalize = canonicalize
        self.pose_rep = pose_rep
        self.nfeats = nfeats_of(pose_rep)
        self.offset = offset

    def forward(self, data, data_rep='matrix', first_frame=None) -> Tensor:

        poses, trans = data.rots, data.trans

        # extract the root gravity axis
        # for smpl it is the last coordinate
        root_y = trans[..., 2]
        trajectory = trans[..., [0, 1]]

        # Compute the difference of trajectory
        vel_trajectory = torch.diff(trajectory, dim=-2)

        # 0 for the first one => keep the dimentionality
        if first_frame is None:
            first_frame = 0 * vel_trajectory[..., [0], :]

        vel_trajectory = torch.cat((first_frame, vel_trajectory), dim=-2)

        # first normalize the data
        if self.canonicalize:

            matrix_poses = rep_to_rep(data_rep, 'matrix', poses)
            global_orient = matrix_poses[..., 0, :, :]

            # remove the rotation
            rot2d = rep_to_rep(data_rep, 'rotvec', poses[0, 0, ...])
            
            # Remove the fist rotation along the vertical axis
            rot2d[..., :2] = 0

            if self.offset:
                # add a bit more rotation
                rot2d[..., 2] += torch.pi / 2

            rot2d = rep_to_rep('rotvec', 'matrix', rot2d)
            
            # turn with the same amount all the rotations
            global_orient = torch.einsum("...kj,...kl->...jl", rot2d,
                                         global_orient)

            matrix_poses = torch.cat(
                (global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]),
                dim=-3)

            poses = rep_to_rep('matrix', data_rep, matrix_poses)

            # Turn the trajectory as well
            vel_trajectory = torch.einsum("...kj,...lk->...lj",
                                          rot2d[..., :2, :2], vel_trajectory)

        poses = rep_to_rep(data_rep, self.pose_rep, poses)
        features = torch.cat(
            (root_y[..., None], vel_trajectory,
             rearrange(poses, "... joints rot -> ... (joints rot)")),
            dim=-1)
        features = self.normalize(features)

        return features

    def extract(self, features):
        root_y = features[..., 0]
        vel_trajectory = features[..., 1:3]
        poses_features = features[..., 3:]
        poses = rearrange(poses_features,
                          "... (joints rot) -> ... joints rot",
                          rot=self.nfeats)
        return root_y, vel_trajectory, poses

    def inverse(self, features, last_frame=None):
        features = self.unnormalize(features)
        root_y, vel_trajectory, poses = self.extract(features)

        # integrate the trajectory
        trajectory = torch.cumsum(vel_trajectory, dim=-2)
        if last_frame is None:
            pass
        # First frame should be 0, but if infered it is better to ensure it
        trajectory = trajectory - trajectory[..., [0], :]

        # Get back the translation
        trans = torch.cat([trajectory, root_y[..., None]], dim=-1)
        matrix_poses = rep_to_rep(self.pose_rep, 'matrix',  poses)

        from ..smpl import RotTransDatastruct
        return RotTransDatastruct(rots=matrix_poses, trans=trans)