import os import numpy as np import torch from torch.utils.data import DataLoader N_POSES = 21 class AMASSDataset(torch.utils.data.Dataset): def __init__(self, root_path, version='version0', subset='train', basis_path='base_amass.npy', sample_interval=None, num_coeffs=100, return_shape=False, normalize=True, min_max=False): self.root_path = root_path self.version = version assert subset in ['train', 'valid', 'test'] self.subset = subset self.sample_interval = sample_interval self.return_shape = return_shape self.normalize = normalize self.min_max = min_max self.num_coeffs = num_coeffs self.poses, self.shapes = self.read_data() if self.sample_interval: self._sample(sample_interval) if self.normalize: if self.min_max: self.min_poses, self.max_poses, self.min_shapes, self.max_shapes = self.Normalize() else: self.mean_poses, self.std_poses, self.mean_shapes, self.std_shapes = self.Normalize() self.real_data_len = len(self.poses) def __getitem__(self, idx): """ Return: [21, 3] or [21, 6] for poses including body and root orient [10] for shapes (betas) [Optimal] """ data_poses = self.poses[idx % self.real_data_len] #coeffs = data_poses} if self.return_shape: return data_poses, self.shapes[idx % self.real_data_len] return data_poses def __len__(self, ): return len(self.poses) def _sample(self, sample_interval): print(f'Class AMASSDataset({self.subset}): sample dataset every {sample_interval} frame') self.poses = self.poses[::sample_interval] def read_data(self): data_path = os.path.join(self.root_path, self.subset) # root_orient = torch.load(os.path.join(data_path, 'root_orient.pt')) coeffs = torch.load(os.path.join(data_path, 'train_coeffs.pt')) shapes = torch.load(os.path.join(data_path, 'betas.pt')) if self.return_shape else None # poses = torch.cat([root_orient, pose_body], dim=1) data_len = len(coeffs) if self.num_coeffs < 300: coeffs = coeffs[:, -self.num_coeffs:] return coeffs, shapes def Normalize(self): # Use train dataset for normalize computing, Z_score or min-max Normalize if self.min_max: normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize1.pt') else: normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize2.pt') if os.path.exists(normalize_path): normalize_params = torch.load(normalize_path) if self.min_max: min_poses, max_poses, min_shapes, max_shapes = ( normalize_params['min_poses'], normalize_params['max_poses'], normalize_params['min_shapes'], normalize_params['max_shapes'] ) else: mean_poses, std_poses, mean_shapes, std_shapes = ( normalize_params['mean_poses'], normalize_params['std_poses'], normalize_params['mean_shapes'], normalize_params['std_shapes'] ) else: if self.min_max: min_poses = torch.min(self.poses, dim=0)[0] max_poses = torch.max(self.poses, dim=0)[0] min_shapes = torch.min(self.shapes, dim=0)[0] if self.return_shape else None max_shapes = torch.max(self.shapes, dim=0)[0] if self.return_shape else None torch.save({ 'min_poses': min_poses, 'max_poses': max_poses, 'min_shapes': min_shapes, 'max_shapes': max_shapes }, normalize_path) else: mean_poses = torch.mean(self.poses, dim=0) std_poses = torch.std(self.poses, dim=0) mean_shapes = torch.mean(self.shapes, dim=0) if self.return_shape else None std_shapes = torch.std(self.shapes, dim=0) if self.return_shape else None torch.save({ 'mean_poses': mean_poses, 'std_poses': std_poses, 'mean_shapes': mean_shapes, 'std_shapes': std_shapes }, normalize_path) if self.min_max: self.poses = 2 * (self.poses - min_poses) / (max_poses - min_poses) - 1 if self.return_shape: self.shapes = 2 * (self.shapes - min_shapes) / (max_shapes - min_shapes) - 1 return min_poses, max_poses, min_shapes, max_shapes else: self.poses = (self.poses - mean_poses) / std_poses if self.return_shape: self.shapes = (self.shapes - mean_shapes) / std_shapes return mean_poses, std_poses, mean_shapes, std_shapes def Denormalize(self, poses, shapes=None): assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim] if self.min_max: min_poses = self.min_poses.view(1, -1).to(poses.device) max_poses = self.max_poses.view(1, -1).to(poses.device) if len(poses.shape) == 3: # [t, b, data_dim] min_poses = min_poses.unsqueeze(0) max_poses = max_poses.unsqueeze(0) normalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses) if shapes is not None and self.min_shapes is not None: min_shapes = self.min_shapes.view(1, -1).to(shapes.device) max_shapes = self.max_shapes.view(1, -1).to(shapes.device) if len(shapes.shape) == 3: min_shapes = min_shapes.unsqueeze(0) max_shapes = max_shapes.unsqueeze(0) normalized_shapes = 0.5 * ((shapes + 1) * (max_shapes - min_shapes) + 2 * min_shapes) return normalized_poses, normalized_shapes else: return normalized_poses else: mean_poses = self.mean_poses.view(1, -1).to(poses.device) std_poses = self.std_poses.view(1, -1).to(poses.device) if len(poses.shape) == 3: # [t, b, data_dim] mean_poses = mean_poses.unsqueeze(0) std_poses = std_poses.unsqueeze(0) normalized_poses = poses * std_poses + mean_poses if shapes is not None and self.mean_shapes is not None: mean_shapes = self.mean_shapes.view(1, -1) std_shapes = self.std_shapes.view(1, -1) if len(shapes.shape) == 3: mean_shapes = mean_shapes.unsqueeze(0) std_shapes = std_shapes.unsqueeze(0) normalized_shapes = shapes * std_shapes + mean_shapes return normalized_poses, normalized_shapes else: return normalized_poses def eval(self, preds): pass class Posenormalizer: def __init__(self, data_path, device='cuda:0', normalize=True, min_max=True, rot_rep=None): assert rot_rep in ['rot6d', 'axis'] self.normalize = normalize self.min_max = min_max self.rot_rep = rot_rep normalize_params = torch.load(os.path.join(data_path, '{}_normalize1.pt'.format(rot_rep))) self.min_poses, self.max_poses = normalize_params['min_poses'].to(device), normalize_params['max_poses'].to(device) normalize_params = torch.load(os.path.join(data_path, '{}_normalize2.pt'.format(rot_rep))) self.mean_poses, self.std_poses = normalize_params['mean_poses'].to(device), normalize_params['std_poses'].to(device) def offline_normalize(self, poses, from_axis=False): assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim] pose_shape = poses.shape if not self.normalize: return poses if self.min_max: min_poses = self.min_poses.view(1, -1) max_poses = self.max_poses.view(1, -1) if len(poses.shape) == 3: # [t, b, data_dim] min_poses = min_poses.unsqueeze(0) max_poses = max_poses.unsqueeze(0) normalized_poses = 2 * (poses - min_poses) / (max_poses - min_poses) - 1 else: mean_poses = self.mean_poses.view(1, -1) std_poses = self.std_poses.view(1, -1) if len(poses.shape) == 3: # [t, b, data_dim] mean_poses = mean_poses.unsqueeze(0) std_poses = std_poses.unsqueeze(0) normalized_poses = (poses - mean_poses) / std_poses return normalized_poses def offline_denormalize(self, poses, to_axis=False): assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim] if not self.normalize: denormalized_poses = poses else: if self.min_max: min_poses = self.min_poses.view(1, -1) max_poses = self.max_poses.view(1, -1) if len(poses.shape) == 3: # [t, b, data_dim] min_poses = min_poses.unsqueeze(0) max_poses = max_poses.unsqueeze(0) denormalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses) else: mean_poses = self.mean_poses.view(1, -1) std_poses = self.std_poses.view(1, -1) if len(poses.shape) == 3: # [t, b, data_dim] mean_poses = mean_poses.unsqueeze(0) std_poses = std_poses.unsqueeze(0) denormalized_poses = poses * std_poses + mean_poses return denormalized_poses # a simple eval process for completion task class Evaler: def __init__(self, body_model, part=None): self.body_model = body_model self.part = part if self.part is not None: self.joint_idx = np.array(getattr(BodyPartIndices, self.part)) + 1 # skip pelvis self.vert_idx = np.array(getattr(BodySegIndices, self.part)) else: self.joint_idx = slice(None) self.vert_idx = slice(None) def eval_bodys(self, outs, gts): ''' :param outs: [b, j*3] axis-angle results of body poses :param gts: [b, j*3] axis-angle groundtruth of body poses :return: result dict for every sample ''' sample_num = len(outs) eval_result = {'mpvpe_all': [], 'mpjpe_body': []} body_gt = self.body_model(pose_body=gts) body_out = self.body_model(pose_body=outs) for n in range(sample_num): # MPVPE from all vertices mesh_gt = body_gt.v.detach().cpu().numpy()[n, self.vert_idx] mesh_out = body_out.v.detach().cpu().numpy()[n, self.vert_idx] eval_result['mpvpe_all'].append(np.sqrt(np.sum((mesh_out - mesh_gt) ** 2, 1)).mean() * 1000) joint_gt_body = body_gt.Jtr.detach().cpu().numpy()[n, self.joint_idx] joint_out_body = body_out.Jtr.detach().cpu().numpy()[n, self.joint_idx] eval_result['mpjpe_body'].append( np.sqrt(np.sum((joint_out_body - joint_gt_body) ** 2, 1)).mean() * 1000) return eval_result def multi_eval_bodys(self, outs, gts): ''' :param outs: [b, hypo, j*3] axis-angle results of body poses, multiple hypothesis :param gts: [b, j*3] axis-angle groundtruth of body poses :return: result dict ''' hypo_num = outs.shape[1] eval_result = {f'mpvpe_all': [], f'mpjpe_body': []} for hypo in range(hypo_num): result = self.eval_bodys(outs[:, hypo], gts) eval_result['mpvpe_all'].append(result['mpvpe_all']) eval_result['mpjpe_body'].append(result['mpjpe_body']) eval_result['mpvpe_all'] = np.min(eval_result['mpvpe_all'], axis=0) eval_result['mpjpe_body'] = np.min(eval_result['mpjpe_body'], axis=0) return eval_result def print_eval_result(self, eval_result): print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all'])) print('MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body'])) def print_multi_eval_result(self, eval_result, hypo_num): print(f'multihypo {hypo_num} MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all'])) print(f'multihypo {hypo_num} MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))