ICON / lib /pymaf /core /fits_dict.py
Yuliang's picture
done
2d5f249
raw history blame
No virus
5.72 kB
'''
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
'''
import os
import cv2
import torch
import numpy as np
from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis
from core import path_config, constants
import logging
logger = logging.getLogger(__name__)
class FitsDict():
""" Dictionary keeping track of the best fit per image in the training set """
def __init__(self, options, train_dataset):
self.options = options
self.train_dataset = train_dataset
self.fits_dict = {}
self.valid_fit_state = {}
# array used to flip SMPL pose parameters
self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM,
dtype=torch.int64)
# Load dictionary state
for ds_name, ds in train_dataset.dataset_dict.items():
if ds_name in ['h36m']:
dict_file = os.path.join(path_config.FINAL_FITS_DIR,
ds_name + '.npy')
self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file))
self.valid_fit_state[ds_name] = torch.ones(len(
self.fits_dict[ds_name]),
dtype=torch.uint8)
else:
dict_file = os.path.join(path_config.FINAL_FITS_DIR,
ds_name + '.npz')
fits_dict = np.load(dict_file)
opt_pose = torch.from_numpy(fits_dict['pose'])
opt_betas = torch.from_numpy(fits_dict['betas'])
opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to(
torch.uint8)
self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas],
dim=1)
self.valid_fit_state[ds_name] = opt_valid_fit
if not options.single_dataset:
for ds in train_dataset.datasets:
if ds.dataset not in ['h36m']:
ds.pose = self.fits_dict[ds.dataset][:, :72].numpy()
ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy()
ds.has_smpl = self.valid_fit_state[ds.dataset].numpy()
def save(self):
""" Save dictionary state to disk """
for ds_name in self.train_dataset.dataset_dict.keys():
dict_file = os.path.join(self.options.checkpoint_dir,
ds_name + '_fits.npy')
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())
def __getitem__(self, x):
""" Retrieve dictionary entries """
dataset_name, ind, rot, is_flipped = x
batch_size = len(dataset_name)
pose = torch.zeros((batch_size, 72))
betas = torch.zeros((batch_size, 10))
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
params = self.fits_dict[ds][i]
pose[n, :] = params[:72]
betas[n, :] = params[72:]
pose = pose.clone()
# Apply flipping and rotation
pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped)
betas = betas.clone()
return pose, betas
def get_vaild_state(self, dataset_name, ind):
batch_size = len(dataset_name)
valid_fit = torch.zeros(batch_size, dtype=torch.uint8)
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
valid_fit[n] = self.valid_fit_state[ds][i]
valid_fit = valid_fit.clone()
return valid_fit
def __setitem__(self, x, val):
""" Update dictionary entries """
dataset_name, ind, rot, is_flipped, update = x
pose, betas = val
batch_size = len(dataset_name)
# Undo flipping and rotation
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot)
params = torch.cat((pose, betas), dim=-1).cpu()
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
if update[n]:
self.fits_dict[ds][i] = params[n]
def flip_pose(self, pose, is_flipped):
"""flip SMPL pose parameters"""
is_flipped = is_flipped.byte()
pose_f = pose.clone()
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
# we also negate the second and the third dimension of the axis-angle representation
pose_f[is_flipped, 1::3] *= -1
pose_f[is_flipped, 2::3] *= -1
return pose_f
def rotate_pose(self, pose, rot):
"""Rotate SMPL pose parameters by rot degrees"""
pose = pose.clone()
cos = torch.cos(-np.pi * rot / 180.)
sin = torch.sin(-np.pi * rot / 180.)
zeros = torch.zeros_like(cos)
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
r3[:, 0, -1] = 1
R = torch.cat([
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
],
dim=1)
global_pose = pose[:, :3]
global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose)
global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3]
global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3)
global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3
global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy()
global_pose_np = np.zeros((global_pose.shape[0], 3))
for i in range(global_pose.shape[0]):
aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
global_pose_np[i, :] = aa.squeeze()
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
return pose