|
import torch |
|
import numpy as np |
|
import glob |
|
import os |
|
import io |
|
import random |
|
import pickle |
|
from torch.utils.data import Dataset, DataLoader |
|
from lib.data.augmentation import Augmenter3D |
|
from lib.utils.tools import read_pkl |
|
from lib.utils.utils_data import flip_data, crop_scale |
|
from lib.utils.utils_mesh import flip_thetas |
|
from lib.utils.utils_smpl import SMPL |
|
from torch.utils.data import Dataset, DataLoader |
|
from lib.data.datareader_h36m import DataReaderH36M |
|
from lib.data.datareader_mesh import DataReaderMesh |
|
from lib.data.dataset_action import random_move |
|
|
|
class SMPLDataset(Dataset): |
|
def __init__(self, args, data_split, dataset): |
|
random.seed(0) |
|
np.random.seed(0) |
|
self.clip_len = args.clip_len |
|
self.data_split = data_split |
|
if dataset=="h36m": |
|
datareader = DataReaderH36M(n_frames=self.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=self.clip_len, dt_root=args.data_root, dt_file=args.dt_file_h36m) |
|
elif dataset=="coco": |
|
datareader = DataReaderMesh(n_frames=1, sample_stride=args.sample_stride, data_stride_train=1, data_stride_test=1, dt_root=args.data_root, dt_file=args.dt_file_coco, res=[640, 640]) |
|
elif dataset=="pw3d": |
|
datareader = DataReaderMesh(n_frames=self.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=self.clip_len, dt_root=args.data_root, dt_file=args.dt_file_pw3d, res=[1920, 1920]) |
|
else: |
|
raise Exception("Mesh dataset undefined.") |
|
|
|
split_id_train, split_id_test = datareader.get_split_id() |
|
train_data, test_data = datareader.read_2d() |
|
train_data, test_data = train_data[split_id_train], test_data[split_id_test] |
|
self.motion_2d = {'train': train_data, 'test': test_data}[data_split] |
|
|
|
dt = datareader.dt_dataset |
|
smpl_pose_train = dt['train']['smpl_pose'][split_id_train] |
|
smpl_shape_train = dt['train']['smpl_shape'][split_id_train] |
|
smpl_pose_test = dt['test']['smpl_pose'][split_id_test] |
|
smpl_shape_test = dt['test']['smpl_shape'][split_id_test] |
|
|
|
self.motion_smpl_3d = {'train': {'pose': smpl_pose_train, 'shape': smpl_shape_train}, 'test': {'pose': smpl_pose_test, 'shape': smpl_shape_test}}[data_split] |
|
self.smpl = SMPL( |
|
args.data_root, |
|
batch_size=1, |
|
) |
|
|
|
def __len__(self): |
|
'Denotes the total number of samples' |
|
return len(self.motion_2d) |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError |
|
|
|
class MotionSMPL(SMPLDataset): |
|
def __init__(self, args, data_split, dataset): |
|
super(MotionSMPL, self).__init__(args, data_split, dataset) |
|
self.flip = args.flip |
|
|
|
def __getitem__(self, index): |
|
'Generates one sample of data' |
|
|
|
motion_2d = self.motion_2d[index] |
|
motion_2d[:,:,2] = np.clip(motion_2d[:,:,2], 0, 1) |
|
motion_smpl_pose = self.motion_smpl_3d['pose'][index].reshape(-1, 24, 3) |
|
motion_smpl_shape = self.motion_smpl_3d['shape'][index] |
|
|
|
if self.data_split=="train": |
|
if self.flip and random.random() > 0.5: |
|
motion_2d = flip_data(motion_2d) |
|
motion_smpl_pose = flip_thetas(motion_smpl_pose) |
|
|
|
|
|
motion_smpl_pose = torch.from_numpy(motion_smpl_pose).reshape(-1, 72).float() |
|
motion_smpl_shape = torch.from_numpy(motion_smpl_shape).reshape(-1, 10).float() |
|
motion_smpl = self.smpl( |
|
betas=motion_smpl_shape, |
|
body_pose=motion_smpl_pose[:, 3:], |
|
global_orient=motion_smpl_pose[:, :3], |
|
pose2rot=True |
|
) |
|
motion_verts = motion_smpl.vertices.detach()*1000.0 |
|
J_regressor = self.smpl.J_regressor_h36m |
|
J_regressor_batch = J_regressor[None, :].expand(motion_verts.shape[0], -1, -1).to(motion_verts.device) |
|
motion_3d_reg = torch.matmul(J_regressor_batch, motion_verts) |
|
motion_verts = motion_verts - motion_3d_reg[:, :1, :] |
|
motion_3d_reg = motion_3d_reg - motion_3d_reg[:, :1, :] |
|
motion_theta = torch.cat((motion_smpl_pose, motion_smpl_shape), -1) |
|
motion_smpl_3d = { |
|
'theta': motion_theta, |
|
'kp_3d': motion_3d_reg, |
|
'verts': motion_verts, |
|
} |
|
return motion_2d, motion_smpl_3d |