ev2hands / demo.py
chris10's picture
init
15bc41b
raw
history blame
No virus
7.14 kB
import sys
import os
os.environ['ERPC'] = '1'
import esim_py
import torch
import cv2
import time
import pyrender
import numpy as np
import trimesh
import arg_parser
from model import TEHNetWrapper
from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH
def pc_normalize(pc):
pc[:, 0] /= OUTPUT_WIDTH
pc[:, 1] /= OUTPUT_HEIGHT
pc[:, :2] = 2 * pc[:, :2] - 1
ts = pc[:, 2:]
t_max = ts.max(0).values
t_min = ts.min(0).values
ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1
pc[:, 2:] = ts
return pc
def process_events(events):
n_events = 2048
events[:, 2] -= events[0, 2] # normalize ts
event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32)
count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32)
x, y, t, p = events.T
x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32)
np.add.at(event_grid, (y, x, 0), t)
np.add.at(event_grid, (y, x, 1), p == 1)
np.add.at(event_grid, (y, x, 2), p != 1)
np.add.at(count_grid, (y, x), 1)
yi, xi = np.nonzero(count_grid)
t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi]
p_evn = event_grid[yi, xi, 1]
n_evn = event_grid[yi, xi, 2]
events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]])
sampled_indices = np.random.choice(events.shape[0], n_events)
events = events[sampled_indices]
events = torch.tensor(events, dtype=torch.float32)
coordinates = np.zeros((events.shape[0], 2))
event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8)
for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events):
y, x = y.int(), x.int()
coordinates[idx] = (y, x)
event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255
event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255
events[:, :3] = pc_normalize(events[:, :3])
hand_data = {
'event_frame': torch.tensor(event_frame, dtype=torch.uint8),
'events': events.permute(1, 0).unsqueeze(0),
'coordinates': torch.tensor(coordinates, dtype=torch.float32)
}
return hand_data
def demo(net, device, data):
net.eval()
events = data['events']
events = events.to(device=device, dtype=torch.float32)
start_time = time.time()
with torch.no_grad():
outputs = net(events)
end_time = time.time()
N = events.shape[0]
print(end_time - start_time)
outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu()
frames = list()
for idx in range(N):
hands = dict()
hands['left'] = {
'vertices': outputs['left']['vertices'][idx].cpu(),
'j3d': outputs['left']['j3d'][idx].cpu(),
}
hands['right'] = {
'vertices': outputs['right']['vertices'][idx].cpu(),
'j3d': outputs['right']['j3d'][idx].cpu(),
}
coordinates = data['coordinates']
seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8)
for edx, (y, x) in enumerate(coordinates):
y, x = y.int(), x.int()
cid = outputs['class_logits'][idx][edx]
if cid == 3:
seg_mask[y, x] = 255
else:
seg_mask[y, x, cid] = 255
hands['seg_mask'] = seg_mask
frames.append(hands)
return frames
def main():
arg_parser.demo()
os.makedirs('outputs', exist_ok=True)
device = torch.device('cpu')
net = TEHNetWrapper(device=device)
save_path = os.environ['CHECKPOINT_PATH']
batch_size = int(os.environ['BATCH_SIZE'])
checkpoint = torch.load(save_path, map_location=device)
net.load_state_dict(checkpoint['state_dict'], strict=True)
renderer = pyrender.OffscreenRenderer(viewport_width=OUTPUT_WIDTH, viewport_height=OUTPUT_HEIGHT)
scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
light_pose = np.eye(4)
light_pose[:3, 3] = np.array([0, -1, 1])
scene.add(light, pose=light_pose)
light_pose[:3, 3] = np.array([0, 1, 1])
scene.add(light, pose=light_pose)
light_pose[:3, 3] = np.array([1, 1, 2])
scene.add(light, pose=light_pose)
rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
mano_hands = net.hands
# camera = cv2.VideoCapture(0)
input_video_stream = cv2.VideoCapture('video.mp4')
video_fps = 25
video = cv2.VideoWriter('outputs/video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (3 * OUTPUT_WIDTH, OUTPUT_HEIGHT))
POS_THRESHOLD = 0.5
NEG_THRESHOLD = 0.5
REF_PERIOD = 0.000
esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True)
fps = cv2.CAP_PROP_FPS
ts_s = 1 / fps
ts_ns = ts_s * 1e9 # convert s to ns
is_init = False
idx = 0
while True:
_, frame_bgr = input_video_stream.read()
frame_bgr = cv2.resize(frame_bgr, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4)
height, width = frame_log.shape[:2]
current_ts_ns = idx * ts_ns
if not is_init:
esim.init(frame_log, current_ts_ns)
is_init = True
idx += 1
continue
idx += 1
events = esim.generateEventFromCVImage(frame_log, current_ts_ns)
data = process_events(events)
event_frame = data['event_frame'].cpu().numpy().astype(dtype=np.uint8)
cv2.imwrite(f"outputs/event_frame_{idx}.png", event_frame)
print(idx, event_frame.shape)
frame = demo(net=net, device=device, data=data)[0]
seg_mask = frame['seg_mask']
pred_meshes = list()
for hand_type in ['left', 'right']:
faces = mano_hands[hand_type].faces
pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces)
pred_mesh.visual.vertex_colors = [255, 0, 0]
pred_meshes.append(pred_mesh)
pred_meshes = trimesh.util.concatenate(pred_meshes)
pred_meshes.apply_transform(rot)
camera = MAIN_CAMERA
nc = pyrender.Node(camera=camera, matrix=np.eye(4))
scene.add_node(nc)
mesh_node = pyrender.Node(mesh=pyrender.Mesh.from_trimesh(pred_meshes))
scene.add_node(mesh_node)
pred_rgb, depth = renderer.render(scene)
scene.remove_node(mesh_node)
scene.remove_node(nc)
pred_rgb = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
pred_rgb[pred_rgb == 255] = 0
img_stack = np.hstack([event_frame, seg_mask, pred_rgb])
video.write(img_stack)
cv2.imshow('image', img_stack)
c = cv2.waitKey(1)
if c == ord('q'):
video.release()
exit(0)
video.release()
if __name__ == '__main__':
main()