File size: 5,282 Bytes
aeba71c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import smplx
from constants import *
from scipy.interpolate import interp1d
from torch import nn, einsum
import pytorch3d as T


class JointsToSMPLX(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, **kwargs):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            # nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.layers(x)


def optimize_smpl(pose_pred, joints, joints_ind, hand_pca=45):
    device = joints.device
    len = joints.shape[0]

    smpl_model = smplx.create('./smpl_models', model_type='smplx',
                              gender='male', ext='npz',
                              num_betas=10,
                              use_pca=False,
                              create_global_orient=True,
                              create_body_pose=True,
                              create_betas=True,
                              create_left_hand_pose=True,
                              create_right_hand_pose=True,
                              create_expression=True,
                              create_jaw_pose=True,
                              create_leye_pose=True,
                              create_reye_pose=True,
                              create_transl=True,
                              batch_size=len,
                              ).to(device)
    smpl_model.eval()

    # weights = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100]).reshape(nb_joints, 1).repeat(1, 3).to(device)
    joints = joints.reshape(len, -1, 3) + torch.tensor(pelvis_shift).to(device)
    pose_input = torch.nn.Parameter(pose_pred.detach(), requires_grad=True)
    transl = torch.nn.Parameter(torch.zeros(pose_pred.shape[0], 3).to(device), requires_grad=True)
    # left_hand = torch.nn.Parameter(torch.zeros(pose_pred.shape[0], hand_pca).to(device), requires_grad=True)
    # right_hand = torch.nn.Parameter(torch.zeros(pose_pred.shape[0], hand_pca).to(device), requires_grad=True)
    left_hand = torch.from_numpy(relaxed_hand_pose[:45].reshape(1, -1).repeat(pose_pred.shape[0], axis=0)).to(device)
    right_hand = torch.from_numpy(relaxed_hand_pose[45:].reshape(1, -1).repeat(pose_pred.shape[0], axis=0)).to(device)
    optimizer = torch.optim.Adam(params=[pose_input, transl], lr=0.05)
    loss_fn = nn.MSELoss()
    vertices_output = None
    for step in range(100):
        smpl_output = smpl_model(transl=transl, body_pose=pose_input[:, 3:], global_orient=pose_input[:, :3], return_verts=True,
                                 left_hand_pose=left_hand,# @ left_hand_components[:hand_pca],
                                 right_hand_pose=right_hand,# @ right_hand_components[:hand_pca],
                                 )
        joints_output = smpl_output.joints[:, joints_ind].reshape(len, -1, 3)
        vertices_output = smpl_output.vertices[:, ::10].detach().cpu().numpy()
        loss = loss_fn(joints[:, :], joints_output[:, :])
        # loss = torch.mean((joints - joints_output) ** 2 * weights)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(loss.item())



    #left_hand = left_hand @ left_hand_components[:hand_pca]
    #right_hand = right_hand @ right_hand_components[:hand_pca]

    return pose_input.detach().cpu().numpy(), transl.detach().cpu().numpy(), left_hand.detach().cpu().numpy(), right_hand.detach().cpu().numpy(), vertices_output


def joints_to_smpl(model, joints, joints_ind, interp_s):
    joints = interpolate_joints(joints, scale=interp_s)
    # joints = interpolate_joints(joints, scale=0.33)
    # joints = interpolate_joints(joints, scale=interp_s * 3)
    input_len = joints.shape[0]
    joints = joints.reshape(input_len, -1, 3)
    joints = joints.permute(1, 0, 2)
    trans_np = joints[0].detach().cpu().numpy()
    joints = joints - joints[0]
    joints = joints.permute(1, 0, 2)
    joints = joints.reshape(input_len, -1)
    pose_pred = model(joints)
    pose_pred = pose_pred.reshape(-1, 6)
    pose_pred = T.matrix_to_axis_angle(T.rotation_6d_to_matrix(pose_pred)).reshape(input_len, -1)
    # pose_pred = pose_pred[:seq_len]
    pose_output, transl, left_hand, right_hand, vertices = optimize_smpl(pose_pred, joints, joints_ind)

    transl = trans_np - np.array(pelvis_shift) + transl

    vertices = vertices + transl.reshape(-1, 1, 3)


    return pose_output, transl, left_hand, right_hand, vertices


def interpolate_joints(joints, scale):
    if scale == 1:
        return joints
    device = joints.device
    joints = joints.detach().cpu().numpy()
    in_len = joints.shape[0]
    out_len = int(in_len * scale)
    joints = joints.reshape(in_len, -1)
    x = np.array(range(in_len))
    xnew = np.linspace(0, in_len - 1, out_len)
    f = interp1d(x, joints, axis=0)
    joints_new = f(xnew)
    joints_new = torch.from_numpy(joints_new).to(device).float()

    return  joints_new