File size: 3,368 Bytes
c87d1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import numpy as np

from lib.utils import transforms


def make_collate_fn():
    def collate_fn(items):
        items = list(filter(lambda x: x is not None , items))
        batch = dict()
        try: batch['vid'] = [item['vid'] for item in items]
        except: pass
        try: batch['gender'] = [item['gender'] for item in items]
        except: pass
        for key in items[0].keys():
            try: batch[key] = torch.stack([item[key] for item in items])
            except: pass
        return batch

    return collate_fn


def prepare_keypoints_data(target):
    """Prepare keypoints data"""
    
    # Prepare 2D keypoints
    target['init_kp2d'] = target['kp2d'][:1]
    target['kp2d'] = target['kp2d'][1:]
    if 'kp3d' in target:
        target['kp3d'] = target['kp3d'][1:]

    return target


def prepare_smpl_data(target):
    if 'pose' in target.keys():
        # Use only the main joints
        pose = target['pose'][:]
        # 6-D Rotation representation
        pose6d = transforms.matrix_to_rotation_6d(pose)
        target['pose'] = pose6d[1:]
    
    if 'betas' in target.keys():
        target['betas'] = target['betas'][1:]
    
    # Translation and shape parameters
    if 'transl' in target.keys():
        target['cam'] = target['transl'][1:]
    
    # Initial pose and translation
    target['init_pose'] = transforms.matrix_to_rotation_6d(target['init_pose'])

    return target


def append_target(target, label, key_list, idx1, idx2=None, pad=True):
    for key in key_list:
        if idx2 is None: data = label[key][idx1]
        else: data = label[key][idx1:idx2+1]
        if not pad: data = data[2:]
        target[key] = data
        
    return target


def map_dmpl_to_smpl(pose):
    """ Map AMASS DMPL pose representation to SMPL pose representation



    Args:

        pose - tensor / array with shape of (n_frames, 156)



    Return:

        pose - tensor / array with shape of (n_frames, 24, 3)

    """

    pose = pose.reshape(pose.shape[0], -1, 3)
    pose[:, 23] = pose[:, 37]     # right hand
    if isinstance(pose, np.ndarray): pose = pose[:, :24].copy()
    else: pose = pose[:, :24].clone()
    return pose


def transform_global_coordinate(pose, T, transl=None):
    """ Transform global coordinate of dataset with respect to the given matrix.

    Various datasets have different global coordinate system, 

    thus we united all datasets to the cronical coordinate system.



    Args:

        pose - SMPL pose; tensor / array

        T - Transformation matrix

        transl - SMPL translation

    """

    return_to_numpy = False
    if isinstance(pose, np.ndarray):
        return_to_numpy = True
        pose = torch.from_numpy(pose).float()
        if transl is not None: transl = torch.from_numpy(transl).float()

    pose = transforms.axis_angle_to_matrix(pose)
    pose[:, 0] = T @ pose[:, 0]
    pose = transforms.matrix_to_axis_angle(pose)
    if transl is not None:
        transl = (T @ transl.T).squeeze().T

    if return_to_numpy:
        pose = pose.detach().numpy()
        if transl is not None: transl = transl.detach().numpy()
    return pose, transl