|
import numpy as np |
|
import torch.nn as nn |
|
import torch |
|
import os |
|
import torch.nn.functional as F |
|
from .pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation |
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
def __init__(self): |
|
super(AttentionBlock, self).__init__() |
|
|
|
def forward(self, key, value, query): |
|
query = query.permute(0, 2, 1) |
|
N, KC = key.shape[:2] |
|
key = key.view(N, KC, -1) |
|
|
|
N, KC = value.shape[:2] |
|
value = value.view(N, KC, -1) |
|
|
|
sim_map = torch.bmm(key, query) |
|
sim_map = (KC ** -.5 ) * sim_map |
|
sim_map = F.softmax(sim_map, dim=1) |
|
|
|
context = torch.bmm(sim_map, value) |
|
|
|
return context |
|
|
|
|
|
class MANORegressor(nn.Module): |
|
def __init__(self, n_inp_features=4, n_pose_params=6, n_shape_params=10): |
|
super(MANORegressor, self).__init__() |
|
|
|
normal_channel = True |
|
|
|
if normal_channel: |
|
additional_channel = n_inp_features |
|
else: |
|
additional_channel = 0 |
|
|
|
self.normal_channel = normal_channel |
|
|
|
self.sa1 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], additional_channel, [[128, 128, 256], [128, 196, 256]]) |
|
self.sa2 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512], group_all=True) |
|
|
|
self.n_pose_params = n_pose_params |
|
self.n_mano_params = n_pose_params + n_shape_params |
|
|
|
self.mano_regressor = nn.Sequential( |
|
nn.Linear(512, 1024), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(1024), |
|
nn.Dropout(0.3), |
|
nn.Linear(1024, 3 + self.n_mano_params + 3), |
|
) |
|
|
|
|
|
def J3dtoJ2d(self, j3d, scale): |
|
B, N = j3d.shape[:2] |
|
device = j3d.device |
|
|
|
j2d = torch.zeros(B, N, 2, device=device) |
|
j2d[:, :, 0] = scale[:, :, 0] * j3d[:, :, 0] |
|
j2d[:, :, 1] = scale[:, :, 1] * j3d[:, :, 1] |
|
|
|
return j2d |
|
|
|
def forward(self, xyz, features, mano_hand, previous_mano_params=None): |
|
device = xyz.device |
|
batch_size = xyz.shape[0] |
|
|
|
l0_xyz = xyz |
|
l0_points = features |
|
|
|
l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) |
|
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) |
|
|
|
l2_xyz = l2_xyz.squeeze(-1) |
|
l2_points = l2_points.squeeze(-1) |
|
|
|
if previous_mano_params is None: |
|
previous_mano_params = torch.zeros(self.n_mano_params).unsqueeze(0).expand(batch_size, -1).to(device) |
|
previous_rot_trans_params = torch.zeros(6).unsqueeze(0).expand(batch_size, -1).to(device) |
|
|
|
mano_params = self.mano_regressor(l2_points) |
|
|
|
global_orient = mano_params[:, :3] |
|
hand_pose = mano_params[:, 3:3+self.n_pose_params] |
|
betas = mano_params[:, 3+self.n_pose_params:-3] |
|
transl = mano_params[:, -3:] |
|
|
|
device = mano_hand.shapedirs.device |
|
|
|
mano_args = { |
|
'global_orient': global_orient.to(device), |
|
'hand_pose' : hand_pose.to(device), |
|
'betas' : betas.to(device), |
|
'transl' : transl.to(device), |
|
} |
|
|
|
mano_outs = dict() |
|
|
|
output = mano_hand(**mano_args) |
|
mano_outs['vertices'] = output.vertices |
|
mano_outs['j3d'] = output.joints |
|
|
|
mano_outs.update(mano_args) |
|
|
|
if not self.training: |
|
mano_outs['faces'] = np.tile(mano_hand.faces, (batch_size, 1, 1)) |
|
|
|
return mano_outs |
|
|
|
|
|
class TEHNet(nn.Module): |
|
def __init__(self, n_pose_params, num_classes=4): |
|
super(TEHNet, self).__init__() |
|
|
|
normal_channel = True |
|
|
|
if normal_channel: |
|
additional_channel = 1 + int(os.getenv('ERPC', 0)) |
|
else: |
|
additional_channel = 0 |
|
|
|
self.normal_channel = normal_channel |
|
self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) |
|
self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]]) |
|
self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True) |
|
self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256]) |
|
self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128]) |
|
|
|
self.fp1 = PointNetFeaturePropagation(128, [128, 128, 256]) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Conv1d(256, 256, 1), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(256), |
|
nn.Dropout(0.3), |
|
nn.Conv1d(256, num_classes, 1) |
|
) |
|
|
|
self.attention_block = AttentionBlock() |
|
|
|
self.left_mano_regressor = MANORegressor(n_pose_params=n_pose_params) |
|
self.right_mano_regressor = MANORegressor(n_pose_params=n_pose_params) |
|
|
|
self.mhlnes = int(os.getenv('MHLNES', 0)) |
|
|
|
self.left_query_conv = nn.Sequential( |
|
nn.Conv1d(256, 256, 3, 1, 3//2), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(256), |
|
nn.Dropout(0.1), |
|
nn.Conv1d(256, 256, 3, 1, 3//2), |
|
nn.BatchNorm1d(256), |
|
) |
|
|
|
self.right_query_conv = nn.Sequential( |
|
nn.Conv1d(256, 256, 3, 1, 3//2), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(256), |
|
nn.Dropout(0.1), |
|
nn.Conv1d(256, 256, 3, 1, 3//2), |
|
nn.BatchNorm1d(256), |
|
) |
|
|
|
def forward(self, xyz, mano_hands): |
|
device = xyz.device |
|
|
|
|
|
l0_points = xyz |
|
|
|
l0_xyz = xyz[:, :3, :] |
|
|
|
if self.mhlnes: |
|
l0_xyz[:, -1, :] = xyz[:, 3:, :].mean(1) |
|
|
|
l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) |
|
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) |
|
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) |
|
|
|
|
|
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) |
|
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) |
|
l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) |
|
|
|
seg_out = self.classifier(l0_points) |
|
feat_fuse = l0_points |
|
|
|
left_hand_features = self.attention_block(seg_out, feat_fuse, self.left_query_conv(feat_fuse)) |
|
right_hand_features = self.attention_block(seg_out, feat_fuse, self.right_query_conv(feat_fuse)) |
|
|
|
left = self.left_mano_regressor(l0_xyz, left_hand_features, mano_hands['left']) |
|
right = self.right_mano_regressor(l0_xyz, right_hand_features, mano_hands['right']) |
|
|
|
return {'class_logits': seg_out, 'left': left, 'right': right} |
|
|
|
|
|
def main(): |
|
|
|
net = TEHNet(n_pose_params=6) |
|
points = torch.rand(4, 4, 128) |
|
net(points) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|