Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from pytorch3d.transforms import axis_angle_to_quaternion | |
from network.mlp import MLPLinear | |
from utils.embedder import get_embedder | |
class HandAvatar(nn.Module): | |
def __init__(self, | |
multires = 4, | |
view_multires = -1, | |
pose_dim = 15*4): | |
super(HandAvatar, self).__init__() | |
self.pos_embedder, self.pos_dim = get_embedder(multires, 3) | |
if view_multires == -1: | |
self.view_embedder, self.view_dim = None, 0 | |
else: | |
self.view_embedder, self.view_dim = get_embedder(view_multires, 3) | |
self.pose_dim = pose_dim | |
self.tex_mlp = MLPLinear( | |
in_channels = self.pos_dim + 1 + self.view_dim + pose_dim, | |
inter_channels = [64, 64, 64, 64, 64], | |
out_channels = 3, | |
last_op = nn.Sigmoid() | |
) | |
def forward(self, cano_xyz, sdf, view_dir, hand_pose): | |
batch_size, n_pts = cano_xyz.shape[:2] | |
in_feat = torch.cat([self.pos_embedder(cano_xyz), sdf], -1) | |
hand_pose = axis_angle_to_quaternion(hand_pose.reshape(batch_size, -1, 3)).reshape(batch_size, -1) | |
if self.view_embedder is not None: | |
in_feat = torch.cat([in_feat, self.view_embedder(view_dir)], -1) | |
in_feat = torch.cat([in_feat, hand_pose[:, None].expand(-1, n_pts, -1)], -1) | |
color = self.tex_mlp(in_feat) | |
return color | |