ev2hands / model /utils.py
chris10's picture
init
15bc41b
raw
history blame
No virus
1.39 kB
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from manopth.manolayer import ManoLayer
def create_mano_layers(mano_path, device, n_cmps):
class Output:
def __init__(self, vertices, joints):
self.vertices = vertices
self.joints = joints
class SmplxAdapter:
def __init__(self, side):
self.m = ManoLayer(mano_root=f'{mano_path}/mano', use_pca=True, ncomps=n_cmps, side=side, flat_hand_mean=False, robust_rot=True).to(device)
self.faces = self.m.th_faces.cpu().numpy()
self.shapedirs = self.m.th_shapedirs
def __call__(self, global_orient, hand_pose, betas, transl):
vertices, joints = self.m(torch.cat([global_orient, hand_pose], 1), betas, transl)
vertices /= 1000
joints /= 1000
return Output(vertices, joints)
mano_layer = {
'left': SmplxAdapter(side='left'),
'right': SmplxAdapter(side='right')
}
if torch.sum(torch.abs(mano_layer['left'].m.th_shapedirs[:,0,:] - mano_layer['right'].m.th_shapedirs[:,0,:])) < 1:
print('Fix th_shapedirs bug of MANO')
mano_layer['left'].m.th_shapedirs[:,0,:] *= -1
return mano_layer