import torch import torch.nn as nn import trimesh import numpy as np from .TEHNet import TEHNet from .utils import create_mano_layers from settings import MANO_PATH, MANO_CMPS class TEHNetWrapper(): def state_dict(self): return self.net.state_dict() def load_state_dict(self, params, *args, **kwargs): modified_params = dict() for k, v in params.items(): if k.startswith('module.'): k = k[len('module.'):] modified_params[k] = v self.net.load_state_dict(modified_params, *args, **kwargs) def parameters(self): return self.net.parameters() def train(self): self.training = True return self.net.train() def eval(self): self.training = False return self.net.eval() def P3dtoP2d(self, j3d, scale, translation): B, N = j3d.shape[:2] homogeneous_j3d = torch.cat([j3d, torch.ones(B, N, 1, device=j3d.device)], 2) homogeneous_j3d = homogeneous_j3d @ self.rot.detach() translation = translation.unsqueeze(1) scale = scale.unsqueeze(1) j2d = torch.zeros(B, N, 2, device=j3d.device) j2d[:, :, 0] = translation[:, :, 0] + scale[:, :, 0] * homogeneous_j3d[:, :, 0] j2d[:, :, 1] = translation[:, :, 1] + scale[:, :, 1] * homogeneous_j3d[:, :, 1] return j2d def __init__(self, device): net = TEHNet(n_pose_params=MANO_CMPS).to(device) self.net = net self.training = False self.hands = create_mano_layers(MANO_PATH, device, MANO_CMPS) self.rot = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]), device=device).float() def __call__(self, inp): outputs = self.net(inp, self.hands) return outputs