Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| N_POSES = 21 | |
| class AMASSDataset(torch.utils.data.Dataset): | |
| def __init__(self, root_path, version='version0', subset='train', basis_path='base_amass.npy', | |
| sample_interval=None, num_coeffs=100, return_shape=False, | |
| normalize=True, min_max=False): | |
| self.root_path = root_path | |
| self.version = version | |
| assert subset in ['train', 'valid', 'test'] | |
| self.subset = subset | |
| self.sample_interval = sample_interval | |
| self.return_shape = return_shape | |
| self.normalize = normalize | |
| self.min_max = min_max | |
| self.num_coeffs = num_coeffs | |
| self.poses, self.shapes = self.read_data() | |
| if self.sample_interval: | |
| self._sample(sample_interval) | |
| if self.normalize: | |
| if self.min_max: | |
| self.min_poses, self.max_poses, self.min_shapes, self.max_shapes = self.Normalize() | |
| else: | |
| self.mean_poses, self.std_poses, self.mean_shapes, self.std_shapes = self.Normalize() | |
| self.real_data_len = len(self.poses) | |
| def __getitem__(self, idx): | |
| """ | |
| Return: | |
| [21, 3] or [21, 6] for poses including body and root orient | |
| [10] for shapes (betas) [Optimal] | |
| """ | |
| data_poses = self.poses[idx % self.real_data_len] | |
| #coeffs = data_poses} | |
| if self.return_shape: | |
| return data_poses, self.shapes[idx % self.real_data_len] | |
| return data_poses | |
| def __len__(self, ): | |
| return len(self.poses) | |
| def _sample(self, sample_interval): | |
| print(f'Class AMASSDataset({self.subset}): sample dataset every {sample_interval} frame') | |
| self.poses = self.poses[::sample_interval] | |
| def read_data(self): | |
| data_path = os.path.join(self.root_path, self.subset) | |
| # root_orient = torch.load(os.path.join(data_path, 'root_orient.pt')) | |
| coeffs = torch.load(os.path.join(data_path, 'train_coeffs.pt')) | |
| shapes = torch.load(os.path.join(data_path, 'betas.pt')) if self.return_shape else None | |
| # poses = torch.cat([root_orient, pose_body], dim=1) | |
| data_len = len(coeffs) | |
| if self.num_coeffs < 300: | |
| coeffs = coeffs[:, -self.num_coeffs:] | |
| return coeffs, shapes | |
| def Normalize(self): | |
| # Use train dataset for normalize computing, Z_score or min-max Normalize | |
| if self.min_max: | |
| normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize1.pt') | |
| else: | |
| normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize2.pt') | |
| if os.path.exists(normalize_path): | |
| normalize_params = torch.load(normalize_path) | |
| if self.min_max: | |
| min_poses, max_poses, min_shapes, max_shapes = ( | |
| normalize_params['min_poses'], | |
| normalize_params['max_poses'], | |
| normalize_params['min_shapes'], | |
| normalize_params['max_shapes'] | |
| ) | |
| else: | |
| mean_poses, std_poses, mean_shapes, std_shapes = ( | |
| normalize_params['mean_poses'], | |
| normalize_params['std_poses'], | |
| normalize_params['mean_shapes'], | |
| normalize_params['std_shapes'] | |
| ) | |
| else: | |
| if self.min_max: | |
| min_poses = torch.min(self.poses, dim=0)[0] | |
| max_poses = torch.max(self.poses, dim=0)[0] | |
| min_shapes = torch.min(self.shapes, dim=0)[0] if self.return_shape else None | |
| max_shapes = torch.max(self.shapes, dim=0)[0] if self.return_shape else None | |
| torch.save({ | |
| 'min_poses': min_poses, | |
| 'max_poses': max_poses, | |
| 'min_shapes': min_shapes, | |
| 'max_shapes': max_shapes | |
| }, normalize_path) | |
| else: | |
| mean_poses = torch.mean(self.poses, dim=0) | |
| std_poses = torch.std(self.poses, dim=0) | |
| mean_shapes = torch.mean(self.shapes, dim=0) if self.return_shape else None | |
| std_shapes = torch.std(self.shapes, dim=0) if self.return_shape else None | |
| torch.save({ | |
| 'mean_poses': mean_poses, | |
| 'std_poses': std_poses, | |
| 'mean_shapes': mean_shapes, | |
| 'std_shapes': std_shapes | |
| }, normalize_path) | |
| if self.min_max: | |
| self.poses = 2 * (self.poses - min_poses) / (max_poses - min_poses) - 1 | |
| if self.return_shape: | |
| self.shapes = 2 * (self.shapes - min_shapes) / (max_shapes - min_shapes) - 1 | |
| return min_poses, max_poses, min_shapes, max_shapes | |
| else: | |
| self.poses = (self.poses - mean_poses) / std_poses | |
| if self.return_shape: | |
| self.shapes = (self.shapes - mean_shapes) / std_shapes | |
| return mean_poses, std_poses, mean_shapes, std_shapes | |
| def Denormalize(self, poses, shapes=None): | |
| assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim] | |
| if self.min_max: | |
| min_poses = self.min_poses.view(1, -1).to(poses.device) | |
| max_poses = self.max_poses.view(1, -1).to(poses.device) | |
| if len(poses.shape) == 3: # [t, b, data_dim] | |
| min_poses = min_poses.unsqueeze(0) | |
| max_poses = max_poses.unsqueeze(0) | |
| normalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses) | |
| if shapes is not None and self.min_shapes is not None: | |
| min_shapes = self.min_shapes.view(1, -1).to(shapes.device) | |
| max_shapes = self.max_shapes.view(1, -1).to(shapes.device) | |
| if len(shapes.shape) == 3: | |
| min_shapes = min_shapes.unsqueeze(0) | |
| max_shapes = max_shapes.unsqueeze(0) | |
| normalized_shapes = 0.5 * ((shapes + 1) * (max_shapes - min_shapes) + 2 * min_shapes) | |
| return normalized_poses, normalized_shapes | |
| else: | |
| return normalized_poses | |
| else: | |
| mean_poses = self.mean_poses.view(1, -1).to(poses.device) | |
| std_poses = self.std_poses.view(1, -1).to(poses.device) | |
| if len(poses.shape) == 3: # [t, b, data_dim] | |
| mean_poses = mean_poses.unsqueeze(0) | |
| std_poses = std_poses.unsqueeze(0) | |
| normalized_poses = poses * std_poses + mean_poses | |
| if shapes is not None and self.mean_shapes is not None: | |
| mean_shapes = self.mean_shapes.view(1, -1) | |
| std_shapes = self.std_shapes.view(1, -1) | |
| if len(shapes.shape) == 3: | |
| mean_shapes = mean_shapes.unsqueeze(0) | |
| std_shapes = std_shapes.unsqueeze(0) | |
| normalized_shapes = shapes * std_shapes + mean_shapes | |
| return normalized_poses, normalized_shapes | |
| else: | |
| return normalized_poses | |
| def eval(self, preds): | |
| pass | |
| class Posenormalizer: | |
| def __init__(self, data_path, device='cuda:0', normalize=True, min_max=True, rot_rep=None): | |
| assert rot_rep in ['rot6d', 'axis'] | |
| self.normalize = normalize | |
| self.min_max = min_max | |
| self.rot_rep = rot_rep | |
| normalize_params = torch.load(os.path.join(data_path, '{}_normalize1.pt'.format(rot_rep))) | |
| self.min_poses, self.max_poses = normalize_params['min_poses'].to(device), normalize_params['max_poses'].to(device) | |
| normalize_params = torch.load(os.path.join(data_path, '{}_normalize2.pt'.format(rot_rep))) | |
| self.mean_poses, self.std_poses = normalize_params['mean_poses'].to(device), normalize_params['std_poses'].to(device) | |
| def offline_normalize(self, poses, from_axis=False): | |
| assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim] | |
| pose_shape = poses.shape | |
| if not self.normalize: | |
| return poses | |
| if self.min_max: | |
| min_poses = self.min_poses.view(1, -1) | |
| max_poses = self.max_poses.view(1, -1) | |
| if len(poses.shape) == 3: # [t, b, data_dim] | |
| min_poses = min_poses.unsqueeze(0) | |
| max_poses = max_poses.unsqueeze(0) | |
| normalized_poses = 2 * (poses - min_poses) / (max_poses - min_poses) - 1 | |
| else: | |
| mean_poses = self.mean_poses.view(1, -1) | |
| std_poses = self.std_poses.view(1, -1) | |
| if len(poses.shape) == 3: # [t, b, data_dim] | |
| mean_poses = mean_poses.unsqueeze(0) | |
| std_poses = std_poses.unsqueeze(0) | |
| normalized_poses = (poses - mean_poses) / std_poses | |
| return normalized_poses | |
| def offline_denormalize(self, poses, to_axis=False): | |
| assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim] | |
| if not self.normalize: | |
| denormalized_poses = poses | |
| else: | |
| if self.min_max: | |
| min_poses = self.min_poses.view(1, -1) | |
| max_poses = self.max_poses.view(1, -1) | |
| if len(poses.shape) == 3: # [t, b, data_dim] | |
| min_poses = min_poses.unsqueeze(0) | |
| max_poses = max_poses.unsqueeze(0) | |
| denormalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses) | |
| else: | |
| mean_poses = self.mean_poses.view(1, -1) | |
| std_poses = self.std_poses.view(1, -1) | |
| if len(poses.shape) == 3: # [t, b, data_dim] | |
| mean_poses = mean_poses.unsqueeze(0) | |
| std_poses = std_poses.unsqueeze(0) | |
| denormalized_poses = poses * std_poses + mean_poses | |
| return denormalized_poses | |
| # a simple eval process for completion task | |
| class Evaler: | |
| def __init__(self, body_model, part=None): | |
| self.body_model = body_model | |
| self.part = part | |
| if self.part is not None: | |
| self.joint_idx = np.array(getattr(BodyPartIndices, self.part)) + 1 # skip pelvis | |
| self.vert_idx = np.array(getattr(BodySegIndices, self.part)) | |
| else: | |
| self.joint_idx = slice(None) | |
| self.vert_idx = slice(None) | |
| def eval_bodys(self, outs, gts): | |
| ''' | |
| :param outs: [b, j*3] axis-angle results of body poses | |
| :param gts: [b, j*3] axis-angle groundtruth of body poses | |
| :return: result dict for every sample | |
| ''' | |
| sample_num = len(outs) | |
| eval_result = {'mpvpe_all': [], 'mpjpe_body': []} | |
| body_gt = self.body_model(pose_body=gts) | |
| body_out = self.body_model(pose_body=outs) | |
| for n in range(sample_num): | |
| # MPVPE from all vertices | |
| mesh_gt = body_gt.v.detach().cpu().numpy()[n, self.vert_idx] | |
| mesh_out = body_out.v.detach().cpu().numpy()[n, self.vert_idx] | |
| eval_result['mpvpe_all'].append(np.sqrt(np.sum((mesh_out - mesh_gt) ** 2, 1)).mean() * 1000) | |
| joint_gt_body = body_gt.Jtr.detach().cpu().numpy()[n, self.joint_idx] | |
| joint_out_body = body_out.Jtr.detach().cpu().numpy()[n, self.joint_idx] | |
| eval_result['mpjpe_body'].append( | |
| np.sqrt(np.sum((joint_out_body - joint_gt_body) ** 2, 1)).mean() * 1000) | |
| return eval_result | |
| def multi_eval_bodys(self, outs, gts): | |
| ''' | |
| :param outs: [b, hypo, j*3] axis-angle results of body poses, multiple hypothesis | |
| :param gts: [b, j*3] axis-angle groundtruth of body poses | |
| :return: result dict | |
| ''' | |
| hypo_num = outs.shape[1] | |
| eval_result = {f'mpvpe_all': [], f'mpjpe_body': []} | |
| for hypo in range(hypo_num): | |
| result = self.eval_bodys(outs[:, hypo], gts) | |
| eval_result['mpvpe_all'].append(result['mpvpe_all']) | |
| eval_result['mpjpe_body'].append(result['mpjpe_body']) | |
| eval_result['mpvpe_all'] = np.min(eval_result['mpvpe_all'], axis=0) | |
| eval_result['mpjpe_body'] = np.min(eval_result['mpjpe_body'], axis=0) | |
| return eval_result | |
| def print_eval_result(self, eval_result): | |
| print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all'])) | |
| print('MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body'])) | |
| def print_multi_eval_result(self, eval_result, hypo_num): | |
| print(f'multihypo {hypo_num} MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all'])) | |
| print(f'multihypo {hypo_num} MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body'])) |