import os os.environ['ERPC'] = '1' import torch import cv2 import time import numpy as np import trimesh import arg_parser from model import TEHNetWrapper from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH def pc_normalize(pc): pc[:, 0] /= OUTPUT_WIDTH pc[:, 1] /= OUTPUT_HEIGHT pc[:, :2] = 2 * pc[:, :2] - 1 ts = pc[:, 2:] t_max = ts.max(0).values t_min = ts.min(0).values ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1 pc[:, 2:] = ts return pc def process_events(events): n_events = 2048 events[:, 2] -= events[0, 2] # normalize ts event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32) count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32) x, y, t, p = events.T x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32) np.add.at(event_grid, (y, x, 0), t) np.add.at(event_grid, (y, x, 1), p == 1) np.add.at(event_grid, (y, x, 2), p != 1) np.add.at(count_grid, (y, x), 1) yi, xi = np.nonzero(count_grid) t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi] p_evn = event_grid[yi, xi, 1] n_evn = event_grid[yi, xi, 2] events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]]) sampled_indices = np.random.choice(events.shape[0], n_events) events = events[sampled_indices] events = torch.tensor(events, dtype=torch.float32) coordinates = np.zeros((events.shape[0], 2)) event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events): y, x = y.int(), x.int() coordinates[idx] = (y, x) event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255 event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255 events[:, :3] = pc_normalize(events[:, :3]) hand_data = { 'event_frame': torch.tensor(event_frame, dtype=torch.uint8), 'events': events.permute(1, 0).unsqueeze(0), 'coordinates': torch.tensor(coordinates, dtype=torch.float32) } return hand_data def demo(net, device, data): net.eval() events = data['events'] events = events.to(device=device, dtype=torch.float32) start_time = time.time() with torch.no_grad(): outputs = net(events) end_time = time.time() N = events.shape[0] print(end_time - start_time) outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu() frames = list() for idx in range(N): hands = dict() hands['left'] = { 'vertices': outputs['left']['vertices'][idx].cpu(), 'j3d': outputs['left']['j3d'][idx].cpu(), } hands['right'] = { 'vertices': outputs['right']['vertices'][idx].cpu(), 'j3d': outputs['right']['j3d'][idx].cpu(), } coordinates = data['coordinates'] seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) for edx, (y, x) in enumerate(coordinates): y, x = y.int(), x.int() cid = outputs['class_logits'][idx][edx] if cid == 3: seg_mask[y, x] = 255 else: seg_mask[y, x, cid] = 255 hands['seg_mask'] = seg_mask frames.append(hands) return frames class Ev2Hands: def __init__(self) -> None: arg_parser.demo() device = torch.device('cpu') net = TEHNetWrapper(device=device) save_path = os.environ['CHECKPOINT_PATH'] checkpoint = torch.load(save_path, map_location=device) net.load_state_dict(checkpoint['state_dict'], strict=True) rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) mano_hands = net.hands self.net = net self.device = device self.mano_hands = mano_hands self.rot = rot def __call__(self, data): net = self.net device = self.device mano_hands = self.mano_hands rot = self.rot frame = demo(net=net, device=device, data=data)[0] seg_mask = frame['seg_mask'] pred_meshes = list() for hand_type in ['left', 'right']: faces = mano_hands[hand_type].faces pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces) pred_mesh.visual.vertex_colors = [255, 0, 0] pred_meshes.append(pred_mesh) pred_meshes = trimesh.util.concatenate(pred_meshes) pred_meshes.apply_transform(rot) return pred_meshes