Wendy-Fly commited on
Commit
af5044e
·
verified ·
1 Parent(s): 386545a

Upload motion_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. motion_utils.py +173 -0
motion_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
5
+ from mpl_toolkits.mplot3d import Axes3D
6
+ from matplotlib.animation import FuncAnimation, PillowWriter
7
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8
+ import mpl_toolkits.mplot3d.axes3d as p3
9
+
10
+ def qrot(q, v):
11
+ """
12
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
13
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
14
+ where * denotes any number of dimensions.
15
+ Returns a tensor of shape (*, 3).
16
+ """
17
+ assert q.shape[-1] == 4
18
+ assert v.shape[-1] == 3
19
+ assert q.shape[:-1] == v.shape[:-1]
20
+
21
+ original_shape = list(v.shape)
22
+ # print(q.shape)
23
+ q = q.contiguous().view(-1, 4)
24
+ v = v.contiguous().view(-1, 3)
25
+
26
+ qvec = q[:, 1:]
27
+ uv = torch.cross(qvec, v, dim=1)
28
+ uuv = torch.cross(qvec, uv, dim=1)
29
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
30
+
31
+ def qinv(q):
32
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
33
+ mask = torch.ones_like(q)
34
+ mask[..., 1:] = -mask[..., 1:]
35
+ return q * mask
36
+
37
+ def recover_root_rot_pos(data):
38
+ rot_vel = data[..., 0]
39
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
40
+ '''Get Y-axis rotation from rotation velocity'''
41
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
42
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
43
+
44
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
45
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
46
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
47
+
48
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
49
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
50
+ '''Add Y-axis rotation to root position'''
51
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
52
+
53
+ r_pos = torch.cumsum(r_pos, dim=-2)
54
+
55
+ r_pos[..., 1] = data[..., 3]
56
+ return r_rot_quat, r_pos
57
+
58
+ def recover_from_ric(data, joints_num):
59
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
60
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
61
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
62
+
63
+ '''Add Y-axis rotation to local joints'''
64
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
65
+
66
+ '''Add root XZ to joints'''
67
+ positions[..., 0] += r_pos[..., 0:1]
68
+ positions[..., 2] += r_pos[..., 2:3]
69
+
70
+ '''Concate root and joints'''
71
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
72
+
73
+ return positions
74
+
75
+ def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
76
+ # matplotlib.use('Agg')
77
+
78
+ title_sp = title.split(' ')
79
+ if len(title_sp) > 10:
80
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
81
+ def init():
82
+ # ax.set_xlim3d([-radius / 2, radius / 2])
83
+ # ax.set_ylim3d([0, radius])
84
+ # ax.set_zlim3d([0, radius])
85
+ # # print(title)
86
+ # fig.suptitle(title, fontsize=20)
87
+ # ax.grid(b=False)
88
+
89
+ nb_joints = joints.shape[1]
90
+ limits = 1000 if nb_joints == 21 else 2
91
+ ax.set_xlim(-limits, limits)
92
+ ax.set_ylim(-limits, limits)
93
+ ax.set_zlim(0, limits)
94
+ fig.suptitle(title, fontsize=20)
95
+ ax.grid(b=False)
96
+
97
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
98
+ ## Plot a plane XZ
99
+ verts = [
100
+ [minx, miny, minz],
101
+ [minx, miny, maxz],
102
+ [maxx, miny, maxz],
103
+ [maxx, miny, minz]
104
+ ]
105
+ xz_plane = Poly3DCollection([verts])
106
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
107
+ ax.add_collection3d(xz_plane)
108
+
109
+ # return ax
110
+
111
+ # (seq_len, joints_num, 3)
112
+ data = joints.copy().reshape(len(joints), -1, 3)
113
+ fig = plt.figure(figsize=figsize)
114
+ # ax = p3.Axes3D(fig)
115
+ ax = fig.add_subplot(111, projection='3d')
116
+ init()
117
+ MINS = data.min(axis=0).min(axis=0)
118
+ MAXS = data.max(axis=0).max(axis=0)
119
+ colors = ['red', 'blue', 'black', 'red', 'blue',
120
+ 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
121
+ 'darkred', 'darkred','darkred','darkred','darkred']
122
+ frame_number = data.shape[0]
123
+ # print(data.shape)
124
+
125
+ height_offset = MINS[1]
126
+ data[:, :, 1] -= height_offset
127
+ trajec = data[:, 0, [0, 2]]
128
+
129
+ data[..., 0] -= data[:, 0:1, 0]
130
+ data[..., 2] -= data[:, 0:1, 2]
131
+
132
+ # print(trajec.shape)
133
+
134
+ def update(index):
135
+ # print(index)
136
+ # ax.lines = []
137
+ # ax.collections = []
138
+ for line in ax.lines:
139
+ line.remove()
140
+ for collection in ax.collections:
141
+ collection.remove()
142
+ ax.view_init(elev=120, azim=-90)
143
+ ax.dist = 7.5
144
+ # ax =
145
+ plot_xzPlane(MINS[0]-trajec[index, 0], MAXS[0]-trajec[index, 0], 0, MINS[2]-trajec[index, 1], MAXS[2]-trajec[index, 1])
146
+ # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
147
+
148
+ if index > 1:
149
+ ax.plot3D(trajec[:index, 0]-trajec[index, 0], np.zeros_like(trajec[:index, 0]), trajec[:index, 1]-trajec[index, 1], linewidth=1.0,
150
+ color='blue')
151
+ # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
152
+
153
+
154
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
155
+ # print(color)
156
+ if i < 5:
157
+ linewidth = 4.0
158
+ else:
159
+ linewidth = 2.0
160
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, color=color)
161
+ # print(trajec[:index, 0].shape)
162
+
163
+ plt.axis('off')
164
+ ax.set_xticklabels([])
165
+ ax.set_yticklabels([])
166
+ ax.set_zticklabels([])
167
+
168
+ ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False)
169
+ # modify
170
+ FFwriter=animation.FFMpegWriter(fps=fps, extra_args=['-vcodec', 'libx264'])
171
+
172
+ ani.save(save_path, writer=FFwriter)
173
+ plt.close()