File size: 6,041 Bytes
af5044e 18df527 af5044e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation, PillowWriter
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import mpl_toolkits.mplot3d.axes3d as p3
def qrot(q, v):
"""
Rotate vector(s) v about the rotation described by quaternion(s) q.
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
where * denotes any number of dimensions.
Returns a tensor of shape (*, 3).
"""
assert q.shape[-1] == 4
assert v.shape[-1] == 3
assert q.shape[:-1] == v.shape[:-1]
original_shape = list(v.shape)
# print(q.shape)
q = q.contiguous().view(-1, 4)
v = v.contiguous().view(-1, 3)
qvec = q[:, 1:]
uv = torch.cross(qvec, v, dim=1)
uuv = torch.cross(qvec, uv, dim=1)
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
def qinv(q):
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
mask = torch.ones_like(q)
mask[..., 1:] = -mask[..., 1:]
return q * mask
def recover_root_rot_pos(data):
rot_vel = data[..., 0]
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
'''Get Y-axis rotation from rotation velocity'''
r_rot_ang[..., 1:] = rot_vel[..., :-1]
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
'''Add Y-axis rotation to root position'''
r_pos = qrot(qinv(r_rot_quat), r_pos)
r_pos = torch.cumsum(r_pos, dim=-2)
r_pos[..., 1] = data[..., 3]
return r_rot_quat, r_pos
def recover_from_ric(data, joints_num):
r_rot_quat, r_pos = recover_root_rot_pos(data)
positions = data[..., 4:(joints_num - 1) * 3 + 4]
positions = positions.view(positions.shape[:-1] + (-1, 3))
'''Add Y-axis rotation to local joints'''
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
'''Add root XZ to joints'''
positions[..., 0] += r_pos[..., 0:1]
positions[..., 2] += r_pos[..., 2:3]
'''Concate root and joints'''
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
return positions
def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
# matplotlib.use('Agg')
title_sp = title.split(' ')
if len(title_sp) > 10:
title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
def init():
# ax.set_xlim3d([-radius / 2, radius / 2])
# ax.set_ylim3d([0, radius])
# ax.set_zlim3d([0, radius])
# # print(title)
# fig.suptitle(title, fontsize=20)
# ax.grid(b=False)
nb_joints = joints.shape[1]
limits = 1000 if nb_joints == 21 else 2
ax.set_xlim(-limits, limits)
ax.set_ylim(-limits, limits)
ax.set_zlim(0, limits)
fig.suptitle(title, fontsize=20)
ax.grid(b=False)
def plot_xzPlane(minx, maxx, miny, minz, maxz):
## Plot a plane XZ
verts = [
[minx, miny, minz],
[minx, miny, maxz],
[maxx, miny, maxz],
[maxx, miny, minz]
]
xz_plane = Poly3DCollection([verts])
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
ax.add_collection3d(xz_plane)
# return ax
# (seq_len, joints_num, 3)
data = joints.copy().reshape(len(joints), -1, 3)
fig = plt.figure(figsize=figsize)
# ax = p3.Axes3D(fig)
ax = fig.add_subplot(111, projection='3d')
init()
MINS = data.min(axis=0).min(axis=0)
MAXS = data.max(axis=0).max(axis=0)
colors = ['red', 'blue', 'black', 'red', 'blue',
'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
'darkred', 'darkred','darkred','darkred','darkred']
frame_number = data.shape[0]
# print(data.shape)
height_offset = MINS[1]
data[:, :, 1] -= height_offset
trajec = data[:, 0, [0, 2]]
data[..., 0] -= data[:, 0:1, 0]
data[..., 2] -= data[:, 0:1, 2]
# print(trajec.shape)
def update(index):
# print(index)
# ax.lines = []
# ax.collections = []
for line in ax.lines:
line.remove()
for collection in ax.collections:
collection.remove()
ax.view_init(elev=120, azim=-90)
ax.dist = 7.5
# ax =
plot_xzPlane(MINS[0]-trajec[index, 0], MAXS[0]-trajec[index, 0], 0, MINS[2]-trajec[index, 1], MAXS[2]-trajec[index, 1])
# ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
if index > 1:
ax.plot3D(trajec[:index, 0]-trajec[index, 0], np.zeros_like(trajec[:index, 0]), trajec[:index, 1]-trajec[index, 1], linewidth=1.0,
color='blue')
# ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
# print(color)
if i < 5:
linewidth = 4.0
else:
linewidth = 2.0
ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, color=color)
# print(trajec[:index, 0].shape)
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False)
ani.save(save_path, fps=fps)
plt.close()
|