ev2hands / model /model.py
chris10's picture
init
15bc41b
raw
history blame contribute delete
No virus
1.85 kB
import torch
import torch.nn as nn
import trimesh
import numpy as np
from .TEHNet import TEHNet
from .utils import create_mano_layers
from settings import MANO_PATH, MANO_CMPS
class TEHNetWrapper():
def state_dict(self):
return self.net.state_dict()
def load_state_dict(self, params, *args, **kwargs):
modified_params = dict()
for k, v in params.items():
if k.startswith('module.'):
k = k[len('module.'):]
modified_params[k] = v
self.net.load_state_dict(modified_params, *args, **kwargs)
def parameters(self):
return self.net.parameters()
def train(self):
self.training = True
return self.net.train()
def eval(self):
self.training = False
return self.net.eval()
def P3dtoP2d(self, j3d, scale, translation):
B, N = j3d.shape[:2]
homogeneous_j3d = torch.cat([j3d, torch.ones(B, N, 1, device=j3d.device)], 2)
homogeneous_j3d = homogeneous_j3d @ self.rot.detach()
translation = translation.unsqueeze(1)
scale = scale.unsqueeze(1)
j2d = torch.zeros(B, N, 2, device=j3d.device)
j2d[:, :, 0] = translation[:, :, 0] + scale[:, :, 0] * homogeneous_j3d[:, :, 0]
j2d[:, :, 1] = translation[:, :, 1] + scale[:, :, 1] * homogeneous_j3d[:, :, 1]
return j2d
def __init__(self, device):
net = TEHNet(n_pose_params=MANO_CMPS).to(device)
self.net = net
self.training = False
self.hands = create_mano_layers(MANO_PATH, device, MANO_CMPS)
self.rot = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]), device=device).float()
def __call__(self, inp):
outputs = self.net(inp, self.hands)
return outputs