|
|
|
|
|
import matplotlib
|
|
|
matplotlib.use('Agg')
|
|
|
import matplotlib.pyplot as plt
|
|
|
import matplotlib.animation as animation
|
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
|
import numpy as np
|
|
|
from sklearn.decomposition import PCA
|
|
|
from scipy.spatial.transform import Rotation as R
|
|
|
|
|
|
def render_smpl(pose_data, output_path, fps=30):
|
|
|
"""
|
|
|
Render SMPL 3D pose data to a video file.
|
|
|
|
|
|
Args:
|
|
|
pose_data (np.ndarray): Shape (Frames, 24, 3)
|
|
|
output_path (str): Path to save the MP4 video.
|
|
|
fps (int): Frames per second.
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
connections = [
|
|
|
(0, 1), (0, 2), (0, 3),
|
|
|
(1, 4), (2, 5), (3, 6),
|
|
|
(4, 7), (5, 8), (6, 9),
|
|
|
(7, 10), (8, 11), (9, 12),
|
|
|
(9, 13), (9, 14), (12, 15),
|
|
|
(13, 16), (14, 17),
|
|
|
(16, 18), (17, 19),
|
|
|
(18, 20), (19, 21),
|
|
|
(20, 22), (21, 23)
|
|
|
]
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 10))
|
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
|
|
|
|
|
|
|
|
feet_indices = [10, 11]
|
|
|
feet_points = pose_data[:, feet_indices, :].reshape(-1, 3)
|
|
|
|
|
|
pca = PCA(n_components=3)
|
|
|
pca.fit(feet_points)
|
|
|
normal = pca.components_[2]
|
|
|
|
|
|
|
|
|
|
|
|
pelvis_head_vector = pose_data[:, 15, :] - pose_data[:, 0, :]
|
|
|
avg_body_up = np.mean(pelvis_head_vector, axis=0)
|
|
|
|
|
|
|
|
|
if np.dot(normal, avg_body_up) < 0:
|
|
|
normal = -normal
|
|
|
|
|
|
|
|
|
target_normal = np.array([0, 0, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rot, rssd = R.align_vectors([target_normal], [normal])
|
|
|
rot_matrix = rot.as_matrix()
|
|
|
|
|
|
|
|
|
|
|
|
original_shape = pose_data.shape
|
|
|
flat_data = pose_data.reshape(-1, 3)
|
|
|
|
|
|
|
|
|
pose_data_rotated = rot.apply(flat_data)
|
|
|
pose_data = pose_data_rotated.reshape(original_shape)
|
|
|
|
|
|
|
|
|
|
|
|
all_x = pose_data[:, :, 0]
|
|
|
all_y = pose_data[:, :, 1]
|
|
|
all_z = pose_data[:, :, 2]
|
|
|
|
|
|
|
|
|
center_x = np.mean(all_x)
|
|
|
center_y = np.mean(all_y)
|
|
|
|
|
|
pose_data[:, :, 0] -= center_x
|
|
|
pose_data[:, :, 1] -= center_y
|
|
|
|
|
|
|
|
|
min_z = np.min(all_z)
|
|
|
pose_data[:, :, 2] -= min_z
|
|
|
|
|
|
|
|
|
all_x = pose_data[:, :, 0]
|
|
|
all_y = pose_data[:, :, 1]
|
|
|
all_z = pose_data[:, :, 2]
|
|
|
|
|
|
mid_x = (np.min(all_x) + np.max(all_x)) / 2
|
|
|
mid_y = (np.min(all_y) + np.max(all_y)) / 2
|
|
|
mid_z = (np.min(all_z) + np.max(all_z)) / 2
|
|
|
|
|
|
max_range = np.array([np.ptp(all_x), np.ptp(all_y), np.ptp(all_z)]).max() / 2.0
|
|
|
|
|
|
|
|
|
all_x = pose_data[:, :, 0]
|
|
|
all_y = pose_data[:, :, 1]
|
|
|
all_z = pose_data[:, :, 2]
|
|
|
|
|
|
|
|
|
mid_x = (np.min(all_x) + np.max(all_x)) / 2
|
|
|
mid_y = (np.min(all_y) + np.max(all_y)) / 2
|
|
|
mid_z = (np.min(all_z) + np.max(all_z)) / 2
|
|
|
|
|
|
|
|
|
padding = 1.0
|
|
|
gp_min_x = np.min(all_x) - padding
|
|
|
gp_max_x = np.max(all_x) + padding
|
|
|
gp_min_y = np.min(all_y) - padding
|
|
|
gp_max_y = np.max(all_y) + padding
|
|
|
|
|
|
def update(frame):
|
|
|
ax.clear()
|
|
|
ax.set_axis_off()
|
|
|
|
|
|
|
|
|
x = np.linspace(gp_min_x, gp_max_x, 2)
|
|
|
y = np.linspace(gp_min_y, gp_max_y, 2)
|
|
|
X, Y = np.meshgrid(x, y)
|
|
|
Z = np.zeros_like(X)
|
|
|
|
|
|
ax.plot_surface(X, Y, Z, color='gray', alpha=0.2, shade=False)
|
|
|
|
|
|
|
|
|
|
|
|
current_pose = pose_data[frame]
|
|
|
|
|
|
|
|
|
ax.scatter(current_pose[:, 0], current_pose[:, 1], current_pose[:, 2], c='blue', s=20)
|
|
|
|
|
|
|
|
|
for start, end in connections:
|
|
|
xs = [current_pose[start, 0], current_pose[end, 0]]
|
|
|
ys = [current_pose[start, 1], current_pose[end, 1]]
|
|
|
zs = [current_pose[start, 2], current_pose[end, 2]]
|
|
|
ax.plot(xs, ys, zs, c='red')
|
|
|
|
|
|
|
|
|
ax.set_xlim(mid_x - max_range, mid_x + max_range)
|
|
|
ax.set_ylim(mid_y - max_range, mid_y + max_range)
|
|
|
ax.set_zlim(mid_z - max_range, mid_z + max_range)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax.set_title(f"Frame {frame}")
|
|
|
|
|
|
ani = animation.FuncAnimation(fig, update, frames=len(pose_data), interval=1000/fps)
|
|
|
|
|
|
|
|
|
print(f"Saving video to {output_path}...")
|
|
|
try:
|
|
|
if animation.writers.is_available('ffmpeg'):
|
|
|
writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
|
|
|
ani.save(output_path, writer=writer)
|
|
|
else:
|
|
|
raise RuntimeError("ffmpeg not available")
|
|
|
except Exception as e:
|
|
|
print(f"ffmpeg failed or not found ({e}). Using OpenCV fallback...")
|
|
|
try:
|
|
|
import cv2
|
|
|
plt.close(fig)
|
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 10))
|
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
|
|
|
|
|
fig.canvas.draw()
|
|
|
width, height = fig.canvas.get_width_height()
|
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
|
|
|
|
if not out.isOpened():
|
|
|
print("avc1 failed. Trying h264...")
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'h264')
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
|
|
|
|
if not out.isOpened():
|
|
|
print("h264 failed. Trying vp80...")
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'vp80')
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
|
|
|
|
if not out.isOpened():
|
|
|
print("vp80 failed. Trying mp4v (less compatible)...")
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
|
|
|
|
if not out.isOpened():
|
|
|
raise RuntimeError("Failed to open VideoWriter with any compatible codec.")
|
|
|
|
|
|
print("Rendering frames directly to OpenCV VideoWriter...")
|
|
|
for frame in range(len(pose_data)):
|
|
|
update(frame)
|
|
|
fig.canvas.draw()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
|
|
img = img.reshape(height, width, 4)[:, :, :3]
|
|
|
except AttributeError:
|
|
|
|
|
|
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
|
|
img = img.reshape(height, width, 3)
|
|
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
out.write(img)
|
|
|
|
|
|
out.release()
|
|
|
plt.close(fig)
|
|
|
print("OpenCV fallback rendering complete.")
|
|
|
|
|
|
except Exception as cv_e:
|
|
|
print(f"OpenCV fallback also failed: {cv_e}")
|
|
|
raise cv_e
|
|
|
|
|
|
return output_path
|
|
|
|
|
|
|
|
|
|