ev2hands / infererence.py
chris10's picture
init
d965e49
raw
history blame
4.71 kB
import os
os.environ['ERPC'] = '1'
import torch
import cv2
import time
import numpy as np
import trimesh
import arg_parser
from model import TEHNetWrapper
from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH
def pc_normalize(pc):
pc[:, 0] /= OUTPUT_WIDTH
pc[:, 1] /= OUTPUT_HEIGHT
pc[:, :2] = 2 * pc[:, :2] - 1
ts = pc[:, 2:]
t_max = ts.max(0).values
t_min = ts.min(0).values
ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1
pc[:, 2:] = ts
return pc
def process_events(events):
n_events = 2048
events[:, 2] -= events[0, 2] # normalize ts
event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32)
count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32)
x, y, t, p = events.T
x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32)
np.add.at(event_grid, (y, x, 0), t)
np.add.at(event_grid, (y, x, 1), p == 1)
np.add.at(event_grid, (y, x, 2), p != 1)
np.add.at(count_grid, (y, x), 1)
yi, xi = np.nonzero(count_grid)
t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi]
p_evn = event_grid[yi, xi, 1]
n_evn = event_grid[yi, xi, 2]
events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]])
sampled_indices = np.random.choice(events.shape[0], n_events)
events = events[sampled_indices]
events = torch.tensor(events, dtype=torch.float32)
coordinates = np.zeros((events.shape[0], 2))
event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8)
for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events):
y, x = y.int(), x.int()
coordinates[idx] = (y, x)
event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255
event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255
events[:, :3] = pc_normalize(events[:, :3])
hand_data = {
'event_frame': torch.tensor(event_frame, dtype=torch.uint8),
'events': events.permute(1, 0).unsqueeze(0),
'coordinates': torch.tensor(coordinates, dtype=torch.float32)
}
return hand_data
def demo(net, device, data):
net.eval()
events = data['events']
events = events.to(device=device, dtype=torch.float32)
start_time = time.time()
with torch.no_grad():
outputs = net(events)
end_time = time.time()
N = events.shape[0]
print(end_time - start_time)
outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu()
frames = list()
for idx in range(N):
hands = dict()
hands['left'] = {
'vertices': outputs['left']['vertices'][idx].cpu(),
'j3d': outputs['left']['j3d'][idx].cpu(),
}
hands['right'] = {
'vertices': outputs['right']['vertices'][idx].cpu(),
'j3d': outputs['right']['j3d'][idx].cpu(),
}
coordinates = data['coordinates']
seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8)
for edx, (y, x) in enumerate(coordinates):
y, x = y.int(), x.int()
cid = outputs['class_logits'][idx][edx]
if cid == 3:
seg_mask[y, x] = 255
else:
seg_mask[y, x, cid] = 255
hands['seg_mask'] = seg_mask
frames.append(hands)
return frames
class Ev2Hands:
def __init__(self) -> None:
arg_parser.demo()
device = torch.device('cpu')
net = TEHNetWrapper(device=device)
save_path = os.environ['CHECKPOINT_PATH']
checkpoint = torch.load(save_path, map_location=device)
net.load_state_dict(checkpoint['state_dict'], strict=True)
rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
mano_hands = net.hands
self.net = net
self.device = device
self.mano_hands = mano_hands
self.rot = rot
def __call__(self, data):
net = self.net
device = self.device
mano_hands = self.mano_hands
rot = self.rot
frame = demo(net=net, device=device, data=data)[0]
seg_mask = frame['seg_mask']
pred_meshes = list()
for hand_type in ['left', 'right']:
faces = mano_hands[hand_type].faces
pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces)
pred_mesh.visual.vertex_colors = [255, 0, 0]
pred_meshes.append(pred_mesh)
pred_meshes = trimesh.util.concatenate(pred_meshes)
pred_meshes.apply_transform(rot)
return pred_meshes