Spaces:
Configuration error
Configuration error
""" | |
This file contains the MANO defination and mesh sampling operations for MANO mesh | |
Adapted from opensource projects | |
MANOPTH (https://github.com/hassony2/manopth) | |
Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) | |
GraphCMR (https://github.com/nkolot/GraphCMR/) | |
""" | |
from __future__ import division | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import os.path as osp | |
import json | |
import code | |
from custom_manopth.manolayer import ManoLayer | |
import scipy.sparse | |
import custom_mesh_graphormer.modeling.data.config as cfg | |
from pathlib import Path | |
from comfy.model_management import get_torch_device | |
from wrapper_for_mps import sparse_to_dense | |
device = get_torch_device() | |
class MANO(nn.Module): | |
def __init__(self): | |
super(MANO, self).__init__() | |
self.mano_dir = str(Path(__file__).parent / "data") | |
self.layer = self.get_layer() | |
self.vertex_num = 778 | |
self.face = self.layer.th_faces.numpy() | |
self.joint_regressor = self.layer.th_J_regressor.numpy() | |
self.joint_num = 21 | |
self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') | |
self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) | |
self.root_joint_idx = self.joints_name.index('Wrist') | |
# add fingertips to joint_regressor | |
self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) | |
thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) | |
self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] | |
joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() | |
self.register_buffer('joint_regressor_torch', joint_regressor_torch) | |
def get_layer(self): | |
return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model | |
def get_3d_joints(self, vertices): | |
""" | |
This method is used to get the joint locations from the SMPL mesh | |
Input: | |
vertices: size = (B, 778, 3) | |
Output: | |
3D joints: size = (B, 21, 3) | |
""" | |
joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) | |
return joints | |
class SparseMM(torch.autograd.Function): | |
"""Redefine sparse @ dense matrix multiplication to enable backpropagation. | |
The builtin matrix multiplication operation does not support backpropagation in some cases. | |
""" | |
def forward(ctx, sparse, dense): | |
ctx.req_grad = dense.requires_grad | |
ctx.save_for_backward(sparse) | |
return torch.matmul(sparse, dense) | |
def backward(ctx, grad_output): | |
grad_input = None | |
sparse, = ctx.saved_tensors | |
if ctx.req_grad: | |
grad_input = torch.matmul(sparse.t(), grad_output) | |
return None, grad_input | |
def spmm(sparse, dense): | |
sparse = sparse.to(device) | |
dense = dense.to(device) | |
return SparseMM.apply(sparse, dense) | |
def scipy_to_pytorch(A, U, D): | |
"""Convert scipy sparse matrices to pytorch sparse matrix.""" | |
ptU = [] | |
ptD = [] | |
for i in range(len(U)): | |
u = scipy.sparse.coo_matrix(U[i]) | |
i = torch.LongTensor(np.array([u.row, u.col])) | |
v = torch.FloatTensor(u.data) | |
ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) | |
for i in range(len(D)): | |
d = scipy.sparse.coo_matrix(D[i]) | |
i = torch.LongTensor(np.array([d.row, d.col])) | |
v = torch.FloatTensor(d.data) | |
ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) | |
return ptU, ptD | |
def adjmat_sparse(adjmat, nsize=1): | |
"""Create row-normalized sparse graph adjacency matrix.""" | |
adjmat = scipy.sparse.csr_matrix(adjmat) | |
if nsize > 1: | |
orig_adjmat = adjmat.copy() | |
for _ in range(1, nsize): | |
adjmat = adjmat * orig_adjmat | |
adjmat.data = np.ones_like(adjmat.data) | |
for i in range(adjmat.shape[0]): | |
adjmat[i,i] = 1 | |
num_neighbors = np.array(1 / adjmat.sum(axis=-1)) | |
adjmat = adjmat.multiply(num_neighbors) | |
adjmat = scipy.sparse.coo_matrix(adjmat) | |
row = adjmat.row | |
col = adjmat.col | |
data = adjmat.data | |
i = torch.LongTensor(np.array([row, col])) | |
v = torch.from_numpy(data).float() | |
adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) | |
return adjmat | |
def get_graph_params(filename, nsize=1): | |
"""Load and process graph adjacency matrix and upsampling/downsampling matrices.""" | |
data = np.load(filename, encoding='latin1', allow_pickle=True) | |
A = data['A'] | |
U = data['U'] | |
D = data['D'] | |
U, D = scipy_to_pytorch(A, U, D) | |
A = [adjmat_sparse(a, nsize=nsize) for a in A] | |
return A, U, D | |
class Mesh(object): | |
"""Mesh object that is used for handling certain graph operations.""" | |
def __init__(self, filename=cfg.MANO_sampling_matrix, | |
num_downsampling=1, nsize=1, device=torch.device('cuda')): | |
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) | |
# self._A = [a.to(device) for a in self._A] | |
self._U = [u.to(device) for u in self._U] | |
self._D = [d.to(device) for d in self._D] | |
self.num_downsampling = num_downsampling | |
def downsample(self, x, n1=0, n2=None): | |
"""Downsample mesh.""" | |
if n2 is None: | |
n2 = self.num_downsampling | |
if x.ndimension() < 3: | |
for i in range(n1, n2): | |
x = spmm(self._D[i], x) | |
elif x.ndimension() == 3: | |
out = [] | |
for i in range(x.shape[0]): | |
y = x[i] | |
for j in range(n1, n2): | |
y = spmm(self._D[j], y) | |
out.append(y) | |
x = torch.stack(out, dim=0) | |
return x | |
def upsample(self, x, n1=1, n2=0): | |
"""Upsample mesh.""" | |
if x.ndimension() < 3: | |
for i in reversed(range(n2, n1)): | |
x = spmm(self._U[i], x) | |
elif x.ndimension() == 3: | |
out = [] | |
for i in range(x.shape[0]): | |
y = x[i] | |
for j in reversed(range(n2, n1)): | |
y = spmm(self._U[j], y) | |
out.append(y) | |
x = torch.stack(out, dim=0) | |
return x | |