Spaces:
Runtime error
Runtime error
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
"""Generate lerp videos using pretrained network pickle.""" | |
import copy | |
import os | |
import re | |
from typing import List, Optional, Tuple, Union | |
import click | |
import dnnlib | |
import imageio | |
import numpy as np | |
import scipy.interpolate | |
import torch | |
from tqdm import tqdm | |
import legacy | |
#---------------------------------------------------------------------------- | |
def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): | |
batch_size, channels, img_h, img_w = img.shape | |
if grid_w is None: | |
grid_w = batch_size // grid_h | |
assert batch_size == grid_w * grid_h | |
if float_to_uint8: | |
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
img = img.reshape(grid_h, grid_w, channels, img_h, img_w) | |
img = img.permute(2, 0, 3, 1, 4) | |
img = img.reshape(channels, grid_h * img_h, grid_w * img_w) | |
if chw_to_hwc: | |
img = img.permute(1, 2, 0) | |
if to_numpy: | |
img = img.cpu().numpy() | |
return img | |
#---------------------------------------------------------------------------- | |
def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): | |
grid_w = grid_dims[0] | |
grid_h = grid_dims[1] | |
if num_keyframes is None: | |
if len(seeds) % (grid_w*grid_h) != 0: | |
raise ValueError('Number of input seeds must be divisible by grid W*H') | |
num_keyframes = len(seeds) // (grid_w*grid_h) | |
all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) | |
for idx in range(num_keyframes*grid_h*grid_w): | |
all_seeds[idx] = seeds[idx % len(seeds)] | |
if shuffle_seed is not None: | |
rng = np.random.RandomState(seed=shuffle_seed) | |
rng.shuffle(all_seeds) | |
zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) | |
ws = G.mapping(z=zs, c=None, truncation_psi=psi) | |
_ = G.synthesis(ws[:1]) # warm up | |
ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) | |
# Interpolation. | |
grid = [] | |
for yi in range(grid_h): | |
row = [] | |
for xi in range(grid_w): | |
x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) | |
y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) | |
interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) | |
row.append(interp) | |
grid.append(row) | |
# Render video. | |
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) | |
for frame_idx in tqdm(range(num_keyframes * w_frames)): | |
imgs = [] | |
for yi in range(grid_h): | |
for xi in range(grid_w): | |
interp = grid[yi][xi] | |
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) | |
img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] | |
imgs.append(img) | |
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) | |
video_out.close() | |
#---------------------------------------------------------------------------- | |
def parse_range(s: Union[str, List[int]]) -> List[int]: | |
'''Parse a comma separated list of numbers or ranges and return a list of ints. | |
Example: '1,2,5-10' returns [1, 2, 5, 6, 7] | |
''' | |
if isinstance(s, list): return s | |
ranges = [] | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
for p in s.split(','): | |
if m := range_re.match(p): | |
ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) | |
else: | |
ranges.append(int(p)) | |
return ranges | |
#---------------------------------------------------------------------------- | |
def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: | |
'''Parse a 'M,N' or 'MxN' integer tuple. | |
Example: | |
'4x2' returns (4,2) | |
'0,1' returns (0,1) | |
''' | |
if isinstance(s, tuple): return s | |
if m := re.match(r'^(\d+)[x,](\d+)$', s): | |
return (int(m.group(1)), int(m.group(2))) | |
raise ValueError(f'cannot parse tuple {s}') | |
#---------------------------------------------------------------------------- | |
def generate_images( | |
network_pkl: str, | |
seeds: List[int], | |
shuffle_seed: Optional[int], | |
truncation_psi: float, | |
grid: Tuple[int,int], | |
num_keyframes: Optional[int], | |
w_frames: int, | |
output: str | |
): | |
"""Render a latent vector interpolation video. | |
Examples: | |
\b | |
# Render a 4x2 grid of interpolations for seeds 0 through 31. | |
python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ | |
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl | |
Animation length and seed keyframes: | |
The animation length is either determined based on the --seeds value or explicitly | |
specified using the --num-keyframes option. | |
When num keyframes is specified with --num-keyframes, the output video length | |
will be 'num_keyframes*w_frames' frames. | |
If --num-keyframes is not specified, the number of seeds given with | |
--seeds must be divisible by grid size W*H (--grid). In this case the | |
output video length will be '# seeds/(w*h)*w_frames' frames. | |
""" | |
print('Loading networks from "%s"...' % network_pkl) | |
device = torch.device('cuda') | |
with dnnlib.util.open_url(network_pkl) as f: | |
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) | |
#---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
generate_images() # pylint: disable=no-value-for-parameter | |
#---------------------------------------------------------------------------- | |