''' This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py ''' import os import cv2 import torch import numpy as np from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis from core import path_config, constants import logging logger = logging.getLogger(__name__) class FitsDict(): """ Dictionary keeping track of the best fit per image in the training set """ def __init__(self, options, train_dataset): self.options = options self.train_dataset = train_dataset self.fits_dict = {} self.valid_fit_state = {} # array used to flip SMPL pose parameters self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM, dtype=torch.int64) # Load dictionary state for ds_name, ds in train_dataset.dataset_dict.items(): if ds_name in ['h36m']: dict_file = os.path.join(path_config.FINAL_FITS_DIR, ds_name + '.npy') self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file)) self.valid_fit_state[ds_name] = torch.ones(len( self.fits_dict[ds_name]), dtype=torch.uint8) else: dict_file = os.path.join(path_config.FINAL_FITS_DIR, ds_name + '.npz') fits_dict = np.load(dict_file) opt_pose = torch.from_numpy(fits_dict['pose']) opt_betas = torch.from_numpy(fits_dict['betas']) opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to( torch.uint8) self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas], dim=1) self.valid_fit_state[ds_name] = opt_valid_fit if not options.single_dataset: for ds in train_dataset.datasets: if ds.dataset not in ['h36m']: ds.pose = self.fits_dict[ds.dataset][:, :72].numpy() ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy() ds.has_smpl = self.valid_fit_state[ds.dataset].numpy() def save(self): """ Save dictionary state to disk """ for ds_name in self.train_dataset.dataset_dict.keys(): dict_file = os.path.join(self.options.checkpoint_dir, ds_name + '_fits.npy') np.save(dict_file, self.fits_dict[ds_name].cpu().numpy()) def __getitem__(self, x): """ Retrieve dictionary entries """ dataset_name, ind, rot, is_flipped = x batch_size = len(dataset_name) pose = torch.zeros((batch_size, 72)) betas = torch.zeros((batch_size, 10)) for ds, i, n in zip(dataset_name, ind, range(batch_size)): params = self.fits_dict[ds][i] pose[n, :] = params[:72] betas[n, :] = params[72:] pose = pose.clone() # Apply flipping and rotation pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped) betas = betas.clone() return pose, betas def get_vaild_state(self, dataset_name, ind): batch_size = len(dataset_name) valid_fit = torch.zeros(batch_size, dtype=torch.uint8) for ds, i, n in zip(dataset_name, ind, range(batch_size)): valid_fit[n] = self.valid_fit_state[ds][i] valid_fit = valid_fit.clone() return valid_fit def __setitem__(self, x, val): """ Update dictionary entries """ dataset_name, ind, rot, is_flipped, update = x pose, betas = val batch_size = len(dataset_name) # Undo flipping and rotation pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot) params = torch.cat((pose, betas), dim=-1).cpu() for ds, i, n in zip(dataset_name, ind, range(batch_size)): if update[n]: self.fits_dict[ds][i] = params[n] def flip_pose(self, pose, is_flipped): """flip SMPL pose parameters""" is_flipped = is_flipped.byte() pose_f = pose.clone() pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts] # we also negate the second and the third dimension of the axis-angle representation pose_f[is_flipped, 1::3] *= -1 pose_f[is_flipped, 2::3] *= -1 return pose_f def rotate_pose(self, pose, rot): """Rotate SMPL pose parameters by rot degrees""" pose = pose.clone() cos = torch.cos(-np.pi * rot / 180.) sin = torch.sin(-np.pi * rot / 180.) zeros = torch.zeros_like(cos) r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device) r3[:, 0, -1] = 1 R = torch.cat([ torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1), torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3 ], dim=1) global_pose = pose[:, :3] global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose) global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3] global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3) global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3 global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy() global_pose_np = np.zeros((global_pose.shape[0], 3)) for i in range(global_pose.shape[0]): aa, _ = cv2.Rodrigues(global_pose_rotmat[i]) global_pose_np[i, :] = aa.squeeze() pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device) return pose