ControlAR / utils /video.py
wondervictor
update README
2422035
raw
history blame
3.48 kB
import math
import numpy as np
import skvideo.io
from PIL import Image
# Shifts src_tf dim to dest dim
# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
n_dims = len(x.shape)
if src_dim < 0:
src_dim = n_dims + src_dim
if dest_dim < 0:
dest_dim = n_dims + dest_dim
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
dims = list(range(n_dims))
del dims[src_dim]
permutation = []
ctr = 0
for i in range(n_dims):
if i == dest_dim:
permutation.append(src_dim)
else:
permutation.append(dims[ctr])
ctr += 1
x = x.permute(permutation)
if make_contiguous:
x = x.contiguous()
return x
# reshapes tensor start from dim i (inclusive)
# to dim j (exclusive) to the desired shape
# e.g. if x.shape = (b, thw, c) then
# view_range(x, 1, 2, (t, h, w)) returns
# x of shape (b, t, h, w, c)
def view_range(x, i, j, shape):
shape = tuple(shape)
n_dims = len(x.shape)
if i < 0:
i = n_dims + i
if j is None:
j = n_dims
elif j < 0:
j = n_dims + j
assert 0 <= i < j <= n_dims
x_shape = x.shape
target_shape = x_shape[:i] + shape + x_shape[j:]
return x.view(target_shape)
def tensor_slice(x, begin, size):
assert all([b >= 0 for b in begin])
size = [l - b if s == -1 else s
for s, b, l in zip(size, begin, x.shape)]
assert all([s >= 0 for s in size])
slices = [slice(b, b + s) for b, s in zip(begin, size)]
return x[slices]
def save_video_grid(video, fname, nrow=None, fps=5):
b, c, t, h, w = video.shape
video = video.permute(0, 2, 3, 4, 1)
video = (video.cpu().numpy() * 255).astype('uint8')
if nrow is None:
nrow = math.ceil(math.sqrt(b))
ncol = math.ceil(b / nrow)
padding = 1
video_grid = np.zeros((t, (padding + h) * nrow + padding,
(padding + w) * ncol + padding, c), dtype='uint8')
for i in range(b):
r = i // ncol
c = i % ncol
start_r = (padding + h) * r
start_c = (padding + w) * c
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '{}'.format(fps)})
def save_gif_grid(video, file_name, nrow=None, fps=5):
b, c, t, h, w = video.shape
video = video.permute(0, 2, 3, 4, 1)
video = (video.cpu().numpy() * 255).astype('uint8')
if nrow is None:
nrow = math.ceil(math.sqrt(b))
ncol = math.ceil(b / nrow)
padding = 1
video_grid = np.zeros((t, (padding + h) * nrow + padding,
(padding + w) * ncol + padding, c), dtype='uint8')
for i in range(b):
r = i // ncol
c = i % ncol
start_r = (padding + h) * r
start_c = (padding + w) * c
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
images = []
for frame in video_grid:
images.append(Image.fromarray(frame))
# Save the first image and append the rest of the images as frames in the GIF
images[0].save(file_name, save_all=True, append_images=images[1:], optimize=False, duration=int(1000/fps), loop=0)
# The 'duration' parameter defines the display time for each frame in milliseconds
# The 'loop' parameter defines the number of loops the GIF should make (0 for infinite loop)