Spaces:
Sleeping
Sleeping
from lib.kits.basic import * | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import einops | |
from omegaconf import OmegaConf | |
from lib.platform import PM | |
from lib.body_models.skel_utils.transforms import params_q2rep, params_rep2q | |
from .utils.pose_transformer import TransformerDecoder | |
class SKELTransformerDecoderHead(nn.Module): | |
""" Cross-attention based SKEL Transformer decoder | |
""" | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
if cfg.pd_poses_repr == 'rotation_6d': | |
n_poses = 24 * 6 | |
elif cfg.pd_poses_repr == 'euler_angle': | |
n_poses = 46 | |
else: | |
raise ValueError(f"Unknown pose representation: {cfg.pd_poses_repr}") | |
n_betas = 10 | |
n_cam = 3 | |
self.input_is_mean_shape = False | |
# Build transformer decoder. | |
transformer_args = { | |
'num_tokens' : 1, | |
'token_dim' : (n_poses + n_betas + n_cam) if self.input_is_mean_shape else 1, | |
'dim' : 1024, | |
} | |
transformer_args.update(OmegaConf.to_container(cfg.transformer_decoder, resolve=True)) # type: ignore | |
self.transformer = TransformerDecoder(**transformer_args) | |
# Build decoders for parameters. | |
dim = transformer_args['dim'] | |
self.poses_decoder = nn.Linear(dim, n_poses) | |
self.betas_decoder = nn.Linear(dim, n_betas) | |
self.cam_decoder = nn.Linear(dim, n_cam) | |
# Load mean shape parameters as initial values. | |
skel_mean_path = Path(__file__).parent / 'SKEL_mean.npz' | |
skel_mean_params = np.load(skel_mean_path) | |
init_poses = torch.from_numpy(skel_mean_params['poses'].astype(np.float32)).unsqueeze(0) # (1, 46) | |
if cfg.pd_poses_repr == 'rotation_6d': | |
init_poses = params_q2rep(init_poses).reshape(1, 24*6) # (1, 24*6) | |
init_betas = torch.from_numpy(skel_mean_params['betas'].astype(np.float32)).unsqueeze(0) | |
init_cam = torch.from_numpy(skel_mean_params['cam'].astype(np.float32)).unsqueeze(0) | |
self.register_buffer('init_poses', init_poses) | |
self.register_buffer('init_betas', init_betas) | |
self.register_buffer('init_cam', init_cam) | |
def forward(self, x, **kwargs): | |
B = x.shape[0] | |
# vit pretrained backbone is channel-first. Change to token-first | |
x = einops.rearrange(x, 'b c h w -> b (h w) c') | |
# Initialize the parameters. | |
init_poses = self.init_poses.expand(B, -1) # (B, 46) | |
init_betas = self.init_betas.expand(B, -1) # (B, 10) | |
init_cam = self.init_cam.expand(B, -1) # (B, 3) | |
# Input token to transformer is zero token. | |
with PM.time_monitor('init_token'): | |
if self.input_is_mean_shape: | |
token = torch.cat([init_poses, init_betas, init_cam], dim=1)[:, None, :] # (B, 1, C) | |
else: | |
token = x.new_zeros(B, 1, 1) | |
# Pass through transformer. | |
with PM.time_monitor('transformer'): | |
token_out = self.transformer(token, context=x) | |
token_out = token_out.squeeze(1) # (B, C) | |
# Parse the SKEL parameters out from token_out. | |
with PM.time_monitor('decode'): | |
pd_poses = self.poses_decoder(token_out) + init_poses | |
pd_betas = self.betas_decoder(token_out) + init_betas | |
pd_cam = self.cam_decoder(token_out) + init_cam | |
with PM.time_monitor('rot_repr_transform'): | |
if self.cfg.pd_poses_repr == 'rotation_6d': | |
pd_poses = params_rep2q(pd_poses.reshape(-1, 24, 6)) # (B, 46) | |
elif self.cfg.pd_poses_repr == 'euler_angle': | |
pd_poses = pd_poses.reshape(-1, 46) # (B, 46) | |
else: | |
raise ValueError(f"Unknown pose representation: {self.cfg.pd_poses_repr}") | |
pd_skel_params = { | |
'poses' : pd_poses, | |
'poses_orient' : pd_poses[:, :3], | |
'poses_body' : pd_poses[:, 3:], | |
'betas' : pd_betas | |
} | |
return pd_skel_params, pd_cam |