import sys import os os.environ['ERPC'] = '1' import esim_py import torch import cv2 import time import pyrender import numpy as np import trimesh import arg_parser from model import TEHNetWrapper from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH def pc_normalize(pc): pc[:, 0] /= OUTPUT_WIDTH pc[:, 1] /= OUTPUT_HEIGHT pc[:, :2] = 2 * pc[:, :2] - 1 ts = pc[:, 2:] t_max = ts.max(0).values t_min = ts.min(0).values ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1 pc[:, 2:] = ts return pc def process_events(events): n_events = 2048 events[:, 2] -= events[0, 2] # normalize ts event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32) count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32) x, y, t, p = events.T x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32) np.add.at(event_grid, (y, x, 0), t) np.add.at(event_grid, (y, x, 1), p == 1) np.add.at(event_grid, (y, x, 2), p != 1) np.add.at(count_grid, (y, x), 1) yi, xi = np.nonzero(count_grid) t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi] p_evn = event_grid[yi, xi, 1] n_evn = event_grid[yi, xi, 2] events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]]) sampled_indices = np.random.choice(events.shape[0], n_events) events = events[sampled_indices] events = torch.tensor(events, dtype=torch.float32) coordinates = np.zeros((events.shape[0], 2)) event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events): y, x = y.int(), x.int() coordinates[idx] = (y, x) event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255 event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255 events[:, :3] = pc_normalize(events[:, :3]) hand_data = { 'event_frame': torch.tensor(event_frame, dtype=torch.uint8), 'events': events.permute(1, 0).unsqueeze(0), 'coordinates': torch.tensor(coordinates, dtype=torch.float32) } return hand_data def demo(net, device, data): net.eval() events = data['events'] events = events.to(device=device, dtype=torch.float32) start_time = time.time() with torch.no_grad(): outputs = net(events) end_time = time.time() N = events.shape[0] print(end_time - start_time) outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu() frames = list() for idx in range(N): hands = dict() hands['left'] = { 'vertices': outputs['left']['vertices'][idx].cpu(), 'j3d': outputs['left']['j3d'][idx].cpu(), } hands['right'] = { 'vertices': outputs['right']['vertices'][idx].cpu(), 'j3d': outputs['right']['j3d'][idx].cpu(), } coordinates = data['coordinates'] seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) for edx, (y, x) in enumerate(coordinates): y, x = y.int(), x.int() cid = outputs['class_logits'][idx][edx] if cid == 3: seg_mask[y, x] = 255 else: seg_mask[y, x, cid] = 255 hands['seg_mask'] = seg_mask frames.append(hands) return frames def main(): arg_parser.demo() os.makedirs('outputs', exist_ok=True) device = torch.device('cpu') net = TEHNetWrapper(device=device) save_path = os.environ['CHECKPOINT_PATH'] batch_size = int(os.environ['BATCH_SIZE']) checkpoint = torch.load(save_path, map_location=device) net.load_state_dict(checkpoint['state_dict'], strict=True) renderer = pyrender.OffscreenRenderer(viewport_width=OUTPUT_WIDTH, viewport_height=OUTPUT_HEIGHT) scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3)) light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) light_pose = np.eye(4) light_pose[:3, 3] = np.array([0, -1, 1]) scene.add(light, pose=light_pose) light_pose[:3, 3] = np.array([0, 1, 1]) scene.add(light, pose=light_pose) light_pose[:3, 3] = np.array([1, 1, 2]) scene.add(light, pose=light_pose) rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) mano_hands = net.hands # camera = cv2.VideoCapture(0) input_video_stream = cv2.VideoCapture('video.mp4') video_fps = 25 video = cv2.VideoWriter('outputs/video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (3 * OUTPUT_WIDTH, OUTPUT_HEIGHT)) POS_THRESHOLD = 0.5 NEG_THRESHOLD = 0.5 REF_PERIOD = 0.000 esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True) fps = cv2.CAP_PROP_FPS ts_s = 1 / fps ts_ns = ts_s * 1e9 # convert s to ns is_init = False idx = 0 while True: _, frame_bgr = input_video_stream.read() frame_bgr = cv2.resize(frame_bgr, (OUTPUT_WIDTH, OUTPUT_HEIGHT)) frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY) frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4) height, width = frame_log.shape[:2] current_ts_ns = idx * ts_ns if not is_init: esim.init(frame_log, current_ts_ns) is_init = True idx += 1 continue idx += 1 events = esim.generateEventFromCVImage(frame_log, current_ts_ns) data = process_events(events) event_frame = data['event_frame'].cpu().numpy().astype(dtype=np.uint8) cv2.imwrite(f"outputs/event_frame_{idx}.png", event_frame) print(idx, event_frame.shape) frame = demo(net=net, device=device, data=data)[0] seg_mask = frame['seg_mask'] pred_meshes = list() for hand_type in ['left', 'right']: faces = mano_hands[hand_type].faces pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces) pred_mesh.visual.vertex_colors = [255, 0, 0] pred_meshes.append(pred_mesh) pred_meshes = trimesh.util.concatenate(pred_meshes) pred_meshes.apply_transform(rot) camera = MAIN_CAMERA nc = pyrender.Node(camera=camera, matrix=np.eye(4)) scene.add_node(nc) mesh_node = pyrender.Node(mesh=pyrender.Mesh.from_trimesh(pred_meshes)) scene.add_node(mesh_node) pred_rgb, depth = renderer.render(scene) scene.remove_node(mesh_node) scene.remove_node(nc) pred_rgb = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR) pred_rgb[pred_rgb == 255] = 0 img_stack = np.hstack([event_frame, seg_mask, pred_rgb]) video.write(img_stack) cv2.imshow('image', img_stack) c = cv2.waitKey(1) if c == ord('q'): video.release() exit(0) video.release() if __name__ == '__main__': main()