Spaces:
Running
Running
File size: 5,518 Bytes
1966925 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from smplx.utils import ModelOutput
from typing import Optional, NewType
Tensor = NewType('Tensor', torch.Tensor)
@dataclass
class AVESOutput(ModelOutput):
betas: Optional[Tensor] = None
pose: Optional[Tensor] = None
bone: Optional[Tensor] = None
class LBS(torch.nn.Module):
'''
Implementation of linear blend skinning, with additional bone and scale
Input:
V (BN, V, 3): vertices to pose and shape
pose (BN, J, 3, 3) or (BN, J, 3): pose in rot or axis-angle
bone (BN, K): allow for direct change of relative joint distances
scale (1): scale the whole kinematic tree
'''
def __init__(self, J, parents, weights):
super(LBS, self).__init__()
self.n_joints = J.shape[1]
self.register_buffer('h_joints', F.pad(J.unsqueeze(-1), [0,0,0,1], value=0))
self.register_buffer('kin_tree', torch.cat([J[:,[0], :], J[:, 1:]-J[:, parents[1:]]], dim=1).unsqueeze(-1))
self.register_buffer('parents', parents)
self.register_buffer('weights', weights[None].float())
def __call__(self, V, pose, bone, scale, to_rotmats=False):
batch_size = len(V)
device = pose.device
V = F.pad(V.unsqueeze(-1), [0,0,0,1], value=1)
kin_tree = (scale*self.kin_tree) * bone[:, :, None, None]
pose = pose.view([batch_size, -1, 3, 3])
T = torch.zeros([batch_size, self.n_joints, 4, 4]).float().to(device)
T[:, :, -1, -1] = 1
T[:, :, :3, :] = torch.cat([pose, kin_tree], dim=-1)
T_rel = [T[:, 0]]
for i in range(1, self.n_joints):
T_rel.append(T_rel[self.parents[i]] @ T[:, i])
T_rel = torch.stack(T_rel, dim=1)
T_rel[:,:,:,[-1]] -= T_rel.clone() @ (self.h_joints*scale)
T_ = self.weights @ T_rel.view(batch_size, self.n_joints, -1)
T_ = T_.view(batch_size, -1, 4, 4)
V = T_ @ V
return V[:, :, :3, 0]
class AVES(torch.nn.Module):
def __init__(self, **kwargs):
super(AVES, self).__init__()
# kinematic tree, and map to keypoints from vertices
self.register_buffer('kintree_table', kwargs['kintree_table'])
self.register_buffer('parents', kwargs['kintree_table'][0])
self.register_buffer('weights', kwargs['weights'])
self.register_buffer('vert2kpt', kwargs['vert2kpt'])
self.register_buffer('face', kwargs['F'])
# mean shape and default joints
rot = torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], dtype=torch.float32)
rot = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float32) @ rot
# rot = torch.eye(3, dtype=torch.float32)
V = (rot @ kwargs['V'].T).T.unsqueeze(0)
J = (rot @ kwargs['J'].T).T.unsqueeze(0)
self.register_buffer('V', V)
self.register_buffer('J', J)
self.LBS = LBS(self.J, self.parents, self.weights)
# pose and bone prior
self.register_buffer('p_m', kwargs['pose_mean'])
self.register_buffer('b_m', kwargs['bone_mean'])
self.register_buffer('p_cov', kwargs['pose_cov'])
self.register_buffer('b_cov', kwargs['bone_cov'])
# standardized blend shape basis
B = kwargs['Beta']
sigma = kwargs['Beta_sigma']
B = B * sigma[:,None,None]
self.register_buffer('B', B)
# PCA coefficient that is optimized to match the original template shape
### so in the __call__ funciton, if beta is set to self.beta_original,
### it will return the template shape from ECCV2020 (marcbadger/avian-mesh).
self.register_buffer('beta_original', kwargs['beta_original'])
def __call__(self, global_orient, pose, bone, transl=None,
scale=1, betas=None, pose2rot=False, **kwargs):
'''
Input:
global_pose [bn, 3] tensor for batched global_pose on root joint
body_pose [bn, 72] tensor for batched body pose
bone_length [bn, 24] tensor for bone length; the bone variable
captures non-rigid joint articulation in this model
beta [bn, 15] shape PCA coefficients
If beta is None, it will return the mean shape
If beta is self.beta_original, it will return the orignial tempalte shape
'''
device = global_orient.device
batch_size = global_orient.shape[0]
V = self.V.repeat([batch_size, 1, 1]) * scale
J = self.J.repeat([batch_size, 1, 1]) * scale
# multi-bird shape space
if betas is not None:
V = V + torch.einsum('bk, kmn->bmn', betas, self.B)
# concatenate bone and pose
bone = torch.cat([torch.ones([batch_size, 1]).to(device), bone], dim=1)
pose = torch.cat([global_orient, pose], dim=1)
# LBS
verts = self.LBS(V, pose, bone, scale, to_rotmats=pose2rot)
if transl is not None:
verts = verts + transl[:, None, :]
# Calculate 3d keypoint from new vertices resulted from pose
keypoints = torch.einsum('bni,kn->bki', verts, self.vert2kpt)
output = AVESOutput(
vertices=verts,
joints=keypoints,
betas=betas,
global_orient=global_orient,
pose=pose,
bone=bone,
transl=transl,
full_pose=None,
)
return output
|