|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
import imageio |
|
import io |
|
|
|
def random_rotation_matrix(): |
|
|
|
q = torch.randn(4) |
|
q = q / torch.norm(q) |
|
|
|
|
|
R = torch.tensor([ |
|
[1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[3]*q[0], 2*q[1]*q[3] + 2*q[2]*q[0]], |
|
[2*q[1]*q[2] + 2*q[3]*q[0], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[1]*q[0]], |
|
[2*q[1]*q[3] - 2*q[2]*q[0], 2*q[2]*q[3] + 2*q[1]*q[0], 1 - 2*q[1]**2 - 2*q[2]**2] |
|
]) |
|
return R |
|
|
|
def augment_data(data): |
|
B, T, M = data.shape |
|
augmented_data = torch.zeros_like(data) |
|
|
|
for i in range(B): |
|
for c in range(0, M, 6): |
|
R = random_rotation_matrix().cuda() |
|
acc = data[i, :, c:c+3].transpose(0, 1) |
|
gyro = data[i, :, c+3:c+6].transpose(0, 1) |
|
|
|
|
|
rotated_acc = torch.matmul(R, acc) |
|
rotated_gyro = torch.matmul(R, gyro) |
|
|
|
|
|
augmented_data[i, :, c:c+3] = rotated_acc.transpose(0, 1) |
|
augmented_data[i, :, c+3:c+6] = rotated_gyro.transpose(0, 1) |
|
|
|
return augmented_data |
|
|
|
def update_limits(data): |
|
|
|
min_x, max_x = np.min(data[:, :, 0]), np.max(data[:, :, 0]) |
|
min_y, max_y = np.min(data[:, :, 2]), np.max(data[:, :, 2]) |
|
min_z, max_z = np.min(data[:, :, 1]), np.max(data[:, :, 1]) |
|
|
|
|
|
padding = 0.1 |
|
x_range = max_x - min_x |
|
y_range = max_y - min_y |
|
z_range = max_z - min_z |
|
|
|
return (min_x - padding * x_range, max_x + padding * x_range), \ |
|
(min_y - padding * y_range, max_y + padding * y_range), \ |
|
(min_z - padding * z_range, max_z + padding * z_range) |
|
|
|
def plot_skeleton(frame_data, xlims, ylims, zlims, dataset): |
|
""" |
|
Plot a single frame of skeleton data. |
|
""" |
|
fig = plt.figure() |
|
ax = fig.add_subplot(111, projection='3d') |
|
ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1]) |
|
|
|
|
|
if dataset == 't2m': |
|
connections = [ |
|
[0, 2, 5, 8, 11], |
|
[0, 1, 4, 7, 10], |
|
[0, 3, 6, 9, 12, 15], |
|
[9, 14, 17, 19, 21], |
|
[9, 13, 16, 18, 20] |
|
] |
|
|
|
if dataset == 'kit': |
|
connections = [ |
|
[0, 11, 12, 13, 14, 15], |
|
[0, 16, 17, 18, 19, 20], |
|
[0, 1, 2, 3, 4], |
|
[3, 5, 6, 7], |
|
[3, 8, 9, 10] |
|
] |
|
|
|
if dataset == 'ntu': |
|
connections = [ |
|
[0, 12, 13, 14, 15], |
|
[0, 16, 17, 18, 19], |
|
[0, 1, 20, 2, 3], |
|
[20, 4, 5, 6, 7, 21], |
|
[7, 22], |
|
[20, 8, 9, 10, 11, 23], |
|
[11, 24], |
|
] |
|
|
|
|
|
for connection in connections: |
|
for i in range(len(connection)-1): |
|
start_joint = connection[i] |
|
end_joint = connection[i+1] |
|
ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]], |
|
[frame_data[start_joint, 2], frame_data[end_joint, 2]], |
|
[frame_data[start_joint, 1], frame_data[end_joint, 1]]) |
|
|
|
ax.view_init(elev=10, azim=90) |
|
ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims))) |
|
|
|
ax.set_xlim(xlims) |
|
ax.set_ylim(ylims) |
|
ax.set_zlim(zlims) |
|
ax.set_xlabel('X') |
|
ax.set_ylabel('Z') |
|
ax.set_zlabel('Y') |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
img = imageio.imread(buf) |
|
buf.close() |
|
|
|
plt.close(fig) |
|
return img |
|
|
|
def plot_skeleton_gif(data, dataset): |
|
xlims, ylims, zlims = update_limits(data) |
|
images = [plot_skeleton(frame, xlims, ylims, zlims, dataset) for frame in data] |
|
imageio.mimsave('./skeleton_animation.gif', images, fps=20) |
|
return |
|
|
|
def plot_single_skeleton(data, dataset, frame=0): |
|
|
|
xlims, ylims, zlims = update_limits(data) |
|
frame_data = data[frame] |
|
|
|
fig = plt.figure() |
|
ax = fig.add_subplot(111, projection='3d') |
|
ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1]) |
|
|
|
|
|
if dataset == 't2m': |
|
connections = [ |
|
[0, 2, 5, 8, 11], |
|
[0, 1, 4, 7, 10], |
|
[0, 3, 6, 9, 12, 15], |
|
[9, 14, 17, 19, 21], |
|
[9, 13, 16, 18, 20] |
|
] |
|
|
|
if dataset == 'kit': |
|
connections = [ |
|
[0, 11, 12, 13, 14, 15], |
|
[0, 16, 17, 18, 19, 20], |
|
[0, 1, 2, 3, 4], |
|
[3, 5, 6, 7], |
|
[3, 8, 9, 10] |
|
] |
|
|
|
if dataset == 'ntu': |
|
connections = [ |
|
[0, 12, 13, 14, 15], |
|
[0, 16, 17, 18, 19], |
|
[0, 1, 20, 2, 3], |
|
[20, 4, 5, 6, 7, 21], |
|
[7, 22], |
|
[20, 8, 9, 10, 11, 23], |
|
[11, 24], |
|
] |
|
|
|
|
|
for connection in connections: |
|
for i in range(len(connection)-1): |
|
start_joint = connection[i] |
|
end_joint = connection[i+1] |
|
ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]], |
|
[frame_data[start_joint, 2], frame_data[end_joint, 2]], |
|
[frame_data[start_joint, 1], frame_data[end_joint, 1]]) |
|
|
|
|
|
ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims))) |
|
|
|
ax.set_xlim(xlims) |
|
ax.set_ylim(ylims) |
|
ax.set_zlim(zlims) |
|
|
|
ax.set_xlabel('X') |
|
ax.set_ylabel('Z') |
|
ax.set_zlabel('Y') |
|
|
|
plt.savefig('skeleton.pdf', bbox_inches='tight') |
|
|
|
def compute_height(joints, head_index, l_foot_index, r_foot_index): |
|
joints = torch.from_numpy(joints) |
|
left = (joints[:,head_index,1] - joints[:,l_foot_index,1])[0] |
|
right = (joints[:,head_index,1] - joints[:,r_foot_index,1])[0] |
|
height = (left + right) / 2 |
|
return height |
|
|
|
def compute_metrics_np(similarity_matrix, correct_labels): |
|
|
|
B, _ = similarity_matrix.shape |
|
|
|
ranked_indices = np.argsort(-similarity_matrix, axis=1) |
|
|
|
correct_label_ranks = np.array([np.where(ranked_indices[i] == correct_labels[i])[0][0] for i in range(B)]) + 1 |
|
|
|
|
|
R_at_1 = np.mean(correct_label_ranks <= 1) |
|
R_at_2 = np.mean(correct_label_ranks <= 2) |
|
R_at_3 = np.mean(correct_label_ranks <= 3) |
|
R_at_4 = np.mean(correct_label_ranks <= 4) |
|
R_at_5 = np.mean(correct_label_ranks <= 5) |
|
|
|
|
|
MRR = np.mean(1.0 / correct_label_ranks) |
|
|
|
return R_at_1, R_at_2, R_at_3, R_at_4, R_at_5, MRR |
|
|