Jialin Yang
Initial release on Huggingface Spaces with Gradio UI
352b049
# This code is based on https://github.com/Mathux/ACTOR.git
import torch
import numpy as np
import torch.nn.functional as F
from . import rotation_conversions as geometry
from .smpl import SMPL
JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]
class Rotation2xyz:
def __init__(self, device, dataset="amass"):
self.device = device
self.dataset = dataset
self.smpl_model = SMPL().eval().to(device)
def __call__(
self,
x,
mask,
pose_rep,
translation,
glob,
jointstype,
vertstrans,
betas=None,
beta=0,
glob_rot=None,
get_rotations_back=False,
**kwargs
):
if pose_rep == "xyz":
return x
if mask is None:
mask = torch.ones((x.shape[-1]), dtype=bool, device=x.device)
if not glob and glob_rot is None:
raise TypeError("You must specify global rotation if glob is False")
if jointstype not in JOINTSTYPES:
raise NotImplementedError("This jointstype is not implemented.")
if translation:
x_translations = x[-1, :3]
x_rotations = x[:-1]
else:
x_rotations = x
x_rotations = x_rotations.permute(2, 0, 1)
time, njoints, feats = x_rotations.shape
# Compute rotations (convert only masked sequences output)
if pose_rep == "rotvec":
rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
elif pose_rep == "rotmat":
rotations = x_rotations[mask].view(-1, njoints, 3, 3)
elif pose_rep == "rotquat":
rotations = geometry.quaternion_to_matrix(x_rotations[mask])
elif pose_rep == "rot6d":
rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
else:
raise NotImplementedError("No geometry for this one.")
if not glob:
global_orient = torch.tensor(glob_rot, device=x.device)
global_orient = geometry.axis_angle_to_matrix(global_orient).view(
1, 1, 3, 3
)
global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
else:
global_orient = rotations[:, 0]
rotations = rotations[:, 1:]
if betas is None:
betas = torch.zeros(
[rotations.shape[0], self.smpl_model.num_betas],
dtype=rotations.dtype,
device=rotations.device,
)
betas[:, 1] = beta
print(betas)
out = self.smpl_model(
body_pose=rotations, global_orient=global_orient, betas=betas
)
# get the desirable joints
joints = out[jointstype]
x_xyz = torch.empty(time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
x_xyz[~mask] = 0
x_xyz[mask] = joints
x_xyz = x_xyz.permute(1, 2, 0).contiguous()
if translation and vertstrans:
# the first translation root at the origin
x_translations = x_translations - x_translations[:, [0]]
# add the translation to all the joints
x_xyz = x_xyz + x_translations[None, :, :]
if get_rotations_back:
return x_xyz, rotations, global_orient
else:
return x_xyz