|
import os |
|
os.environ['ERPC'] = '1' |
|
|
|
import torch |
|
import cv2 |
|
import time |
|
import numpy as np |
|
import trimesh |
|
|
|
import arg_parser |
|
|
|
from model import TEHNetWrapper |
|
from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH |
|
|
|
|
|
def pc_normalize(pc): |
|
pc[:, 0] /= OUTPUT_WIDTH |
|
pc[:, 1] /= OUTPUT_HEIGHT |
|
pc[:, :2] = 2 * pc[:, :2] - 1 |
|
|
|
ts = pc[:, 2:] |
|
|
|
t_max = ts.max(0).values |
|
t_min = ts.min(0).values |
|
|
|
ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1 |
|
|
|
pc[:, 2:] = ts |
|
|
|
return pc |
|
|
|
|
|
|
|
def process_events(events): |
|
n_events = 2048 |
|
|
|
events[:, 2] -= events[0, 2] |
|
|
|
event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32) |
|
count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32) |
|
|
|
x, y, t, p = events.T |
|
x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32) |
|
|
|
np.add.at(event_grid, (y, x, 0), t) |
|
np.add.at(event_grid, (y, x, 1), p == 1) |
|
np.add.at(event_grid, (y, x, 2), p != 1) |
|
|
|
np.add.at(count_grid, (y, x), 1) |
|
|
|
|
|
yi, xi = np.nonzero(count_grid) |
|
t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi] |
|
p_evn = event_grid[yi, xi, 1] |
|
n_evn = event_grid[yi, xi, 2] |
|
|
|
events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]]) |
|
|
|
sampled_indices = np.random.choice(events.shape[0], n_events) |
|
events = events[sampled_indices] |
|
|
|
events = torch.tensor(events, dtype=torch.float32) |
|
|
|
coordinates = np.zeros((events.shape[0], 2)) |
|
event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) |
|
for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events): |
|
y, x = y.int(), x.int() |
|
|
|
coordinates[idx] = (y, x) |
|
event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255 |
|
event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255 |
|
|
|
|
|
events[:, :3] = pc_normalize(events[:, :3]) |
|
|
|
hand_data = { |
|
'event_frame': torch.tensor(event_frame, dtype=torch.uint8), |
|
'events': events.permute(1, 0).unsqueeze(0), |
|
'coordinates': torch.tensor(coordinates, dtype=torch.float32) |
|
} |
|
|
|
return hand_data |
|
|
|
|
|
|
|
def demo(net, device, data): |
|
net.eval() |
|
|
|
events = data['events'] |
|
events = events.to(device=device, dtype=torch.float32) |
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
outputs = net(events) |
|
|
|
end_time = time.time() |
|
|
|
N = events.shape[0] |
|
print(end_time - start_time) |
|
|
|
outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu() |
|
|
|
frames = list() |
|
for idx in range(N): |
|
hands = dict() |
|
|
|
hands['left'] = { |
|
'vertices': outputs['left']['vertices'][idx].cpu(), |
|
'j3d': outputs['left']['j3d'][idx].cpu(), |
|
} |
|
|
|
hands['right'] = { |
|
'vertices': outputs['right']['vertices'][idx].cpu(), |
|
'j3d': outputs['right']['j3d'][idx].cpu(), |
|
} |
|
|
|
coordinates = data['coordinates'] |
|
|
|
seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8) |
|
for edx, (y, x) in enumerate(coordinates): |
|
y, x = y.int(), x.int() |
|
|
|
cid = outputs['class_logits'][idx][edx] |
|
|
|
if cid == 3: |
|
seg_mask[y, x] = 255 |
|
else: |
|
seg_mask[y, x, cid] = 255 |
|
|
|
hands['seg_mask'] = seg_mask |
|
|
|
frames.append(hands) |
|
|
|
return frames |
|
|
|
|
|
class Ev2Hands: |
|
def __init__(self) -> None: |
|
arg_parser.demo() |
|
device = torch.device('cpu') |
|
net = TEHNetWrapper(device=device) |
|
|
|
save_path = os.environ['CHECKPOINT_PATH'] |
|
|
|
checkpoint = torch.load(save_path, map_location=device) |
|
net.load_state_dict(checkpoint['state_dict'], strict=True) |
|
|
|
rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) |
|
|
|
mano_hands = net.hands |
|
|
|
self.net = net |
|
self.device = device |
|
self.mano_hands = mano_hands |
|
self.rot = rot |
|
|
|
def __call__(self, data): |
|
net = self.net |
|
device = self.device |
|
mano_hands = self.mano_hands |
|
rot = self.rot |
|
|
|
frame = demo(net=net, device=device, data=data)[0] |
|
seg_mask = frame['seg_mask'] |
|
|
|
pred_meshes = list() |
|
for hand_type in ['left', 'right']: |
|
faces = mano_hands[hand_type].faces |
|
|
|
pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces) |
|
pred_mesh.visual.vertex_colors = [255, 0, 0] |
|
pred_meshes.append(pred_mesh) |
|
|
|
pred_meshes = trimesh.util.concatenate(pred_meshes) |
|
pred_meshes.apply_transform(rot) |
|
|
|
return pred_meshes |
|
|