Spaces:
Running
Running
# From TEMOS: temos/render/anim.py | |
# Inspired by | |
# - https://github.com/anindita127/Complextext2animation/blob/main/src/utils/visualization.py | |
# - https://github.com/facebookresearch/QuaterNet/blob/main/common/visualization.py | |
import os | |
import logging | |
from dataclasses import dataclass | |
from typing import List, Tuple, Optional | |
import numpy as np | |
from src.tools.rifke import canonicalize_rotation | |
logger = logging.getLogger("matplotlib.animation") | |
logger.setLevel(logging.ERROR) | |
colors = ("black", "magenta", "red", "green", "blue") | |
KINEMATIC_TREES = { | |
"smpljoints": [ | |
[0, 3, 6, 9, 12, 15], | |
[9, 13, 16, 18, 20], | |
[9, 14, 17, 19, 21], | |
[0, 1, 4, 7, 10], | |
[0, 2, 5, 8, 11], | |
], | |
"guoh3djoints": [ # no hands | |
[0, 3, 6, 9, 12, 15], | |
[9, 13, 16, 18, 20], | |
[9, 14, 17, 19, 21], | |
[0, 1, 4, 7, 10], | |
[0, 2, 5, 8, 11], | |
], | |
} | |
class MatplotlibRender: | |
jointstype: str = "smpljoints" | |
fps: float = 20.0 | |
colors: List[str] = colors | |
figsize: int = 4 | |
fontsize: int = 15 | |
canonicalize: bool = False | |
def __call__( | |
self, | |
joints, | |
output, | |
fps=None, | |
highlights=None, | |
title: str = "", | |
canonicalize=None, | |
): | |
canonicalize = canonicalize if canonicalize is not None else self.canonicalize | |
fps = fps if fps is not None else self.fps | |
if joints.shape[1] == 24: | |
# remove the hands | |
joints = joints[:, :22] | |
render_animation( | |
joints, | |
title=title, | |
highlights=highlights, | |
output=output, | |
jointstype=self.jointstype, | |
fps=self.fps, | |
colors=self.colors, | |
figsize=(self.figsize, self.figsize), | |
fontsize=self.fontsize, | |
canonicalize=canonicalize, | |
) | |
def init_axis(fig, title, radius=1.5): | |
ax = fig.add_subplot(1, 1, 1, projection="3d") | |
ax.view_init(elev=20.0, azim=-60) | |
fact = 2 | |
ax.set_xlim3d([-radius / fact, radius / fact]) | |
ax.set_ylim3d([-radius / fact, radius / fact]) | |
ax.set_zlim3d([0, radius]) | |
ax.set_aspect("auto") | |
ax.set_xticklabels([]) | |
ax.set_yticklabels([]) | |
ax.set_zticklabels([]) | |
ax.set_axis_off() | |
ax.grid(b=False) | |
ax.set_title(title, loc="center", wrap=True) | |
return ax | |
def plot_floor(ax, minx, maxx, miny, maxy, minz): | |
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
# Plot a plane XZ | |
verts = [ | |
[minx, miny, minz], | |
[minx, maxy, minz], | |
[maxx, maxy, minz], | |
[maxx, miny, minz], | |
] | |
xz_plane = Poly3DCollection([verts], zorder=1) | |
xz_plane.set_facecolor((0.5, 0.5, 0.5, 1)) | |
ax.add_collection3d(xz_plane) | |
# Plot a bigger square plane XZ | |
radius = max((maxx - minx), (maxy - miny)) | |
# center +- radius | |
minx_all = (maxx + minx) / 2 - radius | |
maxx_all = (maxx + minx) / 2 + radius | |
miny_all = (maxy + miny) / 2 - radius | |
maxy_all = (maxy + miny) / 2 + radius | |
verts = [ | |
[minx_all, miny_all, minz], | |
[minx_all, maxy_all, minz], | |
[maxx_all, maxy_all, minz], | |
[maxx_all, miny_all, minz], | |
] | |
xz_plane = Poly3DCollection([verts], zorder=1) | |
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) | |
ax.add_collection3d(xz_plane) | |
return ax | |
def update_camera(ax, root, radius=1.5): | |
fact = 2 | |
ax.set_xlim3d([-radius / fact + root[0], radius / fact + root[0]]) | |
ax.set_ylim3d([-radius / fact + root[1], radius / fact + root[1]]) | |
def render_animation( | |
joints: np.ndarray, | |
output: str = "notebook", | |
highlights: Optional[np.ndarray] = None, | |
jointstype: str = "smpljoints", | |
title: str = "", | |
fps: float = 20.0, | |
colors: List[str] = colors, | |
figsize: Tuple[int] = (4, 4), | |
fontsize: int = 15, | |
canonicalize: bool = False, | |
agg=True, | |
): | |
if agg: | |
import matplotlib | |
matplotlib.use("Agg") | |
if highlights is not None: | |
assert len(highlights) == len(joints) | |
assert jointstype in KINEMATIC_TREES | |
kinematic_tree = KINEMATIC_TREES[jointstype] | |
import matplotlib.pyplot as plt | |
from matplotlib.animation import FuncAnimation | |
import matplotlib.patheffects as pe | |
mean_fontsize = fontsize | |
# heuristic to change fontsize | |
fontsize = mean_fontsize - (len(title) - 30) / 20 | |
plt.rcParams.update({"font.size": fontsize}) | |
# Z is gravity here | |
x, y, z = 0, 1, 2 | |
joints = joints.copy() | |
if canonicalize: | |
joints = canonicalize_rotation(joints, jointstype=jointstype) | |
# Create a figure and initialize 3d plot | |
fig = plt.figure(figsize=figsize) | |
ax = init_axis(fig, title) | |
# Create spline line | |
trajectory = joints[:, 0, [x, y]] | |
avg_segment_length = ( | |
np.mean(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) + 1e-3 | |
) | |
draw_offset = int(25 / avg_segment_length) | |
(spline_line,) = ax.plot(*trajectory.T, zorder=10, color="white") | |
# Create a floor | |
minx, miny, _ = joints.min(axis=(0, 1)) | |
maxx, maxy, _ = joints.max(axis=(0, 1)) | |
plot_floor(ax, minx, maxx, miny, maxy, 0) | |
# Put the character on the floor | |
height_offset = np.min(joints[:, :, z]) # Min height | |
joints = joints.copy() | |
joints[:, :, z] -= height_offset | |
# Initialization for redrawing | |
lines = [] | |
initialized = False | |
def update(frame): | |
nonlocal initialized | |
skeleton = joints[frame] | |
root = skeleton[0] | |
update_camera(ax, root) | |
hcolors = colors | |
if highlights is not None and highlights[frame]: | |
hcolors = ("red", "red", "red", "red", "red") | |
for index, (chain, color) in enumerate( | |
zip(reversed(kinematic_tree), reversed(hcolors)) | |
): | |
if not initialized: | |
lines.append( | |
ax.plot( | |
skeleton[chain, x], | |
skeleton[chain, y], | |
skeleton[chain, z], | |
linewidth=6.0, | |
color=color, | |
zorder=20, | |
path_effects=[pe.SimpleLineShadow(), pe.Normal()], | |
) | |
) | |
else: | |
lines[index][0].set_xdata(skeleton[chain, x]) | |
lines[index][0].set_ydata(skeleton[chain, y]) | |
lines[index][0].set_3d_properties(skeleton[chain, z]) | |
lines[index][0].set_color(color) | |
left = max(frame - draw_offset, 0) | |
right = min(frame + draw_offset, trajectory.shape[0]) | |
spline_line.set_xdata(trajectory[left:right, 0]) | |
spline_line.set_ydata(trajectory[left:right, 1]) | |
spline_line.set_3d_properties(np.zeros_like(trajectory[left:right, 0])) | |
initialized = True | |
fig.tight_layout() | |
frames = joints.shape[0] | |
anim = FuncAnimation(fig, update, frames=frames, interval=1000 / fps, repeat=False) | |
if output == "notebook": | |
from IPython.display import HTML | |
HTML(anim.to_jshtml()) | |
else: | |
# anim.save(output, writer='ffmpeg', fps=fps) | |
anim.save(output, fps=fps) | |
plt.close() |