|
import sys |
|
import os |
|
os.environ['ERPC'] = '1' |
|
|
|
import esim_py |
|
|
|
import torch |
|
import cv2 |
|
import time |
|
import pyrender |
|
import numpy as np |
|
import trimesh |
|
|
|
import arg_parser |
|
|
|
from model import TEHNetWrapper |
|
from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH |
|
|
|
|
|
def pc_normalize(pc): |
|
pc[:, 0] /= OUTPUT_WIDTH |
|
pc[:, 1] /= OUTPUT_HEIGHT |
|
pc[:, :2] = 2 * pc[:, :2] - 1 |
|
|
|
ts = pc[:, 2:] |
|
|
|
t_max = ts.max(0).values |
|
t_min = ts.min(0).values |
|
|
|
ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1 |
|
|
|
pc[:, 2:] = ts |
|
|
|
return pc |
|
|
|
|
|
|
|
def process_events(events): |
|
n_events = 2048 |
|
|
|
events[:, 2] -= events[0, 2] |
|
|
|
event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32) |
|
count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32) |
|
|
|
x, y, t, p = events.T |
|
x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32) |
|
|
|
np.add.at(event_grid, (y, x, 0), t) |
|
np.add.at(event_grid, (y, x, 1), p == 1) |
|
np.add.at(event_grid, (y, x, 2), p != 1) |
|
|
|
np.add.at(count_grid, (y, x), 1) |
|
|
|
|
|
yi, xi = np.nonzero(count_grid) |
|
t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi] |
|
p_evn = event_grid[yi, xi, 1] |
|
n_evn = event_grid[yi, xi, 2] |
|
|
|
events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]]) |
|
|
|
sampled_indices = np.random.choice(events.shape[0], n_events) |
|
events = events[sampled_indices] |
|
|
|
events = torch.tensor(events, dtype=torch.float32) |
|
|
|
coordinates = np.zeros((events.shape[0], 2)) |
|
event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) |
|
for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events): |
|
y, x = y.int(), x.int() |
|
|
|
coordinates[idx] = (y, x) |
|
event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255 |
|
event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255 |
|
|
|
|
|
events[:, :3] = pc_normalize(events[:, :3]) |
|
|
|
hand_data = { |
|
'event_frame': torch.tensor(event_frame, dtype=torch.uint8), |
|
'events': events.permute(1, 0).unsqueeze(0), |
|
'coordinates': torch.tensor(coordinates, dtype=torch.float32) |
|
} |
|
|
|
return hand_data |
|
|
|
|
|
|
|
def demo(net, device, data): |
|
net.eval() |
|
|
|
events = data['events'] |
|
events = events.to(device=device, dtype=torch.float32) |
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
outputs = net(events) |
|
|
|
end_time = time.time() |
|
|
|
N = events.shape[0] |
|
print(end_time - start_time) |
|
|
|
outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu() |
|
|
|
frames = list() |
|
for idx in range(N): |
|
hands = dict() |
|
|
|
hands['left'] = { |
|
'vertices': outputs['left']['vertices'][idx].cpu(), |
|
'j3d': outputs['left']['j3d'][idx].cpu(), |
|
} |
|
|
|
hands['right'] = { |
|
'vertices': outputs['right']['vertices'][idx].cpu(), |
|
'j3d': outputs['right']['j3d'][idx].cpu(), |
|
} |
|
|
|
coordinates = data['coordinates'] |
|
|
|
seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) |
|
for edx, (y, x) in enumerate(coordinates): |
|
y, x = y.int(), x.int() |
|
|
|
cid = outputs['class_logits'][idx][edx] |
|
|
|
if cid == 3: |
|
seg_mask[y, x] = 255 |
|
else: |
|
seg_mask[y, x, cid] = 255 |
|
|
|
hands['seg_mask'] = seg_mask |
|
|
|
frames.append(hands) |
|
|
|
return frames |
|
|
|
|
|
|
|
def main(): |
|
arg_parser.demo() |
|
os.makedirs('outputs', exist_ok=True) |
|
|
|
device = torch.device('cpu') |
|
|
|
net = TEHNetWrapper(device=device) |
|
|
|
save_path = os.environ['CHECKPOINT_PATH'] |
|
batch_size = int(os.environ['BATCH_SIZE']) |
|
|
|
checkpoint = torch.load(save_path, map_location=device) |
|
net.load_state_dict(checkpoint['state_dict'], strict=True) |
|
|
|
renderer = pyrender.OffscreenRenderer(viewport_width=OUTPUT_WIDTH, viewport_height=OUTPUT_HEIGHT) |
|
|
|
scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3)) |
|
light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) |
|
light_pose = np.eye(4) |
|
light_pose[:3, 3] = np.array([0, -1, 1]) |
|
scene.add(light, pose=light_pose) |
|
light_pose[:3, 3] = np.array([0, 1, 1]) |
|
scene.add(light, pose=light_pose) |
|
light_pose[:3, 3] = np.array([1, 1, 2]) |
|
scene.add(light, pose=light_pose) |
|
|
|
rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) |
|
|
|
mano_hands = net.hands |
|
|
|
|
|
input_video_stream = cv2.VideoCapture('video.mp4') |
|
|
|
|
|
|
|
video_fps = 25 |
|
video = cv2.VideoWriter('outputs/video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (3 * OUTPUT_WIDTH, OUTPUT_HEIGHT)) |
|
|
|
|
|
POS_THRESHOLD = 0.5 |
|
NEG_THRESHOLD = 0.5 |
|
REF_PERIOD = 0.000 |
|
|
|
esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True) |
|
|
|
|
|
fps = cv2.CAP_PROP_FPS |
|
ts_s = 1 / fps |
|
ts_ns = ts_s * 1e9 |
|
|
|
is_init = False |
|
idx = 0 |
|
while True: |
|
_, frame_bgr = input_video_stream.read() |
|
frame_bgr = cv2.resize(frame_bgr, (OUTPUT_WIDTH, OUTPUT_HEIGHT)) |
|
frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY) |
|
frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4) |
|
height, width = frame_log.shape[:2] |
|
|
|
current_ts_ns = idx * ts_ns |
|
|
|
if not is_init: |
|
esim.init(frame_log, current_ts_ns) |
|
is_init = True |
|
idx += 1 |
|
|
|
continue |
|
idx += 1 |
|
|
|
events = esim.generateEventFromCVImage(frame_log, current_ts_ns) |
|
data = process_events(events) |
|
|
|
event_frame = data['event_frame'].cpu().numpy().astype(dtype=np.uint8) |
|
|
|
cv2.imwrite(f"outputs/event_frame_{idx}.png", event_frame) |
|
|
|
print(idx, event_frame.shape) |
|
|
|
frame = demo(net=net, device=device, data=data)[0] |
|
seg_mask = frame['seg_mask'] |
|
|
|
pred_meshes = list() |
|
for hand_type in ['left', 'right']: |
|
faces = mano_hands[hand_type].faces |
|
|
|
pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces) |
|
pred_mesh.visual.vertex_colors = [255, 0, 0] |
|
pred_meshes.append(pred_mesh) |
|
|
|
pred_meshes = trimesh.util.concatenate(pred_meshes) |
|
pred_meshes.apply_transform(rot) |
|
|
|
camera = MAIN_CAMERA |
|
|
|
nc = pyrender.Node(camera=camera, matrix=np.eye(4)) |
|
scene.add_node(nc) |
|
|
|
mesh_node = pyrender.Node(mesh=pyrender.Mesh.from_trimesh(pred_meshes)) |
|
scene.add_node(mesh_node) |
|
pred_rgb, depth = renderer.render(scene) |
|
scene.remove_node(mesh_node) |
|
scene.remove_node(nc) |
|
|
|
pred_rgb = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR) |
|
pred_rgb[pred_rgb == 255] = 0 |
|
|
|
img_stack = np.hstack([event_frame, seg_mask, pred_rgb]) |
|
video.write(img_stack) |
|
|
|
cv2.imshow('image', img_stack) |
|
c = cv2.waitKey(1) |
|
|
|
if c == ord('q'): |
|
video.release() |
|
exit(0) |
|
|
|
video.release() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|