IsshikiHugh's picture
feat: CPU demo
5ac1897
from lib.kits.basic import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
from omegaconf import OmegaConf
from lib.platform import PM
from lib.body_models.skel_utils.transforms import params_q2rep, params_rep2q
from .utils.pose_transformer import TransformerDecoder
class SKELTransformerDecoderHead(nn.Module):
""" Cross-attention based SKEL Transformer decoder
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.pd_poses_repr == 'rotation_6d':
n_poses = 24 * 6
elif cfg.pd_poses_repr == 'euler_angle':
n_poses = 46
else:
raise ValueError(f"Unknown pose representation: {cfg.pd_poses_repr}")
n_betas = 10
n_cam = 3
self.input_is_mean_shape = False
# Build transformer decoder.
transformer_args = {
'num_tokens' : 1,
'token_dim' : (n_poses + n_betas + n_cam) if self.input_is_mean_shape else 1,
'dim' : 1024,
}
transformer_args.update(OmegaConf.to_container(cfg.transformer_decoder, resolve=True)) # type: ignore
self.transformer = TransformerDecoder(**transformer_args)
# Build decoders for parameters.
dim = transformer_args['dim']
self.poses_decoder = nn.Linear(dim, n_poses)
self.betas_decoder = nn.Linear(dim, n_betas)
self.cam_decoder = nn.Linear(dim, n_cam)
# Load mean shape parameters as initial values.
skel_mean_path = Path(__file__).parent / 'SKEL_mean.npz'
skel_mean_params = np.load(skel_mean_path)
init_poses = torch.from_numpy(skel_mean_params['poses'].astype(np.float32)).unsqueeze(0) # (1, 46)
if cfg.pd_poses_repr == 'rotation_6d':
init_poses = params_q2rep(init_poses).reshape(1, 24*6) # (1, 24*6)
init_betas = torch.from_numpy(skel_mean_params['betas'].astype(np.float32)).unsqueeze(0)
init_cam = torch.from_numpy(skel_mean_params['cam'].astype(np.float32)).unsqueeze(0)
self.register_buffer('init_poses', init_poses)
self.register_buffer('init_betas', init_betas)
self.register_buffer('init_cam', init_cam)
def forward(self, x, **kwargs):
B = x.shape[0]
# vit pretrained backbone is channel-first. Change to token-first
x = einops.rearrange(x, 'b c h w -> b (h w) c')
# Initialize the parameters.
init_poses = self.init_poses.expand(B, -1) # (B, 46)
init_betas = self.init_betas.expand(B, -1) # (B, 10)
init_cam = self.init_cam.expand(B, -1) # (B, 3)
# Input token to transformer is zero token.
with PM.time_monitor('init_token'):
if self.input_is_mean_shape:
token = torch.cat([init_poses, init_betas, init_cam], dim=1)[:, None, :] # (B, 1, C)
else:
token = x.new_zeros(B, 1, 1)
# Pass through transformer.
with PM.time_monitor('transformer'):
token_out = self.transformer(token, context=x)
token_out = token_out.squeeze(1) # (B, C)
# Parse the SKEL parameters out from token_out.
with PM.time_monitor('decode'):
pd_poses = self.poses_decoder(token_out) + init_poses
pd_betas = self.betas_decoder(token_out) + init_betas
pd_cam = self.cam_decoder(token_out) + init_cam
with PM.time_monitor('rot_repr_transform'):
if self.cfg.pd_poses_repr == 'rotation_6d':
pd_poses = params_rep2q(pd_poses.reshape(-1, 24, 6)) # (B, 46)
elif self.cfg.pd_poses_repr == 'euler_angle':
pd_poses = pd_poses.reshape(-1, 46) # (B, 46)
else:
raise ValueError(f"Unknown pose representation: {self.cfg.pd_poses_repr}")
pd_skel_params = {
'poses' : pd_poses,
'poses_orient' : pd_poses[:, :3],
'poses_body' : pd_poses[:, 3:],
'betas' : pd_betas
}
return pd_skel_params, pd_cam