|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from IPython.display import HTML |
|
from matplotlib import animation |
|
from ganime.visualization.images import unnormalize_if_necessary |
|
|
|
|
|
def display_videos(data, ground_truth=None, n_rows=3, n_cols=3): |
|
|
|
if ground_truth is not None: |
|
data = np.concatenate((data, ground_truth), axis=2) |
|
|
|
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, squeeze=False, figsize=(16, 9)) |
|
|
|
|
|
plt.setp(axs, xticks=[], yticks=[]) |
|
plt.subplots_adjust(wspace=0, hspace=0) |
|
|
|
ims = [] |
|
|
|
for i in range(n_rows): |
|
for j in range(n_cols): |
|
idx = i * n_cols + j |
|
video = data[idx] |
|
frame = unnormalize_if_necessary(video[0]) |
|
im = axs[i][j].imshow(frame, animated=True) |
|
ims.append(im) |
|
|
|
plt.close() |
|
|
|
def init(): |
|
for i in range(n_rows): |
|
for j in range(n_cols): |
|
idx = i * n_cols + j |
|
video = data[idx] |
|
im = ims[idx] |
|
frame = unnormalize_if_necessary(video[0]) |
|
im.set_data(frame) |
|
return ims |
|
|
|
def animate(frame_id): |
|
for i in range(n_rows): |
|
for j in range(n_cols): |
|
idx = i * n_cols + j |
|
video = data[idx] |
|
frame = video[frame_id, :, :, :] |
|
frame = unnormalize_if_necessary(frame) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ims[idx].set_data(frame) |
|
return ims |
|
|
|
anim = animation.FuncAnimation( |
|
fig, animate, init_func=init, frames=data.shape[1], blit=True, interval=200 |
|
) |
|
|
|
|
|
|
|
return HTML(anim.to_html5_video()) |
|
|