File size: 12,823 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import os

import numpy as np
import torch
from torch.utils.data import DataLoader

N_POSES = 21


class AMASSDataset(torch.utils.data.Dataset):
    def __init__(self, root_path, version='version0', subset='train', basis_path='base_amass.npy',
                 sample_interval=None, num_coeffs=100, return_shape=False,
                 normalize=True, min_max=False):

        self.root_path = root_path
        self.version = version
        assert subset in ['train', 'valid', 'test']
        self.subset = subset
        self.sample_interval = sample_interval
        self.return_shape = return_shape
        self.normalize = normalize
        self.min_max = min_max
        self.num_coeffs = num_coeffs
        self.poses, self.shapes = self.read_data()

        if self.sample_interval:
            self._sample(sample_interval)
        if self.normalize:
            if self.min_max:
                self.min_poses, self.max_poses, self.min_shapes, self.max_shapes = self.Normalize()
            else:
                self.mean_poses, self.std_poses, self.mean_shapes, self.std_shapes = self.Normalize()

        self.real_data_len = len(self.poses)

    def __getitem__(self, idx):
        """
        Return:
            [21, 3] or [21, 6] for poses including body and root orient
            [10] for shapes (betas)  [Optimal]
        """
        data_poses = self.poses[idx % self.real_data_len]
        #coeffs = data_poses}
        if self.return_shape:
            return data_poses, self.shapes[idx % self.real_data_len]
        return data_poses

    def __len__(self, ):
        return len(self.poses)

    def _sample(self, sample_interval):
        print(f'Class AMASSDataset({self.subset}): sample dataset every {sample_interval} frame')
        self.poses = self.poses[::sample_interval]

    def read_data(self):
        data_path = os.path.join(self.root_path, self.subset)
        # root_orient = torch.load(os.path.join(data_path, 'root_orient.pt'))
        coeffs = torch.load(os.path.join(data_path, 'train_coeffs.pt'))
        shapes = torch.load(os.path.join(data_path, 'betas.pt')) if self.return_shape else None
        # poses = torch.cat([root_orient, pose_body], dim=1)
        data_len = len(coeffs)
        if self.num_coeffs < 300:
            coeffs = coeffs[:, -self.num_coeffs:]

        return coeffs, shapes

    def Normalize(self):
        # Use train dataset for normalize computing, Z_score or min-max Normalize
        if self.min_max:
            normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize1.pt')
        else:
            normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize2.pt')

        if os.path.exists(normalize_path):
            normalize_params = torch.load(normalize_path)
            if self.min_max:
                min_poses, max_poses, min_shapes, max_shapes = (
                    normalize_params['min_poses'],
                    normalize_params['max_poses'],
                    normalize_params['min_shapes'],
                    normalize_params['max_shapes']
                )
            else:
                mean_poses, std_poses, mean_shapes, std_shapes = (
                    normalize_params['mean_poses'],
                    normalize_params['std_poses'],
                    normalize_params['mean_shapes'],
                    normalize_params['std_shapes']
                )
        else:
            if self.min_max:
                min_poses = torch.min(self.poses, dim=0)[0]
                max_poses = torch.max(self.poses, dim=0)[0]

                min_shapes = torch.min(self.shapes, dim=0)[0] if self.return_shape else None
                max_shapes = torch.max(self.shapes, dim=0)[0] if self.return_shape else None

                torch.save({
                    'min_poses': min_poses,
                    'max_poses': max_poses,
                    'min_shapes': min_shapes,
                    'max_shapes': max_shapes
                }, normalize_path)
            else:
                mean_poses = torch.mean(self.poses, dim=0)
                std_poses = torch.std(self.poses, dim=0)

                mean_shapes = torch.mean(self.shapes, dim=0) if self.return_shape else None
                std_shapes = torch.std(self.shapes, dim=0) if self.return_shape else None

                torch.save({
                    'mean_poses': mean_poses,
                    'std_poses': std_poses,
                    'mean_shapes': mean_shapes,
                    'std_shapes': std_shapes
                }, normalize_path)

        if self.min_max:
            self.poses = 2 * (self.poses - min_poses) / (max_poses - min_poses) - 1
            if self.return_shape:
                self.shapes = 2 * (self.shapes - min_shapes) / (max_shapes - min_shapes) - 1
            return min_poses, max_poses, min_shapes, max_shapes

        else:
            self.poses = (self.poses - mean_poses) / std_poses
            if self.return_shape:
                self.shapes = (self.shapes - mean_shapes) / std_shapes
            return mean_poses, std_poses, mean_shapes, std_shapes


    def Denormalize(self, poses, shapes=None):
        assert len(poses.shape) == 2 or len(poses.shape) == 3  # [b, data_dim] or [t, b, data_dim]

        if self.min_max:
            min_poses = self.min_poses.view(1, -1).to(poses.device)
            max_poses = self.max_poses.view(1, -1).to(poses.device)

            if len(poses.shape) == 3:  # [t, b, data_dim]
                min_poses = min_poses.unsqueeze(0)
                max_poses = max_poses.unsqueeze(0)

            normalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses)

            if shapes is not None and self.min_shapes is not None:
                min_shapes = self.min_shapes.view(1, -1).to(shapes.device)
                max_shapes = self.max_shapes.view(1, -1).to(shapes.device)

                if len(shapes.shape) == 3:
                    min_shapes = min_shapes.unsqueeze(0)
                    max_shapes = max_shapes.unsqueeze(0)

                normalized_shapes = 0.5 * ((shapes + 1) * (max_shapes - min_shapes) + 2 * min_shapes)
                return normalized_poses, normalized_shapes
            else:
                return normalized_poses
        else:
            mean_poses = self.mean_poses.view(1, -1).to(poses.device)
            std_poses = self.std_poses.view(1, -1).to(poses.device)

            if len(poses.shape) == 3:  # [t, b, data_dim]
                mean_poses = mean_poses.unsqueeze(0)
                std_poses = std_poses.unsqueeze(0)

            normalized_poses = poses * std_poses + mean_poses

            if shapes is not None and self.mean_shapes is not None:
                mean_shapes = self.mean_shapes.view(1, -1)
                std_shapes = self.std_shapes.view(1, -1)

                if len(shapes.shape) == 3:
                    mean_shapes = mean_shapes.unsqueeze(0)
                    std_shapes = std_shapes.unsqueeze(0)

                normalized_shapes = shapes * std_shapes + mean_shapes
                return normalized_poses, normalized_shapes
            else:
                return normalized_poses

    def eval(self, preds):
        pass


class Posenormalizer:
    def __init__(self, data_path, device='cuda:0', normalize=True, min_max=True, rot_rep=None):
        assert rot_rep in ['rot6d', 'axis']
        self.normalize = normalize
        self.min_max = min_max
        self.rot_rep = rot_rep
        normalize_params = torch.load(os.path.join(data_path, '{}_normalize1.pt'.format(rot_rep)))
        self.min_poses, self.max_poses = normalize_params['min_poses'].to(device), normalize_params['max_poses'].to(device)
        normalize_params = torch.load(os.path.join(data_path, '{}_normalize2.pt'.format(rot_rep)))
        self.mean_poses, self.std_poses = normalize_params['mean_poses'].to(device), normalize_params['std_poses'].to(device)

    def offline_normalize(self, poses, from_axis=False):
        assert len(poses.shape) == 2 or len(poses.shape) == 3  # [b, data_dim] or [t, b, data_dim]
        pose_shape = poses.shape

        if not self.normalize:
            return poses

        if self.min_max:
            min_poses = self.min_poses.view(1, -1)
            max_poses = self.max_poses.view(1, -1)

            if len(poses.shape) == 3:  # [t, b, data_dim]
                min_poses = min_poses.unsqueeze(0)
                max_poses = max_poses.unsqueeze(0)

            normalized_poses = 2 * (poses - min_poses) / (max_poses - min_poses) - 1

        else:
            mean_poses = self.mean_poses.view(1, -1)
            std_poses = self.std_poses.view(1, -1)

            if len(poses.shape) == 3:  # [t, b, data_dim]
                mean_poses = mean_poses.unsqueeze(0)
                std_poses = std_poses.unsqueeze(0)

            normalized_poses = (poses - mean_poses) / std_poses

        return normalized_poses

    def offline_denormalize(self, poses, to_axis=False):
        assert len(poses.shape) == 2 or len(poses.shape) == 3  # [b, data_dim] or [t, b, data_dim]

        if not self.normalize:
            denormalized_poses = poses
        else:
            if self.min_max:
                min_poses = self.min_poses.view(1, -1)
                max_poses = self.max_poses.view(1, -1)

                if len(poses.shape) == 3:  # [t, b, data_dim]
                    min_poses = min_poses.unsqueeze(0)
                    max_poses = max_poses.unsqueeze(0)

                denormalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses)

            else:
                mean_poses = self.mean_poses.view(1, -1)
                std_poses = self.std_poses.view(1, -1)

                if len(poses.shape) == 3:  # [t, b, data_dim]
                    mean_poses = mean_poses.unsqueeze(0)
                    std_poses = std_poses.unsqueeze(0)

                denormalized_poses = poses * std_poses + mean_poses

        return denormalized_poses


# a simple eval process for completion task
class Evaler:
    def __init__(self, body_model, part=None):
        self.body_model = body_model
        self.part = part

        if self.part is not None:
            self.joint_idx = np.array(getattr(BodyPartIndices, self.part)) + 1  # skip pelvis
            self.vert_idx = np.array(getattr(BodySegIndices, self.part))
        else:
            self.joint_idx = slice(None)
            self.vert_idx = slice(None)

    def eval_bodys(self, outs, gts):
        '''
        :param outs: [b, j*3] axis-angle results of body poses
        :param gts:  [b, j*3] axis-angle groundtruth of body poses
        :return: result dict for every sample
        '''
        sample_num = len(outs)
        eval_result = {'mpvpe_all': [], 'mpjpe_body': []}
        body_gt = self.body_model(pose_body=gts)
        body_out = self.body_model(pose_body=outs)

        for n in range(sample_num):
            # MPVPE from all vertices
            mesh_gt = body_gt.v.detach().cpu().numpy()[n, self.vert_idx]
            mesh_out = body_out.v.detach().cpu().numpy()[n, self.vert_idx]
            eval_result['mpvpe_all'].append(np.sqrt(np.sum((mesh_out - mesh_gt) ** 2, 1)).mean() * 1000)

            joint_gt_body = body_gt.Jtr.detach().cpu().numpy()[n, self.joint_idx]
            joint_out_body = body_out.Jtr.detach().cpu().numpy()[n, self.joint_idx]

            eval_result['mpjpe_body'].append(
                np.sqrt(np.sum((joint_out_body - joint_gt_body) ** 2, 1)).mean() * 1000)

        return eval_result

    def multi_eval_bodys(self, outs, gts):
        '''
        :param outs: [b, hypo, j*3] axis-angle results of body poses, multiple hypothesis
        :param gts:  [b, j*3] axis-angle groundtruth of body poses
        :return: result dict
        '''
        hypo_num = outs.shape[1]
        eval_result = {f'mpvpe_all': [], f'mpjpe_body': []}
        for hypo in range(hypo_num):
            result = self.eval_bodys(outs[:, hypo], gts)
            eval_result['mpvpe_all'].append(result['mpvpe_all'])
            eval_result['mpjpe_body'].append(result['mpjpe_body'])

        eval_result['mpvpe_all'] = np.min(eval_result['mpvpe_all'], axis=0)
        eval_result['mpjpe_body'] = np.min(eval_result['mpjpe_body'], axis=0)

        return eval_result

    def print_eval_result(self, eval_result):
        print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
        print('MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))

    def print_multi_eval_result(self, eval_result, hypo_num):
        print(f'multihypo {hypo_num} MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
        print(f'multihypo {hypo_num} MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))