JasonSmithSO's picture
Upload 777 files
0034848 verified
"""
This file contains the MANO defination and mesh sampling operations for MANO mesh
Adapted from opensource projects
MANOPTH (https://github.com/hassony2/manopth)
Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
GraphCMR (https://github.com/nkolot/GraphCMR/)
"""
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import os.path as osp
import json
import code
from custom_manopth.manolayer import ManoLayer
import scipy.sparse
import custom_mesh_graphormer.modeling.data.config as cfg
from pathlib import Path
from comfy.model_management import get_torch_device
from wrapper_for_mps import sparse_to_dense
device = get_torch_device()
class MANO(nn.Module):
def __init__(self):
super(MANO, self).__init__()
self.mano_dir = str(Path(__file__).parent / "data")
self.layer = self.get_layer()
self.vertex_num = 778
self.face = self.layer.th_faces.numpy()
self.joint_regressor = self.layer.th_J_regressor.numpy()
self.joint_num = 21
self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
self.root_joint_idx = self.joints_name.index('Wrist')
# add fingertips to joint_regressor
self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand)
thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
joint_regressor_torch = torch.from_numpy(self.joint_regressor).float()
self.register_buffer('joint_regressor_torch', joint_regressor_torch)
def get_layer(self):
return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model
def get_3d_joints(self, vertices):
"""
This method is used to get the joint locations from the SMPL mesh
Input:
vertices: size = (B, 778, 3)
Output:
3D joints: size = (B, 21, 3)
"""
joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch])
return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix."""
ptU = []
ptD = []
for i in range(len(U)):
u = scipy.sparse.coo_matrix(U[i])
i = torch.LongTensor(np.array([u.row, u.col]))
v = torch.FloatTensor(u.data)
ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
for i in range(len(D)):
d = scipy.sparse.coo_matrix(D[i])
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
return ptU, ptD
def adjmat_sparse(adjmat, nsize=1):
"""Create row-normalized sparse graph adjacency matrix."""
adjmat = scipy.sparse.csr_matrix(adjmat)
if nsize > 1:
orig_adjmat = adjmat.copy()
for _ in range(1, nsize):
adjmat = adjmat * orig_adjmat
adjmat.data = np.ones_like(adjmat.data)
for i in range(adjmat.shape[0]):
adjmat[i,i] = 1
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
adjmat = adjmat.multiply(num_neighbors)
adjmat = scipy.sparse.coo_matrix(adjmat)
row = adjmat.row
col = adjmat.col
data = adjmat.data
i = torch.LongTensor(np.array([row, col]))
v = torch.from_numpy(data).float()
adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
return adjmat
def get_graph_params(filename, nsize=1):
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
data = np.load(filename, encoding='latin1', allow_pickle=True)
A = data['A']
U = data['U']
D = data['D']
U, D = scipy_to_pytorch(A, U, D)
A = [adjmat_sparse(a, nsize=nsize) for a in A]
return A, U, D
class Mesh(object):
"""Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.MANO_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U]
self._D = [d.to(device) for d in self._D]
self.num_downsampling = num_downsampling
def downsample(self, x, n1=0, n2=None):
"""Downsample mesh."""
if n2 is None:
n2 = self.num_downsampling
if x.ndimension() < 3:
for i in range(n1, n2):
x = spmm(self._D[i], x)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in range(n1, n2):
y = spmm(self._D[j], y)
out.append(y)
x = torch.stack(out, dim=0)
return x
def upsample(self, x, n1=1, n2=0):
"""Upsample mesh."""
if x.ndimension() < 3:
for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y)
out.append(y)
x = torch.stack(out, dim=0)
return x