PSHuman / lib /hybrik /models /simple3dpose.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
15.3 kB
from collections import namedtuple
import os
import numpy as np
import torch
import torch.nn as nn
import yaml
from torch.nn import functional as F
from .layers.Resnet import ResNet
from .layers.smpl.SMPL import SMPL_layer
ModelOutput = namedtuple(
typename='ModelOutput',
field_names=[
'pred_shape', 'pred_theta_mats', 'pred_phi', 'pred_delta_shape',
'pred_leaf', 'pred_uvd_jts', 'pred_xyz_jts_29', 'pred_xyz_jts_24',
'pred_xyz_jts_24_struct', 'pred_xyz_jts_17', 'pred_vertices',
'maxvals', 'cam_scale', 'cam_trans', 'cam_root', 'uvd_heatmap',
'transl', 'img_feat', 'pred_camera', 'pred_aa'
])
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields)
def update_config(config_file):
with open(config_file) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def norm_heatmap(norm_type, heatmap):
# Input tensor shape: [N,C,...]
shape = heatmap.shape
if norm_type == 'softmax':
heatmap = heatmap.reshape(*shape[:2], -1)
# global soft max
heatmap = F.softmax(heatmap, 2)
return heatmap.reshape(*shape)
else:
raise NotImplementedError
class HybrIKBaseSMPLCam(nn.Module):
def __init__(self,
cfg_file,
smpl_path,
data_path,
norm_layer=nn.BatchNorm2d):
super(HybrIKBaseSMPLCam, self).__init__()
cfg = update_config(cfg_file)['MODEL']
self.deconv_dim = cfg['NUM_DECONV_FILTERS']
self._norm_layer = norm_layer
self.num_joints = cfg['NUM_JOINTS']
self.norm_type = cfg['POST']['NORM_TYPE']
self.depth_dim = cfg['EXTRA']['DEPTH_DIM']
self.height_dim = cfg['HEATMAP_SIZE'][0]
self.width_dim = cfg['HEATMAP_SIZE'][1]
self.smpl_dtype = torch.float32
backbone = ResNet
self.preact = backbone(f"resnet{cfg['NUM_LAYERS']}")
# Imagenet pretrain model
import torchvision.models as tm
if cfg['NUM_LAYERS'] == 101:
''' Load pretrained model '''
x = tm.resnet101(pretrained=True)
self.feature_channel = 2048
elif cfg['NUM_LAYERS'] == 50:
x = tm.resnet50(pretrained=True)
self.feature_channel = 2048
elif cfg['NUM_LAYERS'] == 34:
x = tm.resnet34(pretrained=True)
self.feature_channel = 512
elif cfg['NUM_LAYERS'] == 18:
x = tm.resnet18(pretrained=True)
self.feature_channel = 512
else:
raise NotImplementedError
model_state = self.preact.state_dict()
state = {
k: v
for k, v in x.state_dict().items()
if k in self.preact.state_dict()
and v.size() == self.preact.state_dict()[k].size()
}
model_state.update(state)
self.preact.load_state_dict(model_state)
self.deconv_layers = self._make_deconv_layer()
self.final_layer = nn.Conv2d(self.deconv_dim[2],
self.num_joints * self.depth_dim,
kernel_size=1,
stride=1,
padding=0)
h36m_jregressor = np.load(
os.path.join(data_path, 'J_regressor_h36m.npy'))
self.smpl = SMPL_layer(smpl_path,
h36m_jregressor=h36m_jregressor,
dtype=self.smpl_dtype)
self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14),
(16, 17), (18, 19), (20, 21), (22, 23))
self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14),
(16, 17), (18, 19), (20, 21), (22, 23),
(25, 26), (27, 28))
self.leaf_pairs = ((0, 1), (3, 4))
self.root_idx_smpl = 0
# mean shape
init_shape = np.load(os.path.join(data_path, 'h36m_mean_beta.npy'))
self.register_buffer('init_shape', torch.Tensor(init_shape).float())
init_cam = torch.tensor([0.9, 0, 0])
self.register_buffer('init_cam', torch.Tensor(init_cam).float())
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(self.feature_channel, 1024)
self.drop1 = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout(p=0.5)
self.decshape = nn.Linear(1024, 10)
self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)]
self.deccam = nn.Linear(1024, 3)
self.focal_length = cfg['FOCAL_LENGTH']
self.input_size = 256.0
def _make_deconv_layer(self):
deconv_layers = []
deconv1 = nn.ConvTranspose2d(self.feature_channel,
self.deconv_dim[0],
kernel_size=4,
stride=2,
padding=int(4 / 2) - 1,
bias=False)
bn1 = self._norm_layer(self.deconv_dim[0])
deconv2 = nn.ConvTranspose2d(self.deconv_dim[0],
self.deconv_dim[1],
kernel_size=4,
stride=2,
padding=int(4 / 2) - 1,
bias=False)
bn2 = self._norm_layer(self.deconv_dim[1])
deconv3 = nn.ConvTranspose2d(self.deconv_dim[1],
self.deconv_dim[2],
kernel_size=4,
stride=2,
padding=int(4 / 2) - 1,
bias=False)
bn3 = self._norm_layer(self.deconv_dim[2])
deconv_layers.append(deconv1)
deconv_layers.append(bn1)
deconv_layers.append(nn.ReLU(inplace=True))
deconv_layers.append(deconv2)
deconv_layers.append(bn2)
deconv_layers.append(nn.ReLU(inplace=True))
deconv_layers.append(deconv3)
deconv_layers.append(bn3)
deconv_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*deconv_layers)
def _initialize(self):
for name, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)
def flip_uvd_coord(self, pred_jts, shift=False, flatten=True):
if flatten:
assert pred_jts.dim() == 2
num_batches = pred_jts.shape[0]
pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3)
else:
assert pred_jts.dim() == 3
num_batches = pred_jts.shape[0]
# flip
if shift:
pred_jts[:, :, 0] = -pred_jts[:, :, 0]
else:
pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0]
for pair in self.joint_pairs_29:
dim0, dim1 = pair
idx = torch.Tensor((dim0, dim1)).long()
inv_idx = torch.Tensor((dim1, dim0)).long()
pred_jts[:, idx] = pred_jts[:, inv_idx]
if flatten:
pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3)
return pred_jts
def flip_xyz_coord(self, pred_jts, flatten=True):
if flatten:
assert pred_jts.dim() == 2
num_batches = pred_jts.shape[0]
pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3)
else:
assert pred_jts.dim() == 3
num_batches = pred_jts.shape[0]
pred_jts[:, :, 0] = -pred_jts[:, :, 0]
for pair in self.joint_pairs_29:
dim0, dim1 = pair
idx = torch.Tensor((dim0, dim1)).long()
inv_idx = torch.Tensor((dim1, dim0)).long()
pred_jts[:, idx] = pred_jts[:, inv_idx]
if flatten:
pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3)
return pred_jts
def flip_phi(self, pred_phi):
pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1]
for pair in self.joint_pairs_24:
dim0, dim1 = pair
idx = torch.Tensor((dim0 - 1, dim1 - 1)).long()
inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long()
pred_phi[:, idx] = pred_phi[:, inv_idx]
return pred_phi
def forward(self,
x,
flip_item=None,
flip_output=False,
gt_uvd=None,
gt_uvd_weight=None,
**kwargs):
batch_size = x.shape[0]
# torch.cuda.synchronize()
# model_start_t = time.time()
x0 = self.preact(x)
out = self.deconv_layers(x0)
out = self.final_layer(out)
# torch.cuda.synchronize()
# preat_end_t = time.time()
out = out.reshape((out.shape[0], self.num_joints, -1))
maxvals, _ = torch.max(out, dim=2, keepdim=True)
out = norm_heatmap(self.norm_type, out)
assert out.dim() == 3, out.shape
heatmaps = out / out.sum(dim=2, keepdim=True)
heatmaps = heatmaps.reshape(
(heatmaps.shape[0], self.num_joints, self.depth_dim,
self.height_dim, self.width_dim))
hm_x0 = heatmaps.sum((2, 3))
hm_y0 = heatmaps.sum((2, 4))
hm_z0 = heatmaps.sum((3, 4))
range_tensor = torch.arange(hm_x0.shape[-1],
dtype=torch.float32,
device=hm_x0.device)
hm_x = hm_x0 * range_tensor
hm_y = hm_y0 * range_tensor
hm_z = hm_z0 * range_tensor
coord_x = hm_x.sum(dim=2, keepdim=True)
coord_y = hm_y.sum(dim=2, keepdim=True)
coord_z = hm_z.sum(dim=2, keepdim=True)
coord_x = coord_x / float(self.width_dim) - 0.5
coord_y = coord_y / float(self.height_dim) - 0.5
coord_z = coord_z / float(self.depth_dim) - 0.5
# -0.5 ~ 0.5
pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2)
x0 = self.avg_pool(x0)
x0 = x0.view(x0.size(0), -1)
init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,)
init_cam = self.init_cam.expand(batch_size, -1) # (B, 3,)
xc = x0
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
delta_shape = self.decshape(xc)
pred_shape = delta_shape + init_shape
pred_phi = self.decphi(xc)
pred_camera = self.deccam(xc).reshape(batch_size, -1) + init_cam
camScale = pred_camera[:, :1].unsqueeze(1)
camTrans = pred_camera[:, 1:].unsqueeze(1)
camDepth = self.focal_length / (self.input_size * camScale + 1e-9)
pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29)
pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :,
2:].clone() # unit: 2.2m
pred_xyz_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) \
* (pred_xyz_jts_29[:, :, 2:]*2.2 + camDepth) - camTrans # unit: m
pred_xyz_jts_29[:, :, :2] = pred_xyz_jts_29_meter / 2.2 # unit: 2.2m
camera_root = pred_xyz_jts_29[:, [0], ] * 2.2
camera_root[:, :, :2] += camTrans
camera_root[:, :, [2]] += camDepth
if not self.training:
pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]]
if flip_item is not None:
assert flip_output is not None
pred_xyz_jts_29_orig, pred_phi_orig, pred_leaf_orig, pred_shape_orig = flip_item
if flip_output:
pred_xyz_jts_29 = self.flip_xyz_coord(pred_xyz_jts_29,
flatten=False)
if flip_output and flip_item is not None:
pred_xyz_jts_29 = (pred_xyz_jts_29 + pred_xyz_jts_29_orig.reshape(
batch_size, 29, 3)) / 2
pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1)
pred_phi = pred_phi.reshape(batch_size, 23, 2)
if flip_output:
pred_phi = self.flip_phi(pred_phi)
if flip_output and flip_item is not None:
pred_phi = (pred_phi + pred_phi_orig) / 2
pred_shape = (pred_shape + pred_shape_orig) / 2
output = self.smpl.hybrik(
pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) *
2.2, # unit: meter
betas=pred_shape.type(self.smpl_dtype),
phis=pred_phi.type(self.smpl_dtype),
global_orient=None,
return_verts=True)
pred_vertices = output.vertices.float()
# -0.5 ~ 0.5
# pred_xyz_jts_24_struct = output.joints.float() / 2.2
pred_xyz_jts_24_struct = output.joints.float() / 2
# -0.5 ~ 0.5
# pred_xyz_jts_17 = output.joints_from_verts.float() / 2.2
pred_xyz_jts_17 = output.joints_from_verts.float() / 2
pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24, 3, 3)
pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size,
72) / 2
pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72)
pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3)
transl = pred_xyz_jts_29[:, 0, :] * \
2.2 - pred_xyz_jts_17[:, 0, :] * 2.2
transl[:, :2] += camTrans[:, 0]
transl[:, 2] += camDepth[:, 0, 0]
new_cam = torch.zeros_like(transl)
new_cam[:, 1:] = transl[:, :2]
new_cam[:, 0] = self.focal_length / \
(self.input_size * transl[:, 2] + 1e-9)
# pred_aa = output.rot_aa.reshape(batch_size, 24, 3)
output = dict(
pred_phi=pred_phi,
pred_delta_shape=delta_shape,
pred_shape=pred_shape,
# pred_aa=pred_aa,
pred_theta_mats=pred_theta_mats,
pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1),
pred_xyz_jts_29=pred_xyz_jts_29_flat,
pred_xyz_jts_24=pred_xyz_jts_24,
pred_xyz_jts_24_struct=pred_xyz_jts_24_struct,
pred_xyz_jts_17=pred_xyz_jts_17_flat,
pred_vertices=pred_vertices,
maxvals=maxvals,
cam_scale=camScale[:, 0],
cam_trans=camTrans[:, 0],
cam_root=camera_root,
pred_camera=new_cam,
transl=transl,
# uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2),
# uvd_heatmap=heatmaps,
# img_feat=x0
)
return output
def forward_gt_theta(self, gt_theta, gt_beta):
output = self.smpl(pose_axis_angle=gt_theta,
betas=gt_beta,
global_orient=None,
return_verts=True)
return output