File size: 1,845 Bytes
15bc41b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import torch.nn as nn
import trimesh
import numpy as np
from .TEHNet import TEHNet
from .utils import create_mano_layers
from settings import MANO_PATH, MANO_CMPS
class TEHNetWrapper():
def state_dict(self):
return self.net.state_dict()
def load_state_dict(self, params, *args, **kwargs):
modified_params = dict()
for k, v in params.items():
if k.startswith('module.'):
k = k[len('module.'):]
modified_params[k] = v
self.net.load_state_dict(modified_params, *args, **kwargs)
def parameters(self):
return self.net.parameters()
def train(self):
self.training = True
return self.net.train()
def eval(self):
self.training = False
return self.net.eval()
def P3dtoP2d(self, j3d, scale, translation):
B, N = j3d.shape[:2]
homogeneous_j3d = torch.cat([j3d, torch.ones(B, N, 1, device=j3d.device)], 2)
homogeneous_j3d = homogeneous_j3d @ self.rot.detach()
translation = translation.unsqueeze(1)
scale = scale.unsqueeze(1)
j2d = torch.zeros(B, N, 2, device=j3d.device)
j2d[:, :, 0] = translation[:, :, 0] + scale[:, :, 0] * homogeneous_j3d[:, :, 0]
j2d[:, :, 1] = translation[:, :, 1] + scale[:, :, 1] * homogeneous_j3d[:, :, 1]
return j2d
def __init__(self, device):
net = TEHNet(n_pose_params=MANO_CMPS).to(device)
self.net = net
self.training = False
self.hands = create_mano_layers(MANO_PATH, device, MANO_CMPS)
self.rot = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]), device=device).float()
def __call__(self, inp):
outputs = self.net(inp, self.hands)
return outputs
|