rsax commited on
Commit
d4c761d
·
verified ·
1 Parent(s): 3ed0fbb

Update visualization/plot_3d_global.py

Browse files
Files changed (1) hide show
  1. visualization/plot_3d_global.py +7 -6
visualization/plot_3d_global.py CHANGED
@@ -106,12 +106,13 @@ def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4):
106
  return frame
107
 
108
  out = []
109
- for i in range(frame_number) :
110
- out.append(update(i))
111
- out = np.stack(out, axis=0)
112
- return torch.from_numpy(out)
113
-
114
-
 
115
 
116
  def draw_to_batch(smpl_joints_batch, title_batch=None, outname=None):
117
  batch_size = len(smpl_joints_batch)
 
106
  return frame
107
 
108
  out = []
109
+ for i in range(frame_number):
110
+ frame = update(i)
111
+ if frame.ndim == 3 and frame.shape[2] == 3: # Check that frame is H x W x 3
112
+ out.append(frame)
113
+ else:
114
+ print(f"Frame {i} has incorrect shape or channels: {frame.shape}")
115
+ return out
116
 
117
  def draw_to_batch(smpl_joints_batch, title_batch=None, outname=None):
118
  batch_size = len(smpl_joints_batch)