ev2hands / model /TEHNet.py
chris10's picture
init
15bc41b
raw
history blame
No virus
6.87 kB
import numpy as np
import torch.nn as nn
import torch
import os
import torch.nn.functional as F
from .pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation
class AttentionBlock(nn.Module):
def __init__(self):
super(AttentionBlock, self).__init__()
def forward(self, key, value, query):
query = query.permute(0, 2, 1)
N, KC = key.shape[:2]
key = key.view(N, KC, -1)
N, KC = value.shape[:2]
value = value.view(N, KC, -1)
sim_map = torch.bmm(key, query)
sim_map = (KC ** -.5 ) * sim_map
sim_map = F.softmax(sim_map, dim=1)
context = torch.bmm(sim_map, value)
return context
class MANORegressor(nn.Module):
def __init__(self, n_inp_features=4, n_pose_params=6, n_shape_params=10):
super(MANORegressor, self).__init__()
normal_channel = True
if normal_channel:
additional_channel = n_inp_features
else:
additional_channel = 0
self.normal_channel = normal_channel
self.sa1 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], additional_channel, [[128, 128, 256], [128, 196, 256]])
self.sa2 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512], group_all=True)
self.n_pose_params = n_pose_params
self.n_mano_params = n_pose_params + n_shape_params
self.mano_regressor = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.BatchNorm1d(1024),
nn.Dropout(0.3),
nn.Linear(1024, 3 + self.n_mano_params + 3),
)
def J3dtoJ2d(self, j3d, scale):
B, N = j3d.shape[:2]
device = j3d.device
j2d = torch.zeros(B, N, 2, device=device)
j2d[:, :, 0] = scale[:, :, 0] * j3d[:, :, 0]
j2d[:, :, 1] = scale[:, :, 1] * j3d[:, :, 1]
return j2d
def forward(self, xyz, features, mano_hand, previous_mano_params=None):
device = xyz.device
batch_size = xyz.shape[0]
l0_xyz = xyz
l0_points = features
l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l2_xyz = l2_xyz.squeeze(-1)
l2_points = l2_points.squeeze(-1)
if previous_mano_params is None:
previous_mano_params = torch.zeros(self.n_mano_params).unsqueeze(0).expand(batch_size, -1).to(device)
previous_rot_trans_params = torch.zeros(6).unsqueeze(0).expand(batch_size, -1).to(device)
mano_params = self.mano_regressor(l2_points)
global_orient = mano_params[:, :3]
hand_pose = mano_params[:, 3:3+self.n_pose_params]
betas = mano_params[:, 3+self.n_pose_params:-3]
transl = mano_params[:, -3:]
device = mano_hand.shapedirs.device
mano_args = {
'global_orient': global_orient.to(device),
'hand_pose' : hand_pose.to(device),
'betas' : betas.to(device),
'transl' : transl.to(device),
}
mano_outs = dict()
output = mano_hand(**mano_args)
mano_outs['vertices'] = output.vertices
mano_outs['j3d'] = output.joints
mano_outs.update(mano_args)
if not self.training:
mano_outs['faces'] = np.tile(mano_hand.faces, (batch_size, 1, 1))
return mano_outs
class TEHNet(nn.Module):
def __init__(self, n_pose_params, num_classes=4):
super(TEHNet, self).__init__()
normal_channel = True
if normal_channel:
additional_channel = 1 + int(os.getenv('ERPC', 0))
else:
additional_channel = 0
self.normal_channel = normal_channel
self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
self.fp1 = PointNetFeaturePropagation(128, [128, 128, 256])
self.classifier = nn.Sequential(
nn.Conv1d(256, 256, 1),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.3),
nn.Conv1d(256, num_classes, 1)
)
self.attention_block = AttentionBlock()
self.left_mano_regressor = MANORegressor(n_pose_params=n_pose_params)
self.right_mano_regressor = MANORegressor(n_pose_params=n_pose_params)
self.mhlnes = int(os.getenv('MHLNES', 0))
self.left_query_conv = nn.Sequential(
nn.Conv1d(256, 256, 3, 1, 3//2),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.1),
nn.Conv1d(256, 256, 3, 1, 3//2),
nn.BatchNorm1d(256),
)
self.right_query_conv = nn.Sequential(
nn.Conv1d(256, 256, 3, 1, 3//2),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.1),
nn.Conv1d(256, 256, 3, 1, 3//2),
nn.BatchNorm1d(256),
)
def forward(self, xyz, mano_hands):
device = xyz.device
# Set Abstraction layers
l0_points = xyz
l0_xyz = xyz[:, :3, :]
if self.mhlnes:
l0_xyz[:, -1, :] = xyz[:, 3:, :].mean(1)
l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
# Feature Propagation layers
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
seg_out = self.classifier(l0_points)
feat_fuse = l0_points
left_hand_features = self.attention_block(seg_out, feat_fuse, self.left_query_conv(feat_fuse))
right_hand_features = self.attention_block(seg_out, feat_fuse, self.right_query_conv(feat_fuse))
left = self.left_mano_regressor(l0_xyz, left_hand_features, mano_hands['left'])
right = self.right_mano_regressor(l0_xyz, right_hand_features, mano_hands['right'])
return {'class_logits': seg_out, 'left': left, 'right': right}
def main():
net = TEHNet(n_pose_params=6)
points = torch.rand(4, 4, 128)
net(points)
if __name__ == '__main__':
main()