LSPW / utils /render.py
fanduluhf's picture
Upload 4 files
f460dc5 verified
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from sklearn.decomposition import PCA
from scipy.spatial.transform import Rotation as R
def render_smpl(pose_data, output_path, fps=30):
"""
Render SMPL 3D pose data to a video file.
Args:
pose_data (np.ndarray): Shape (Frames, 24, 3)
output_path (str): Path to save the MP4 video.
fps (int): Frames per second.
"""
# SMPL kinematic tree (approximate for visualization)
# 0: Pelvis
# 1: L_Hip, 2: R_Hip, 3: Spine1
# 4: L_Knee, 5: R_Knee, 6: Spine2
# 7: L_Ankle, 8: R_Ankle, 9: Spine3
# 10: L_Foot, 11: R_Foot, 12: Neck
# 13: L_Collar, 14: R_Collar, 15: Head
# 16: L_Shoulder, 17: R_Shoulder
# 18: L_Elbow, 19: R_Elbow
# 20: L_Wrist, 21: R_Wrist
# 22: L_Hand, 23: R_Hand
# Connectivity for drawing bones
connections = [
(0, 1), (0, 2), (0, 3),
(1, 4), (2, 5), (3, 6),
(4, 7), (5, 8), (6, 9),
(7, 10), (8, 11), (9, 12),
(9, 13), (9, 14), (12, 15),
(13, 16), (14, 17),
(16, 18), (17, 19),
(18, 20), (19, 21),
(20, 22), (21, 23)
]
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# --- Alignment & Centering ---
# 1. Fit plane to feet to find ground orientation
feet_indices = [10, 11] # L_Foot, R_Foot
feet_points = pose_data[:, feet_indices, :].reshape(-1, 3)
pca = PCA(n_components=3)
pca.fit(feet_points)
normal = pca.components_[2] # Component with least variance is the normal
# Calculate Body Up vector (Pelvis to Head) to determine correct up direction
# Pelvis is 0, Head is 15
pelvis_head_vector = pose_data[:, 15, :] - pose_data[:, 0, :]
avg_body_up = np.mean(pelvis_head_vector, axis=0)
# Ensure normal points in same direction as body up
if np.dot(normal, avg_body_up) < 0:
normal = -normal
# 2. Compute rotation to align normal to Z-axis [0, 0, 1]
target_normal = np.array([0, 0, 1])
# Use scipy to find rotation
# We want R such that R * normal = target_normal
# align_vectors finds rotation that maps vectors_b to vectors_a.
# So we map normal (b) to target (a).
rot, rssd = R.align_vectors([target_normal], [normal])
rot_matrix = rot.as_matrix()
# Apply rotation to all points
# Points are (Frames, Joints, 3). Flatten for transform
original_shape = pose_data.shape
flat_data = pose_data.reshape(-1, 3)
# Apply rotation: (R @ v.T).T = v @ R.T
# Scipy apply: rot.apply(vectors) handles the broadcasting
pose_data_rotated = rot.apply(flat_data)
pose_data = pose_data_rotated.reshape(original_shape)
# 3. Center trajectory
# Center X/Y at 0
all_x = pose_data[:, :, 0]
all_y = pose_data[:, :, 1]
all_z = pose_data[:, :, 2]
# Mean of all points as center (or could use root joint mean)
center_x = np.mean(all_x)
center_y = np.mean(all_y)
pose_data[:, :, 0] -= center_x
pose_data[:, :, 1] -= center_y
# Shift Z so min is 0 (Ground level)
min_z = np.min(all_z)
pose_data[:, :, 2] -= min_z
# Update bounds variables for plotting
all_x = pose_data[:, :, 0]
all_y = pose_data[:, :, 1]
all_z = pose_data[:, :, 2]
mid_x = (np.min(all_x) + np.max(all_x)) / 2
mid_y = (np.min(all_y) + np.max(all_y)) / 2
mid_z = (np.min(all_z) + np.max(all_z)) / 2
max_range = np.array([np.ptp(all_x), np.ptp(all_y), np.ptp(all_z)]).max() / 2.0
# Recalculate bounds after shift
all_x = pose_data[:, :, 0]
all_y = pose_data[:, :, 1]
all_z = pose_data[:, :, 2]
# Use (min+max)/2 for center to ensure bounding box is centered
mid_x = (np.min(all_x) + np.max(all_x)) / 2
mid_y = (np.min(all_y) + np.max(all_y)) / 2
mid_z = (np.min(all_z) + np.max(all_z)) / 2
# Dynamic ground plane bounds covering all trajectory
padding = 1.0 # Increase padding
gp_min_x = np.min(all_x) - padding
gp_max_x = np.max(all_x) + padding
gp_min_y = np.min(all_y) - padding
gp_max_y = np.max(all_y) + padding
def update(frame):
ax.clear()
ax.set_axis_off()
# Transparent gray ground plane at z=0
x = np.linspace(gp_min_x, gp_max_x, 2)
y = np.linspace(gp_min_y, gp_max_y, 2)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X) # Ground at z=0
ax.plot_surface(X, Y, Z, color='gray', alpha=0.2, shade=False)
current_pose = pose_data[frame]
# Scatter points for joints
ax.scatter(current_pose[:, 0], current_pose[:, 1], current_pose[:, 2], c='blue', s=20)
# Draw bones
for start, end in connections:
xs = [current_pose[start, 0], current_pose[end, 0]]
ys = [current_pose[start, 1], current_pose[end, 1]]
zs = [current_pose[start, 2], current_pose[end, 2]]
ax.plot(xs, ys, zs, c='red')
# Set limits
ax.set_xlim(mid_x - max_range, mid_x + max_range)
ax.set_ylim(mid_y - max_range, mid_y + max_range)
ax.set_zlim(mid_z - max_range, mid_z + max_range)
# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')
ax.set_title(f"Frame {frame}")
ani = animation.FuncAnimation(fig, update, frames=len(pose_data), interval=1000/fps)
# Save using ffmpeg writer
print(f"Saving video to {output_path}...")
try:
if animation.writers.is_available('ffmpeg'):
writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
ani.save(output_path, writer=writer)
else:
raise RuntimeError("ffmpeg not available")
except Exception as e:
print(f"ffmpeg failed or not found ({e}). Using OpenCV fallback...")
try:
import cv2
plt.close(fig) # Close the animation fig
# Re-setup figure for opencv loop
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Figure size in pixels approx (10*100 = 1000x1000 usually dpi=100)
fig.canvas.draw()
width, height = fig.canvas.get_width_height()
# Setup video writer - Try H.264 (avc1) first
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
if not out.isOpened():
print("avc1 failed. Trying h264...")
fourcc = cv2.VideoWriter_fourcc(*'h264')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
if not out.isOpened():
print("h264 failed. Trying vp80...")
fourcc = cv2.VideoWriter_fourcc(*'vp80')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
if not out.isOpened():
print("vp80 failed. Trying mp4v (less compatible)...")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
if not out.isOpened():
raise RuntimeError("Failed to open VideoWriter with any compatible codec.")
print("Rendering frames directly to OpenCV VideoWriter...")
for frame in range(len(pose_data)):
update(frame)
fig.canvas.draw()
# Convert canvas to image
# Check for buffer_rgba support (matplotlib 3.x)
try:
img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
img = img.reshape(height, width, 4)[:, :, :3] # RGBA -> RGB
except AttributeError:
# Fallback for older matplotlib or different backend
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(height, width, 3)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
out.write(img)
out.release()
plt.close(fig)
print("OpenCV fallback rendering complete.")
except Exception as cv_e:
print(f"OpenCV fallback also failed: {cv_e}")
raise cv_e
return output_path