ECON / lib /smplx /body_models.py
Yuliang's picture
gradio init
df6cc56
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import logging
import os
import os.path as osp
import pickle
from collections import namedtuple
from typing import Dict, Optional, Union
import numpy as np
import torch
import torch.nn as nn
logging.getLogger("smplx").setLevel(logging.ERROR)
from .lbs import find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks
from .utils import (
Array,
FLAMEOutput,
MANOOutput,
SMPLHOutput,
SMPLOutput,
SMPLXOutput,
Struct,
Tensor,
find_joint_kin_chain,
to_np,
to_tensor,
)
from .vertex_ids import vertex_ids as VERTEX_IDS
from .vertex_joint_selector import VertexJointSelector
ModelOutput = namedtuple(
"ModelOutput",
[
"vertices",
"joints",
"full_pose",
"betas",
"global_orient",
"body_pose",
"expression",
"left_hand_pose",
"right_hand_pose",
"jaw_pose",
],
)
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields)
class SMPL(nn.Module):
NUM_JOINTS = 23
NUM_BODY_JOINTS = 23
SHAPE_SPACE_DIM = 300
def __init__(
self,
model_path: str,
kid_template_path: str = "",
data_struct: Optional[Struct] = None,
create_betas: bool = True,
betas: Optional[Tensor] = None,
num_betas: int = 10,
create_global_orient: bool = True,
global_orient: Optional[Tensor] = None,
create_body_pose: bool = True,
body_pose: Optional[Tensor] = None,
create_transl: bool = True,
transl: Optional[Tensor] = None,
dtype=torch.float32,
batch_size: int = 1,
joint_mapper=None,
gender: str = "neutral",
age: str = "adult",
vertex_ids: Dict[str, int] = None,
v_template: Optional[Union[Tensor, Array]] = None,
v_personal: Optional[Union[Tensor, Array]] = None,
**kwargs,
) -> None:
"""SMPL model constructor
Parameters
----------
model_path: str
The path to the folder or to the file where the model
parameters are stored
data_struct: Strct
A struct object. If given, then the parameters of the model are
read from the object. Otherwise, the model tries to read the
parameters from the given `model_path`. (default = None)
create_global_orient: bool, optional
Flag for creating a member variable for the global orientation
of the body. (default = True)
global_orient: torch.tensor, optional, Bx3
The default value for the global orientation variable.
(default = None)
create_body_pose: bool, optional
Flag for creating a member variable for the pose of the body.
(default = True)
body_pose: torch.tensor, optional, Bx(Body Joints * 3)
The default value for the body pose variable.
(default = None)
num_betas: int, optional
Number of shape components to use
(default = 10).
create_betas: bool, optional
Flag for creating a member variable for the shape space
(default = True).
betas: torch.tensor, optional, Bx10
The default value for the shape member variable.
(default = None)
create_transl: bool, optional
Flag for creating a member variable for the translation
of the body. (default = True)
transl: torch.tensor, optional, Bx3
The default value for the transl variable.
(default = None)
dtype: torch.dtype, optional
The data type for the created variables
batch_size: int, optional
The batch size used for creating the member variables
joint_mapper: object, optional
An object that re-maps the joints. Useful if one wants to
re-order the SMPL joints to some other convention (e.g. MSCOCO)
(default = None)
gender: str, optional
Which gender to load
vertex_ids: dict, optional
A dictionary containing the indices of the extra vertices that
will be selected
"""
self.gender = gender
self.age = age
if data_struct is None:
if osp.isdir(model_path):
model_fn = "SMPL_{}.{ext}".format(gender.upper(), ext="pkl")
smpl_path = os.path.join(model_path, model_fn)
else:
smpl_path = model_path
assert osp.exists(smpl_path), "Path {} does not exist!".format(smpl_path)
with open(smpl_path, "rb") as smpl_file:
data_struct = Struct(**pickle.load(smpl_file, encoding="latin1"))
super(SMPL, self).__init__()
self.batch_size = batch_size
shapedirs = data_struct.shapedirs
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM:
# print(f'WARNING: You are using a {self.name()} model, with only'
# ' 10 shape coefficients.')
num_betas = min(num_betas, 10)
else:
num_betas = min(num_betas, self.SHAPE_SPACE_DIM)
if self.age == "kid":
v_template_smil = np.load(kid_template_path)
v_template_smil -= np.mean(v_template_smil, axis=0)
v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2)
shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2)
num_betas = num_betas + 1
self._num_betas = num_betas
shapedirs = shapedirs[:, :, :num_betas]
# The shape components
self.register_buffer("shapedirs", to_tensor(to_np(shapedirs), dtype=dtype))
if vertex_ids is None:
# SMPL and SMPL-H share the same topology, so any extra joints can
# be drawn from the same place
vertex_ids = VERTEX_IDS["smplh"]
self.dtype = dtype
self.joint_mapper = joint_mapper
self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, **kwargs)
self.faces = data_struct.f
self.register_buffer(
"faces_tensor",
to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long),
)
if create_betas:
if betas is None:
default_betas = torch.zeros([batch_size, self.num_betas], dtype=dtype)
else:
if torch.is_tensor(betas):
default_betas = betas.clone().detach()
else:
default_betas = torch.tensor(betas, dtype=dtype)
self.register_parameter("betas", nn.Parameter(default_betas, requires_grad=True))
# The tensor that contains the global rotation of the model
# It is separated from the pose of the joints in case we wish to
# optimize only over one of them
if create_global_orient:
if global_orient is None:
default_global_orient = torch.zeros([batch_size, 3], dtype=dtype)
else:
if torch.is_tensor(global_orient):
default_global_orient = global_orient.clone().detach()
else:
default_global_orient = torch.tensor(global_orient, dtype=dtype)
global_orient = nn.Parameter(default_global_orient, requires_grad=True)
self.register_parameter("global_orient", global_orient)
if create_body_pose:
if body_pose is None:
default_body_pose = torch.zeros([batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
else:
if torch.is_tensor(body_pose):
default_body_pose = body_pose.clone().detach()
else:
default_body_pose = torch.tensor(body_pose, dtype=dtype)
self.register_parameter(
"body_pose", nn.Parameter(default_body_pose, requires_grad=True)
)
if create_transl:
if transl is None:
default_transl = torch.zeros([batch_size, 3], dtype=dtype, requires_grad=True)
else:
default_transl = torch.tensor(transl, dtype=dtype)
self.register_parameter("transl", nn.Parameter(default_transl, requires_grad=True))
if v_template is None:
v_template = data_struct.v_template
if not torch.is_tensor(v_template):
v_template = to_tensor(to_np(v_template), dtype=dtype)
if v_personal is not None:
v_personal = to_tensor(to_np(v_personal), dtype=dtype)
v_template += v_personal
# The vertices of the template model
self.register_buffer("v_template", v_template)
j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype)
self.register_buffer("J_regressor", j_regressor)
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
num_pose_basis = data_struct.posedirs.shape[-1]
# 207 x 20670
posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T
self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=dtype))
# indices of parents for each joints
parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
parents[0] = -1
self.register_buffer("parents", parents)
self.register_buffer("lbs_weights", to_tensor(to_np(data_struct.weights), dtype=dtype))
@property
def num_betas(self):
return self._num_betas
@property
def num_expression_coeffs(self):
return 0
def create_mean_pose(self, data_struct) -> Tensor:
pass
def name(self) -> str:
return "SMPL"
@torch.no_grad()
def reset_params(self, **params_dict) -> None:
for param_name, param in self.named_parameters():
if param_name in params_dict:
param[:] = torch.tensor(params_dict[param_name])
else:
param.fill_(0)
def get_num_verts(self) -> int:
return self.v_template.shape[0]
def get_num_faces(self) -> int:
return self.faces.shape[0]
def extra_repr(self) -> str:
msg = [
f"Gender: {self.gender.upper()}",
f"Number of joints: {self.J_regressor.shape[0]}",
f"Betas: {self.num_betas}",
]
return "\n".join(msg)
def forward(
self,
betas: Optional[Tensor] = None,
body_pose: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts=True,
return_full_pose: bool = False,
pose2rot: bool = True,
**kwargs,
) -> SMPLOutput:
"""Forward pass for the SMPL model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3
If given, ignore the member variable and use it as the global
rotation of the body. Useful if someone wishes to predicts this
with an external model. (default=None)
betas: torch.tensor, optional, shape BxN_b
If given, ignore the member variable `betas` and use it
instead. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
body_pose: torch.tensor, optional, shape Bx(J*3)
If given, ignore the member variable `body_pose` and use it
instead. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
axis-angle format. (default=None)
transl: torch.tensor, optional, shape Bx3
If given, ignore the member variable `transl` and use it
instead. For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full axis-angle pose vector (default=False)
Returns
-------
"""
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas
apply_trans = transl is not None or hasattr(self, "transl")
if transl is None and hasattr(self, "transl"):
transl = self.transl
full_pose = torch.cat([global_orient, body_pose], dim=1)
batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0])
if betas.shape[0] != batch_size:
num_repeats = int(batch_size / betas.shape[0])
betas = betas.expand(num_repeats, -1)
vertices, joints = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=pose2rot,
)
joints = self.vertex_joint_selector(vertices, joints)
# Map the joints to the current dataset
if self.joint_mapper is not None:
joints = self.joint_mapper(joints)
if apply_trans:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = SMPLOutput(
vertices=vertices if return_verts else None,
global_orient=global_orient,
body_pose=body_pose,
joints=joints,
betas=betas,
full_pose=full_pose if return_full_pose else None,
)
return output
class SMPLLayer(SMPL):
def __init__(self, *args, **kwargs) -> None:
# Just create a SMPL module without any member variables
super(SMPLLayer, self).__init__(
create_body_pose=False,
create_betas=False,
create_global_orient=False,
create_transl=False,
*args,
**kwargs,
)
def forward(
self,
betas: Optional[Tensor] = None,
body_pose: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts=True,
return_full_pose: bool = False,
pose2rot: bool = True,
**kwargs,
) -> SMPLOutput:
"""Forward pass for the SMPL model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3x3
Global rotation of the body. Useful if someone wishes to
predicts this with an external model. It is expected to be in
rotation matrix format. (default=None)
betas: torch.tensor, optional, shape BxN_b
Shape parameters. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
body_pose: torch.tensor, optional, shape BxJx3x3
Body pose. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
transl: torch.tensor, optional, shape Bx3
Translation vector of the body.
For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full axis-angle pose vector (default=False)
Returns
-------
"""
model_vars = [betas, global_orient, body_pose, transl]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
device, dtype = self.shapedirs.device, self.shapedirs.dtype
if global_orient is None:
global_orient = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if body_pose is None:
body_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3,
3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
-1).contiguous()
)
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
full_pose = torch.cat(
[
global_orient.reshape(-1, 1, 3, 3),
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
],
dim=1,
)
vertices, joints = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=False,
)
joints = self.vertex_joint_selector(vertices, joints)
# Map the joints to the current dataset
if self.joint_mapper is not None:
joints = self.joint_mapper(joints)
if transl is not None:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = SMPLOutput(
vertices=vertices if return_verts else None,
global_orient=global_orient,
body_pose=body_pose,
joints=joints,
betas=betas,
full_pose=full_pose if return_full_pose else None,
)
return output
class SMPLH(SMPL):
# The hand joints are replaced by MANO
NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2
NUM_HAND_JOINTS = 15
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS
def __init__(
self,
model_path,
kid_template_path: str = "",
data_struct: Optional[Struct] = None,
create_left_hand_pose: bool = True,
left_hand_pose: Optional[Tensor] = None,
create_right_hand_pose: bool = True,
right_hand_pose: Optional[Tensor] = None,
use_pca: bool = True,
num_pca_comps: int = 6,
flat_hand_mean: bool = False,
batch_size: int = 1,
gender: str = "neutral",
age: str = "adult",
dtype=torch.float32,
vertex_ids=None,
use_compressed: bool = True,
ext: str = "pkl",
**kwargs,
) -> None:
"""SMPLH model constructor
Parameters
----------
model_path: str
The path to the folder or to the file where the model
parameters are stored
data_struct: Strct
A struct object. If given, then the parameters of the model are
read from the object. Otherwise, the model tries to read the
parameters from the given `model_path`. (default = None)
create_left_hand_pose: bool, optional
Flag for creating a member variable for the pose of the left
hand. (default = True)
left_hand_pose: torch.tensor, optional, BxP
The default value for the left hand pose member variable.
(default = None)
create_right_hand_pose: bool, optional
Flag for creating a member variable for the pose of the right
hand. (default = True)
right_hand_pose: torch.tensor, optional, BxP
The default value for the right hand pose member variable.
(default = None)
num_pca_comps: int, optional
The number of PCA components to use for each hand.
(default = 6)
flat_hand_mean: bool, optional
If False, then the pose of the hand is initialized to False.
batch_size: int, optional
The batch size used for creating the member variables
gender: str, optional
Which gender to load
dtype: torch.dtype, optional
The data type for the created variables
vertex_ids: dict, optional
A dictionary containing the indices of the extra vertices that
will be selected
"""
self.num_pca_comps = num_pca_comps
# If no data structure is passed, then load the data from the given
# model folder
if data_struct is None:
# Load the model
if osp.isdir(model_path):
model_fn = "SMPLH_{}.{ext}".format(gender.upper(), ext=ext)
smplh_path = os.path.join(model_path, model_fn)
else:
smplh_path = model_path
assert osp.exists(smplh_path), "Path {} does not exist!".format(smplh_path)
if ext == "pkl":
with open(smplh_path, "rb") as smplh_file:
model_data = pickle.load(smplh_file, encoding="latin1")
elif ext == "npz":
model_data = np.load(smplh_path, allow_pickle=True)
else:
raise ValueError("Unknown extension: {}".format(ext))
data_struct = Struct(**model_data)
if vertex_ids is None:
vertex_ids = VERTEX_IDS["smplh"]
super(SMPLH, self).__init__(
model_path=model_path,
kid_template_path=kid_template_path,
data_struct=data_struct,
batch_size=batch_size,
vertex_ids=vertex_ids,
gender=gender,
age=age,
use_compressed=use_compressed,
dtype=dtype,
ext=ext,
**kwargs,
)
self.use_pca = use_pca
self.num_pca_comps = num_pca_comps
self.flat_hand_mean = flat_hand_mean
left_hand_components = data_struct.hands_componentsl[:num_pca_comps]
right_hand_components = data_struct.hands_componentsr[:num_pca_comps]
self.np_left_hand_components = left_hand_components
self.np_right_hand_components = right_hand_components
if self.use_pca:
self.register_buffer(
"left_hand_components", torch.tensor(left_hand_components, dtype=dtype)
)
self.register_buffer(
"right_hand_components",
torch.tensor(right_hand_components, dtype=dtype),
)
if self.flat_hand_mean:
left_hand_mean = np.zeros_like(data_struct.hands_meanl)
else:
left_hand_mean = data_struct.hands_meanl
if self.flat_hand_mean:
right_hand_mean = np.zeros_like(data_struct.hands_meanr)
else:
right_hand_mean = data_struct.hands_meanr
self.register_buffer("left_hand_mean", to_tensor(left_hand_mean, dtype=self.dtype))
self.register_buffer("right_hand_mean", to_tensor(right_hand_mean, dtype=self.dtype))
# Create the buffers for the pose of the left hand
hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS
if create_left_hand_pose:
if left_hand_pose is None:
default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype)
else:
default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype)
left_hand_pose_param = nn.Parameter(default_lhand_pose, requires_grad=True)
self.register_parameter("left_hand_pose", left_hand_pose_param)
if create_right_hand_pose:
if right_hand_pose is None:
default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype)
else:
default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype)
right_hand_pose_param = nn.Parameter(default_rhand_pose, requires_grad=True)
self.register_parameter("right_hand_pose", right_hand_pose_param)
# Create the buffer for the mean pose.
pose_mean_tensor = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean)
if not torch.is_tensor(pose_mean_tensor):
pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype)
self.register_buffer("pose_mean", pose_mean_tensor)
def create_mean_pose(self, data_struct, flat_hand_mean=False):
# Create the array for the mean pose. If flat_hand is false, then use
# the mean that is given by the data, rather than the flat open hand
global_orient_mean = torch.zeros([3], dtype=self.dtype)
body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype)
pose_mean = torch.cat(
[
global_orient_mean,
body_pose_mean,
self.left_hand_mean,
self.right_hand_mean,
],
dim=0,
)
return pose_mean
def name(self) -> str:
return "SMPL+H"
def extra_repr(self):
msg = super(SMPLH, self).extra_repr()
msg = [msg]
if self.use_pca:
msg.append(f"Number of PCA components: {self.num_pca_comps}")
msg.append(f"Flat hand mean: {self.flat_hand_mean}")
return "\n".join(msg)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
body_pose: Optional[Tensor] = None,
left_hand_pose: Optional[Tensor] = None,
right_hand_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
pose2rot: bool = True,
**kwargs,
) -> SMPLHOutput:
""""""
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas
left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose)
right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose)
apply_trans = transl is not None or hasattr(self, "transl")
if transl is None:
if hasattr(self, "transl"):
transl = self.transl
if self.use_pca:
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components])
right_hand_pose = torch.einsum(
"bi,ij->bj", [right_hand_pose, self.right_hand_components]
)
full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1)
full_pose += self.pose_mean
vertices, joints = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=pose2rot,
)
# Add any extra joints that might be needed
joints = self.vertex_joint_selector(vertices, joints)
if self.joint_mapper is not None:
joints = self.joint_mapper(joints)
if apply_trans:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = SMPLHOutput(
vertices=vertices if return_verts else None,
joints=joints,
betas=betas,
global_orient=global_orient,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
full_pose=full_pose if return_full_pose else None,
)
return output
class SMPLHLayer(SMPLH):
def __init__(self, *args, **kwargs) -> None:
"""SMPL+H as a layer model constructor"""
super(SMPLHLayer, self).__init__(
create_global_orient=False,
create_body_pose=False,
create_left_hand_pose=False,
create_right_hand_pose=False,
create_betas=False,
create_transl=False,
*args,
**kwargs,
)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
body_pose: Optional[Tensor] = None,
left_hand_pose: Optional[Tensor] = None,
right_hand_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
pose2rot: bool = True,
**kwargs,
) -> SMPLHOutput:
"""Forward pass for the SMPL+H model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3x3
Global rotation of the body. Useful if someone wishes to
predicts this with an external model. It is expected to be in
rotation matrix format. (default=None)
betas: torch.tensor, optional, shape BxN_b
Shape parameters. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
body_pose: torch.tensor, optional, shape BxJx3x3
If given, ignore the member variable `body_pose` and use it
instead. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
left_hand_pose: torch.tensor, optional, shape Bx15x3x3
If given, contains the pose of the left hand.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
right_hand_pose: torch.tensor, optional, shape Bx15x3x3
If given, contains the pose of the right hand.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
transl: torch.tensor, optional, shape Bx3
Translation vector of the body.
For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full axis-angle pose vector (default=False)
Returns
-------
"""
model_vars = [
betas,
global_orient,
body_pose,
transl,
left_hand_pose,
right_hand_pose,
]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
device, dtype = self.shapedirs.device, self.shapedirs.dtype
if global_orient is None:
global_orient = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if body_pose is None:
body_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous()
)
if left_hand_pose is None:
left_hand_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
)
if right_hand_pose is None:
right_hand_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
)
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
# Concatenate all pose vectors
full_pose = torch.cat(
[
global_orient.reshape(-1, 1, 3, 3),
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
],
dim=1,
)
vertices, joints = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=False,
)
# Add any extra joints that might be needed
joints = self.vertex_joint_selector(vertices, joints)
if self.joint_mapper is not None:
joints = self.joint_mapper(joints)
if transl is not None:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = SMPLHOutput(
vertices=vertices if return_verts else None,
joints=joints,
betas=betas,
global_orient=global_orient,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
full_pose=full_pose if return_full_pose else None,
)
return output
class SMPLX(SMPLH):
"""
SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters
trained jointly for the face, hands and body.
SMPL-X uses standard vertex based linear blend skinning with learned
corrective blend shapes, has N=10475 vertices and K=54 joints,
which includes joints for the neck, jaw, eyeballs and fingers.
"""
NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21
NUM_HAND_JOINTS = 15
NUM_FACE_JOINTS = 3
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
EXPRESSION_SPACE_DIM = 100
NECK_IDX = 12
def __init__(
self,
model_path: str,
kid_template_path: str = "",
num_expression_coeffs: int = 10,
create_expression: bool = True,
expression: Optional[Tensor] = None,
create_jaw_pose: bool = True,
jaw_pose: Optional[Tensor] = None,
create_leye_pose: bool = True,
leye_pose: Optional[Tensor] = None,
create_reye_pose=True,
reye_pose: Optional[Tensor] = None,
use_face_contour: bool = False,
batch_size: int = 1,
gender: str = "neutral",
age: str = "adult",
dtype=torch.float32,
ext: str = "npz",
**kwargs,
) -> None:
"""SMPLX model constructor
Parameters
----------
model_path: str
The path to the folder or to the file where the model
parameters are stored
num_expression_coeffs: int, optional
Number of expression components to use
(default = 10).
create_expression: bool, optional
Flag for creating a member variable for the expression space
(default = True).
expression: torch.tensor, optional, Bx10
The default value for the expression member variable.
(default = None)
create_jaw_pose: bool, optional
Flag for creating a member variable for the jaw pose.
(default = False)
jaw_pose: torch.tensor, optional, Bx3
The default value for the jaw pose variable.
(default = None)
create_leye_pose: bool, optional
Flag for creating a member variable for the left eye pose.
(default = False)
leye_pose: torch.tensor, optional, Bx10
The default value for the left eye pose variable.
(default = None)
create_reye_pose: bool, optional
Flag for creating a member variable for the right eye pose.
(default = False)
reye_pose: torch.tensor, optional, Bx10
The default value for the right eye pose variable.
(default = None)
use_face_contour: bool, optional
Whether to compute the keypoints that form the facial contour
batch_size: int, optional
The batch size used for creating the member variables
gender: str, optional
Which gender to load
dtype: torch.dtype
The data type for the created variables
"""
# Load the model
from huggingface_hub import hf_hub_download
model_fn = "SMPLX_{}.{ext}".format(gender.upper(), ext=ext)
smplx_path = hf_hub_download(
repo_id=model_path, use_auth_token=os.environ["ICON"], filename=f"models/{model_fn}"
)
if ext == "pkl":
with open(smplx_path, "rb") as smplx_file:
model_data = pickle.load(smplx_file, encoding="latin1")
elif ext == "npz":
model_data = np.load(smplx_path, allow_pickle=True)
else:
raise ValueError("Unknown extension: {}".format(ext))
data_struct = Struct(**model_data)
super(SMPLX, self).__init__(
model_path=model_path,
kid_template_path=kid_template_path,
data_struct=data_struct,
dtype=dtype,
batch_size=batch_size,
vertex_ids=VERTEX_IDS["smplx"],
gender=gender,
age=age,
ext=ext,
**kwargs,
)
lmk_faces_idx = data_struct.lmk_faces_idx
self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long))
lmk_bary_coords = data_struct.lmk_bary_coords
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype))
self.use_face_contour = use_face_contour
if self.use_face_contour:
dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long)
self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx)
dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords
dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, dtype=dtype)
self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_bary_coords)
neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)
self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long))
if create_jaw_pose:
if jaw_pose is None:
default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)
jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True)
self.register_parameter("jaw_pose", jaw_pose_param)
if create_leye_pose:
if leye_pose is None:
default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_leye_pose = torch.tensor(leye_pose, dtype=dtype)
leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True)
self.register_parameter("leye_pose", leye_pose_param)
if create_reye_pose:
if reye_pose is None:
default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_reye_pose = torch.tensor(reye_pose, dtype=dtype)
reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True)
self.register_parameter("reye_pose", reye_pose_param)
shapedirs = data_struct.shapedirs
if len(shapedirs.shape) < 3:
shapedirs = shapedirs[:, :, None]
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM:
# print(f'WARNING: You are using a {self.name()} model, with only'
# ' 10 shape and 10 expression coefficients.')
expr_start_idx = 10
expr_end_idx = 20
num_expression_coeffs = min(num_expression_coeffs, 10)
else:
expr_start_idx = self.SHAPE_SPACE_DIM
expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs
num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM)
self._num_expression_coeffs = num_expression_coeffs
expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]
self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype))
if create_expression:
if expression is None:
default_expression = torch.zeros([batch_size, self.num_expression_coeffs],
dtype=dtype)
else:
default_expression = torch.tensor(expression, dtype=dtype)
expression_param = nn.Parameter(default_expression, requires_grad=True)
self.register_parameter("expression", expression_param)
def name(self) -> str:
return "SMPL-X"
@property
def num_expression_coeffs(self):
return self._num_expression_coeffs
def create_mean_pose(self, data_struct, flat_hand_mean=False):
# Create the array for the mean pose. If flat_hand is false, then use
# the mean that is given by the data, rather than the flat open hand
global_orient_mean = torch.zeros([3], dtype=self.dtype)
body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype)
jaw_pose_mean = torch.zeros([3], dtype=self.dtype)
leye_pose_mean = torch.zeros([3], dtype=self.dtype)
reye_pose_mean = torch.zeros([3], dtype=self.dtype)
pose_mean = np.concatenate(
[
global_orient_mean,
body_pose_mean,
jaw_pose_mean,
leye_pose_mean,
reye_pose_mean,
self.left_hand_mean,
self.right_hand_mean,
],
axis=0,
)
return pose_mean
def extra_repr(self):
msg = super(SMPLX, self).extra_repr()
msg = [msg, f"Number of Expression Coefficients: {self.num_expression_coeffs}"]
return "\n".join(msg)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
body_pose: Optional[Tensor] = None,
left_hand_pose: Optional[Tensor] = None,
right_hand_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
expression: Optional[Tensor] = None,
jaw_pose: Optional[Tensor] = None,
leye_pose: Optional[Tensor] = None,
reye_pose: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
pose2rot: bool = True,
return_joint_transformation: bool = False,
return_vertex_transformation: bool = False,
pose_type: str = 'posed',
**kwargs,
) -> SMPLXOutput:
"""
Forward pass for the SMPLX model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3
If given, ignore the member variable and use it as the global
rotation of the body. Useful if someone wishes to predicts this
with an external model. (default=None)
betas: torch.tensor, optional, shape BxN_b
If given, ignore the member variable `betas` and use it
instead. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
expression: torch.tensor, optional, shape BxN_e
If given, ignore the member variable `expression` and use it
instead. For example, it can used if expression parameters
`expression` are predicted from some external model.
body_pose: torch.tensor, optional, shape Bx(J*3)
If given, ignore the member variable `body_pose` and use it
instead. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
axis-angle format. (default=None)
left_hand_pose: torch.tensor, optional, shape BxP
If given, ignore the member variable `left_hand_pose` and
use this instead. It should either contain PCA coefficients or
joint rotations in axis-angle format.
right_hand_pose: torch.tensor, optional, shape BxP
If given, ignore the member variable `right_hand_pose` and
use this instead. It should either contain PCA coefficients or
joint rotations in axis-angle format.
jaw_pose: torch.tensor, optional, shape Bx3
If given, ignore the member variable `jaw_pose` and
use this instead. It should either joint rotations in
axis-angle format.
transl: torch.tensor, optional, shape Bx3
If given, ignore the member variable `transl` and use it
instead. For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full axis-angle pose vector (default=False)
Returns
-------
output: ModelOutput
A named tuple of type `ModelOutput`
"""
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas
left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose)
right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose)
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
leye_pose = leye_pose if leye_pose is not None else self.leye_pose
reye_pose = reye_pose if reye_pose is not None else self.reye_pose
expression = expression if expression is not None else self.expression
apply_trans = transl is not None or hasattr(self, "transl")
if transl is None:
if hasattr(self, "transl"):
transl = self.transl
if self.use_pca:
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components])
right_hand_pose = torch.einsum(
"bi,ij->bj", [right_hand_pose, self.right_hand_components]
)
full_pose = torch.cat(
[
global_orient,
body_pose,
jaw_pose,
leye_pose,
reye_pose,
left_hand_pose,
right_hand_pose,
],
dim=1,
)
if pose_type == "t-pose":
full_pose *= 0.0
elif pose_type == "a-pose":
body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3)
body_pose[:, 15] = torch.tensor([0., 0., -45 * np.pi / 180.])
body_pose[:, 16] = torch.tensor([0., 0., 45 * np.pi / 180.])
body_pose = body_pose.view(body_pose.shape[0], -1)
full_pose = torch.cat(
[
global_orient * 0.,
body_pose,
jaw_pose * 0.,
leye_pose * 0.,
reye_pose * 0.,
left_hand_pose * 0.,
right_hand_pose * 0.,
],
dim=1,
)
elif pose_type == "da-pose":
body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3)
body_pose[:, 0] = torch.tensor([0., 0., 30 * np.pi / 180.])
body_pose[:, 1] = torch.tensor([0., 0., -30 * np.pi / 180.])
body_pose = body_pose.view(body_pose.shape[0], -1)
full_pose = torch.cat(
[
global_orient * 0.,
body_pose,
jaw_pose * 0.,
leye_pose * 0.,
reye_pose * 0.,
left_hand_pose * 0.,
right_hand_pose * 0.,
],
dim=1,
)
# Add the mean pose of the model. Does not affect the body, only the
# hands when flat_hand_mean == False
# full_pose += self.pose_mean
batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0])
# Concatenate the shape and expression coefficients
scale = int(batch_size / betas.shape[0])
if scale > 1:
betas = betas.expand(scale, -1)
shape_components = torch.cat([betas, expression], dim=-1)
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
if return_joint_transformation or return_vertex_transformation:
vertices, joints, joint_transformation, vertex_transformation = lbs(
shape_components,
full_pose,
self.v_template,
shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=pose2rot,
return_transformation=True,
)
else:
vertices, joints = lbs(
shape_components,
full_pose,
self.v_template,
shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=pose2rot,
)
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
if self.use_face_contour:
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
vertices,
full_pose,
self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.neck_kin_chain,
pose2rot=True,
)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
], 1)
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
# Add any extra joints that might be needed
joints = self.vertex_joint_selector(vertices, joints)
# Add the landmarks to the joints
joints = torch.cat([joints, landmarks], dim=1)
# Map the joints to the current dataset
if self.joint_mapper is not None:
joints = self.joint_mapper(joints=joints, vertices=vertices)
if apply_trans:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = SMPLXOutput(
vertices=vertices if return_verts else None,
joints=joints,
betas=betas,
expression=expression,
global_orient=global_orient,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
jaw_pose=jaw_pose,
full_pose=full_pose if return_full_pose else None,
joint_transformation=joint_transformation if return_joint_transformation else None,
vertex_transformation=vertex_transformation if return_vertex_transformation else None,
)
return output
class SMPLXLayer(SMPLX):
def __init__(self, *args, **kwargs) -> None:
# Just create a SMPLX module without any member variables
super(SMPLXLayer, self).__init__(
create_global_orient=False,
create_body_pose=False,
create_left_hand_pose=False,
create_right_hand_pose=False,
create_jaw_pose=False,
create_leye_pose=False,
create_reye_pose=False,
create_betas=False,
create_expression=False,
create_transl=False,
*args,
**kwargs,
)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
body_pose: Optional[Tensor] = None,
left_hand_pose: Optional[Tensor] = None,
right_hand_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
expression: Optional[Tensor] = None,
jaw_pose: Optional[Tensor] = None,
leye_pose: Optional[Tensor] = None,
reye_pose: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs,
) -> SMPLXOutput:
"""
Forward pass for the SMPLX model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3x3
If given, ignore the member variable and use it as the global
rotation of the body. Useful if someone wishes to predicts this
with an external model. It is expected to be in rotation matrix
format. (default=None)
betas: torch.tensor, optional, shape BxN_b
If given, ignore the member variable `betas` and use it
instead. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
expression: torch.tensor, optional, shape BxN_e
Expression coefficients.
For example, it can used if expression parameters
`expression` are predicted from some external model.
body_pose: torch.tensor, optional, shape BxJx3x3
If given, ignore the member variable `body_pose` and use it
instead. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
left_hand_pose: torch.tensor, optional, shape Bx15x3x3
If given, contains the pose of the left hand.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
right_hand_pose: torch.tensor, optional, shape Bx15x3x3
If given, contains the pose of the right hand.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
jaw_pose: torch.tensor, optional, shape Bx3x3
Jaw pose. It should either joint rotations in
rotation matrix format.
transl: torch.tensor, optional, shape Bx3
Translation vector of the body.
For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full pose vector (default=False)
Returns
-------
output: ModelOutput
A data class that contains the posed vertices and joints
"""
device, dtype = self.shapedirs.device, self.shapedirs.dtype
model_vars = [
betas,
global_orient,
body_pose,
transl,
expression,
left_hand_pose,
right_hand_pose,
jaw_pose,
]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
if global_orient is None:
global_orient = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if body_pose is None:
body_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3,
3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
-1).contiguous()
)
if left_hand_pose is None:
left_hand_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
)
if right_hand_pose is None:
right_hand_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
)
if jaw_pose is None:
jaw_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if leye_pose is None:
leye_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if reye_pose is None:
reye_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if expression is None:
expression = torch.zeros([batch_size, self.num_expression_coeffs],
dtype=dtype,
device=device)
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
# Concatenate all pose vectors
full_pose = torch.cat(
[
global_orient.reshape(-1, 1, 3, 3),
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
jaw_pose.reshape(-1, 1, 3, 3),
leye_pose.reshape(-1, 1, 3, 3),
reye_pose.reshape(-1, 1, 3, 3),
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
],
dim=1,
)
shape_components = torch.cat([betas, expression], dim=-1)
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
vertices, joints = lbs(
shape_components,
full_pose,
self.v_template,
shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=False,
)
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(batch_size, 1, 1)
if self.use_face_contour:
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
vertices,
full_pose,
self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.neck_kin_chain,
pose2rot=False,
)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
], 1)
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
# Add any extra joints that might be needed
joints = self.vertex_joint_selector(vertices, joints)
# Add the landmarks to the joints
joints = torch.cat([joints, landmarks], dim=1)
# Map the joints to the current dataset
if self.joint_mapper is not None:
joints = self.joint_mapper(joints=joints, vertices=vertices)
if transl is not None:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = SMPLXOutput(
vertices=vertices if return_verts else None,
joints=joints,
betas=betas,
expression=expression,
global_orient=global_orient,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
jaw_pose=jaw_pose,
transl=transl,
full_pose=full_pose if return_full_pose else None,
)
return output
class MANO(SMPL):
# The hand joints are replaced by MANO
NUM_BODY_JOINTS = 1
NUM_HAND_JOINTS = 15
NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS
def __init__(
self,
model_path: str,
is_rhand: bool = True,
data_struct: Optional[Struct] = None,
create_hand_pose: bool = True,
hand_pose: Optional[Tensor] = None,
use_pca: bool = True,
num_pca_comps: int = 6,
flat_hand_mean: bool = False,
batch_size: int = 1,
dtype=torch.float32,
vertex_ids=None,
use_compressed: bool = True,
ext: str = "pkl",
**kwargs,
) -> None:
"""MANO model constructor
Parameters
----------
model_path: str
The path to the folder or to the file where the model
parameters are stored
data_struct: Strct
A struct object. If given, then the parameters of the model are
read from the object. Otherwise, the model tries to read the
parameters from the given `model_path`. (default = None)
create_hand_pose: bool, optional
Flag for creating a member variable for the pose of the right
hand. (default = True)
hand_pose: torch.tensor, optional, BxP
The default value for the right hand pose member variable.
(default = None)
num_pca_comps: int, optional
The number of PCA components to use for each hand.
(default = 6)
flat_hand_mean: bool, optional
If False, then the pose of the hand is initialized to False.
batch_size: int, optional
The batch size used for creating the member variables
dtype: torch.dtype, optional
The data type for the created variables
vertex_ids: dict, optional
A dictionary containing the indices of the extra vertices that
will be selected
"""
self.num_pca_comps = num_pca_comps
self.is_rhand = is_rhand
# If no data structure is passed, then load the data from the given
# model folder
if data_struct is None:
# Load the model
if osp.isdir(model_path):
model_fn = "MANO_{}.{ext}".format("RIGHT" if is_rhand else "LEFT", ext=ext)
mano_path = os.path.join(model_path, model_fn)
else:
mano_path = model_path
self.is_rhand = (True if "RIGHT" in os.path.basename(model_path) else False)
assert osp.exists(mano_path), "Path {} does not exist!".format(mano_path)
if ext == "pkl":
with open(mano_path, "rb") as mano_file:
model_data = pickle.load(mano_file, encoding="latin1")
elif ext == "npz":
model_data = np.load(mano_path, allow_pickle=True)
else:
raise ValueError("Unknown extension: {}".format(ext))
data_struct = Struct(**model_data)
if vertex_ids is None:
vertex_ids = VERTEX_IDS["smplh"]
super(MANO, self).__init__(
model_path=model_path,
data_struct=data_struct,
batch_size=batch_size,
vertex_ids=vertex_ids,
use_compressed=use_compressed,
dtype=dtype,
ext=ext,
**kwargs,
)
# add only MANO tips to the extra joints
self.vertex_joint_selector.extra_joints_idxs = to_tensor(
list(VERTEX_IDS["mano"].values()), dtype=torch.long
)
self.use_pca = use_pca
self.num_pca_comps = num_pca_comps
if self.num_pca_comps == 45:
self.use_pca = False
self.flat_hand_mean = flat_hand_mean
hand_components = data_struct.hands_components[:num_pca_comps]
self.np_hand_components = hand_components
if self.use_pca:
self.register_buffer("hand_components", torch.tensor(hand_components, dtype=dtype))
if self.flat_hand_mean:
hand_mean = np.zeros_like(data_struct.hands_mean)
else:
hand_mean = data_struct.hands_mean
self.register_buffer("hand_mean", to_tensor(hand_mean, dtype=self.dtype))
# Create the buffers for the pose of the left hand
hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS
if create_hand_pose:
if hand_pose is None:
default_hand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype)
else:
default_hand_pose = torch.tensor(hand_pose, dtype=dtype)
hand_pose_param = nn.Parameter(default_hand_pose, requires_grad=True)
self.register_parameter("hand_pose", hand_pose_param)
# Create the buffer for the mean pose.
pose_mean = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean)
pose_mean_tensor = pose_mean.clone().to(dtype)
# pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype)
self.register_buffer("pose_mean", pose_mean_tensor)
def name(self) -> str:
return "MANO"
def create_mean_pose(self, data_struct, flat_hand_mean=False):
# Create the array for the mean pose. If flat_hand is false, then use
# the mean that is given by the data, rather than the flat open hand
global_orient_mean = torch.zeros([3], dtype=self.dtype)
pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0)
return pose_mean
def extra_repr(self):
msg = [super(MANO, self).extra_repr()]
if self.use_pca:
msg.append(f"Number of PCA components: {self.num_pca_comps}")
msg.append(f"Flat hand mean: {self.flat_hand_mean}")
return "\n".join(msg)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
hand_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs,
) -> MANOOutput:
"""Forward pass for the MANO model"""
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else self.global_orient)
betas = betas if betas is not None else self.betas
hand_pose = hand_pose if hand_pose is not None else self.hand_pose
apply_trans = transl is not None or hasattr(self, "transl")
if transl is None:
if hasattr(self, "transl"):
transl = self.transl
if self.use_pca:
hand_pose = torch.einsum("bi,ij->bj", [hand_pose, self.hand_components])
full_pose = torch.cat([global_orient, hand_pose], dim=1)
full_pose += self.pose_mean
vertices, joints = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=True,
)
# # Add pre-selected extra joints that might be needed
# joints = self.vertex_joint_selector(vertices, joints)
if self.joint_mapper is not None:
joints = self.joint_mapper(joints)
if apply_trans:
joints = joints + transl.unsqueeze(dim=1)
vertices = vertices + transl.unsqueeze(dim=1)
output = MANOOutput(
vertices=vertices if return_verts else None,
joints=joints if return_verts else None,
betas=betas,
global_orient=global_orient,
hand_pose=hand_pose,
full_pose=full_pose if return_full_pose else None,
)
return output
class MANOLayer(MANO):
def __init__(self, *args, **kwargs) -> None:
"""MANO as a layer model constructor"""
super(MANOLayer, self).__init__(
create_global_orient=False,
create_hand_pose=False,
create_betas=False,
create_transl=False,
*args,
**kwargs,
)
def name(self) -> str:
return "MANO"
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
hand_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs,
) -> MANOOutput:
"""Forward pass for the MANO model"""
device, dtype = self.shapedirs.device, self.shapedirs.dtype
if global_orient is None:
batch_size = 1
global_orient = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
else:
batch_size = global_orient.shape[0]
if hand_pose is None:
hand_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
)
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
full_pose = torch.cat([global_orient, hand_pose], dim=1)
vertices, joints = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=False,
)
if self.joint_mapper is not None:
joints = self.joint_mapper(joints)
if transl is not None:
joints = joints + transl.unsqueeze(dim=1)
vertices = vertices + transl.unsqueeze(dim=1)
output = MANOOutput(
vertices=vertices if return_verts else None,
joints=joints if return_verts else None,
betas=betas,
global_orient=global_orient,
hand_pose=hand_pose,
full_pose=full_pose if return_full_pose else None,
)
return output
class FLAME(SMPL):
NUM_JOINTS = 5
SHAPE_SPACE_DIM = 300
EXPRESSION_SPACE_DIM = 100
NECK_IDX = 0
def __init__(
self,
model_path: str,
data_struct=None,
num_expression_coeffs=10,
create_expression: bool = True,
expression: Optional[Tensor] = None,
create_neck_pose: bool = True,
neck_pose: Optional[Tensor] = None,
create_jaw_pose: bool = True,
jaw_pose: Optional[Tensor] = None,
create_leye_pose: bool = True,
leye_pose: Optional[Tensor] = None,
create_reye_pose=True,
reye_pose: Optional[Tensor] = None,
use_face_contour=False,
batch_size: int = 1,
gender: str = "neutral",
dtype: torch.dtype = torch.float32,
ext="pkl",
**kwargs,
) -> None:
"""FLAME model constructor
Parameters
----------
model_path: str
The path to the folder or to the file where the model
parameters are stored
num_expression_coeffs: int, optional
Number of expression components to use
(default = 10).
create_expression: bool, optional
Flag for creating a member variable for the expression space
(default = True).
expression: torch.tensor, optional, Bx10
The default value for the expression member variable.
(default = None)
create_neck_pose: bool, optional
Flag for creating a member variable for the neck pose.
(default = False)
neck_pose: torch.tensor, optional, Bx3
The default value for the neck pose variable.
(default = None)
create_jaw_pose: bool, optional
Flag for creating a member variable for the jaw pose.
(default = False)
jaw_pose: torch.tensor, optional, Bx3
The default value for the jaw pose variable.
(default = None)
create_leye_pose: bool, optional
Flag for creating a member variable for the left eye pose.
(default = False)
leye_pose: torch.tensor, optional, Bx10
The default value for the left eye pose variable.
(default = None)
create_reye_pose: bool, optional
Flag for creating a member variable for the right eye pose.
(default = False)
reye_pose: torch.tensor, optional, Bx10
The default value for the right eye pose variable.
(default = None)
use_face_contour: bool, optional
Whether to compute the keypoints that form the facial contour
batch_size: int, optional
The batch size used for creating the member variables
gender: str, optional
Which gender to load
dtype: torch.dtype
The data type for the created variables
"""
model_fn = f"FLAME_{gender.upper()}.{ext}"
flame_path = os.path.join(model_path, model_fn)
assert osp.exists(flame_path), "Path {} does not exist!".format(flame_path)
if ext == "npz":
file_data = np.load(flame_path, allow_pickle=True)
elif ext == "pkl":
with open(flame_path, "rb") as smpl_file:
file_data = pickle.load(smpl_file, encoding="latin1")
else:
raise ValueError("Unknown extension: {}".format(ext))
data_struct = Struct(**file_data)
super(FLAME, self).__init__(
model_path=model_path,
data_struct=data_struct,
dtype=dtype,
batch_size=batch_size,
gender=gender,
ext=ext,
**kwargs,
)
self.use_face_contour = use_face_contour
self.vertex_joint_selector.extra_joints_idxs = to_tensor([], dtype=torch.long)
if create_neck_pose:
if neck_pose is None:
default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_neck_pose = torch.tensor(neck_pose, dtype=dtype)
neck_pose_param = nn.Parameter(default_neck_pose, requires_grad=True)
self.register_parameter("neck_pose", neck_pose_param)
if create_jaw_pose:
if jaw_pose is None:
default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)
jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True)
self.register_parameter("jaw_pose", jaw_pose_param)
if create_leye_pose:
if leye_pose is None:
default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_leye_pose = torch.tensor(leye_pose, dtype=dtype)
leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True)
self.register_parameter("leye_pose", leye_pose_param)
if create_reye_pose:
if reye_pose is None:
default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)
else:
default_reye_pose = torch.tensor(reye_pose, dtype=dtype)
reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True)
self.register_parameter("reye_pose", reye_pose_param)
shapedirs = data_struct.shapedirs
if len(shapedirs.shape) < 3:
shapedirs = shapedirs[:, :, None]
if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM:
# print(f'WARNING: You are using a {self.name()} model, with only'
# ' 10 shape and 10 expression coefficients.')
expr_start_idx = 10
expr_end_idx = 20
num_expression_coeffs = min(num_expression_coeffs, 10)
else:
expr_start_idx = self.SHAPE_SPACE_DIM
expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs
num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM)
self._num_expression_coeffs = num_expression_coeffs
expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]
self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype))
if create_expression:
if expression is None:
default_expression = torch.zeros([batch_size, self.num_expression_coeffs],
dtype=dtype)
else:
default_expression = torch.tensor(expression, dtype=dtype)
expression_param = nn.Parameter(default_expression, requires_grad=True)
self.register_parameter("expression", expression_param)
# The pickle file that contains the barycentric coordinates for
# regressing the landmarks
landmark_bcoord_filename = osp.join(model_path, "flame_static_embedding.pkl")
with open(landmark_bcoord_filename, "rb") as fp:
landmarks_data = pickle.load(fp, encoding="latin1")
lmk_faces_idx = landmarks_data["lmk_face_idx"].astype(np.int64)
self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long))
lmk_bary_coords = landmarks_data["lmk_b_coords"]
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype))
if self.use_face_contour:
face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy")
contour_embeddings = np.load(face_contour_path, allow_pickle=True,
encoding="latin1")[()]
dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64)
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long)
self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx)
dynamic_lmk_b_coords = torch.tensor(contour_embeddings["lmk_b_coords"], dtype=dtype)
self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_b_coords)
neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)
self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long))
@property
def num_expression_coeffs(self):
return self._num_expression_coeffs
def name(self) -> str:
return "FLAME"
def extra_repr(self):
msg = [
super(FLAME, self).extra_repr(),
f"Number of Expression Coefficients: {self.num_expression_coeffs}",
f"Use face contour: {self.use_face_contour}",
]
return "\n".join(msg)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
neck_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
expression: Optional[Tensor] = None,
jaw_pose: Optional[Tensor] = None,
leye_pose: Optional[Tensor] = None,
reye_pose: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
pose2rot: bool = True,
**kwargs,
) -> FLAMEOutput:
"""
Forward pass for the SMPLX model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3
If given, ignore the member variable and use it as the global
rotation of the body. Useful if someone wishes to predicts this
with an external model. (default=None)
betas: torch.tensor, optional, shape Bx10
If given, ignore the member variable `betas` and use it
instead. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
expression: torch.tensor, optional, shape Bx10
If given, ignore the member variable `expression` and use it
instead. For example, it can used if expression parameters
`expression` are predicted from some external model.
jaw_pose: torch.tensor, optional, shape Bx3
If given, ignore the member variable `jaw_pose` and
use this instead. It should either joint rotations in
axis-angle format.
jaw_pose: torch.tensor, optional, shape Bx3
If given, ignore the member variable `jaw_pose` and
use this instead. It should either joint rotations in
axis-angle format.
transl: torch.tensor, optional, shape Bx3
If given, ignore the member variable `transl` and use it
instead. For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full axis-angle pose vector (default=False)
Returns
-------
output: ModelOutput
A named tuple of type `ModelOutput`
"""
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else self.global_orient)
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
neck_pose = neck_pose if neck_pose is not None else self.neck_pose
leye_pose = leye_pose if leye_pose is not None else self.leye_pose
reye_pose = reye_pose if reye_pose is not None else self.reye_pose
betas = betas if betas is not None else self.betas
expression = expression if expression is not None else self.expression
apply_trans = transl is not None or hasattr(self, "transl")
if transl is None:
if hasattr(self, "transl"):
transl = self.transl
full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1)
batch_size = max(betas.shape[0], global_orient.shape[0], jaw_pose.shape[0])
# Concatenate the shape and expression coefficients
scale = int(batch_size / betas.shape[0])
if scale > 1:
betas = betas.expand(scale, -1)
shape_components = torch.cat([betas, expression], dim=-1)
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
vertices, joints = lbs(
shape_components,
full_pose,
self.v_template,
shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=pose2rot,
)
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
if self.use_face_contour:
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
vertices,
full_pose,
self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.neck_kin_chain,
pose2rot=True,
)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
], 1)
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
# Add any extra joints that might be needed
joints = self.vertex_joint_selector(vertices, joints)
# Add the landmarks to the joints
joints = torch.cat([joints, landmarks], dim=1)
# Map the joints to the current dataset
if self.joint_mapper is not None:
joints = self.joint_mapper(joints=joints, vertices=vertices)
if apply_trans:
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = FLAMEOutput(
vertices=vertices if return_verts else None,
joints=joints,
betas=betas,
expression=expression,
global_orient=global_orient,
neck_pose=neck_pose,
jaw_pose=jaw_pose,
full_pose=full_pose if return_full_pose else None,
)
return output
class FLAMELayer(FLAME):
def __init__(self, *args, **kwargs) -> None:
""" FLAME as a layer model constructor """
super(FLAMELayer, self).__init__(
create_betas=False,
create_expression=False,
create_global_orient=False,
create_neck_pose=False,
create_jaw_pose=False,
create_leye_pose=False,
create_reye_pose=False,
*args,
**kwargs,
)
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
neck_pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
expression: Optional[Tensor] = None,
jaw_pose: Optional[Tensor] = None,
leye_pose: Optional[Tensor] = None,
reye_pose: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
pose2rot: bool = True,
**kwargs,
) -> FLAMEOutput:
"""
Forward pass for the SMPLX model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3x3
Global rotation of the body. Useful if someone wishes to
predicts this with an external model. It is expected to be in
rotation matrix format. (default=None)
betas: torch.tensor, optional, shape BxN_b
Shape parameters. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
expression: torch.tensor, optional, shape BxN_e
If given, ignore the member variable `expression` and use it
instead. For example, it can used if expression parameters
`expression` are predicted from some external model.
jaw_pose: torch.tensor, optional, shape Bx3x3
Jaw pose. It should either joint rotations in
rotation matrix format.
transl: torch.tensor, optional, shape Bx3
Translation vector of the body.
For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full axis-angle pose vector (default=False)
Returns
-------
output: ModelOutput
A named tuple of type `ModelOutput`
"""
device, dtype = self.shapedirs.device, self.shapedirs.dtype
if global_orient is None:
batch_size = 1
global_orient = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
else:
batch_size = global_orient.shape[0]
if neck_pose is None:
neck_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous()
)
if jaw_pose is None:
jaw_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if leye_pose is None:
leye_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if reye_pose is None:
reye_pose = (
torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
)
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if expression is None:
expression = torch.zeros([batch_size, self.num_expression_coeffs],
dtype=dtype,
device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1)
shape_components = torch.cat([betas, expression], dim=-1)
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
vertices, joints = lbs(
shape_components,
full_pose,
self.v_template,
shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
pose2rot=False,
)
lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
if self.use_face_contour:
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
vertices,
full_pose,
self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.neck_kin_chain,
pose2rot=False,
)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([
lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
], 1)
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
# Add any extra joints that might be needed
joints = self.vertex_joint_selector(vertices, joints)
# Add the landmarks to the joints
joints = torch.cat([joints, landmarks], dim=1)
# Map the joints to the current dataset
if self.joint_mapper is not None:
joints = self.joint_mapper(joints=joints, vertices=vertices)
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
output = FLAMEOutput(
vertices=vertices if return_verts else None,
joints=joints,
betas=betas,
expression=expression,
global_orient=global_orient,
neck_pose=neck_pose,
jaw_pose=jaw_pose,
full_pose=full_pose if return_full_pose else None,
)
return output
def build_layer(model_path: str,
model_type: str = "smpl",
**kwargs) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]:
"""Method for creating a model from a path and a model type
Parameters
----------
model_path: str
Either the path to the model you wish to load or a folder,
where each subfolder contains the differents types, i.e.:
model_path:
|
|-- smpl
|-- SMPL_FEMALE
|-- SMPL_NEUTRAL
|-- SMPL_MALE
|-- smplh
|-- SMPLH_FEMALE
|-- SMPLH_MALE
|-- smplx
|-- SMPLX_FEMALE
|-- SMPLX_NEUTRAL
|-- SMPLX_MALE
|-- mano
|-- MANO RIGHT
|-- MANO LEFT
|-- flame
|-- FLAME_FEMALE
|-- FLAME_MALE
|-- FLAME_NEUTRAL
model_type: str, optional
When model_path is a folder, then this parameter specifies the
type of model to be loaded
**kwargs: dict
Keyword arguments
Returns
-------
body_model: nn.Module
The PyTorch module that implements the corresponding body model
Raises
------
ValueError: In case the model type is not one of SMPL, SMPLH,
SMPLX, MANO or FLAME
"""
if osp.isdir(model_path):
model_path = os.path.join(model_path, model_type)
else:
model_type = osp.basename(model_path).split("_")[0].lower()
if model_type.lower() == "smpl":
return SMPLLayer(model_path, **kwargs)
elif model_type.lower() == "smplh":
return SMPLHLayer(model_path, **kwargs)
elif model_type.lower() == "smplx":
return SMPLXLayer(model_path, **kwargs)
elif "mano" in model_type.lower():
return MANOLayer(model_path, **kwargs)
elif "flame" in model_type.lower():
return FLAMELayer(model_path, **kwargs)
else:
raise ValueError(f"Unknown model type {model_type}, exiting!")
def create(model_path: str,
model_type: str = "smpl",
**kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]:
"""Method for creating a model from a path and a model type
Parameters
----------
model_path: str
Either the path to the model you wish to load or a folder,
where each subfolder contains the differents types, i.e.:
model_path:
|
|-- smpl
|-- SMPL_FEMALE
|-- SMPL_NEUTRAL
|-- SMPL_MALE
|-- smplh
|-- SMPLH_FEMALE
|-- SMPLH_MALE
|-- smplx
|-- SMPLX_FEMALE
|-- SMPLX_NEUTRAL
|-- SMPLX_MALE
|-- mano
|-- MANO RIGHT
|-- MANO LEFT
model_type: str, optional
When model_path is a folder, then this parameter specifies the
type of model to be loaded
**kwargs: dict
Keyword arguments
Returns
-------
body_model: nn.Module
The PyTorch module that implements the corresponding body model
Raises
------
ValueError: In case the model type is not one of SMPL, SMPLH,
SMPLX, MANO or FLAME
"""
# If it's a folder, assume
if osp.isdir(model_path):
model_path = os.path.join(model_path, model_type)
else:
model_type = osp.basename(model_path).split("_")[0].lower()
if model_type.lower() == "smpl":
return SMPL(model_path, **kwargs)
elif model_type.lower() == "smplh":
return SMPLH(model_path, **kwargs)
elif model_type.lower() == "smplx":
return SMPLX(model_path, **kwargs)
elif "mano" in model_type.lower():
return MANO(model_path, **kwargs)
elif "flame" in model_type.lower():
return FLAME(model_path, **kwargs)
else:
raise ValueError(f"Unknown model type {model_type}, exiting!")