DragGan44 / stylegan_human /stylemixing_video.py
Varonblu's picture
Duplicate from DragGan/DragGan
b61c951
# Copyright (c) SenseTime Research. All rights reserved.
"""Here we demo style-mixing results using StyleGAN2 pretrained model.
Script reference: https://github.com/PDillis/stylegan2-fun """
import argparse
import legacy
import scipy
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
from typing import List
import re
import sys
import os
import click
import torch
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
import moviepy.editor
"""
Generate style mixing video.
Examples:
\b
python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859 \\
--col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video
"""
@click.command()
@click.option('--network', 'network_pkl', help='Path to network pickle filename', required=True)
@click.option('--row-seed', 'src_seed', type=legacy.num_range, help='Random seed to use for image source row', required=True)
@click.option('--col-seeds', 'dst_seeds', type=legacy.num_range, help='Random seeds to use for image columns (style)', required=True)
@click.option('--col-styles', 'col_styles', type=legacy.num_range, help='Style layer range (default: %(default)s)', default='0-6')
@click.option('--only-stylemix', 'only_stylemix', help='Add flag to only show the style mxied images in the video',default=False)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=1)
@click.option('--duration-sec', 'duration_sec', type=float, help='Duration of video (default: %(default)s)', default=10)
@click.option('--fps', 'mp4_fps', type=int, help='FPS of generated video (default: %(default)s)', default=10)
@click.option('--indent-range', 'indent_range', type=int, default=30)
@click.option('--outdir', help='Root directory for run results (default: %(default)s)', default='outputs/stylemixing_video', metavar='DIR')
def style_mixing_video(network_pkl: str,
src_seed: List[int], # Seed of the source image style (row)
dst_seeds: List[int], # Seeds of the destination image styles (columns)
col_styles: List[int], # Styles to transfer from first row to first column
truncation_psi=float,
only_stylemix=bool, # True if user wishes to show only thre style transferred result
duration_sec=float,
smoothing_sec=1.0,
mp4_fps=int,
mp4_codec="libx264",
mp4_bitrate="16M",
minibatch_size=8,
noise_mode='const',
indent_range=int,
outdir=str):
# Calculate the number of frames:
print('col_seeds: ', dst_seeds)
num_frames = int(np.rint(duration_sec * mp4_fps))
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
Gs = legacy.load_network_pkl(f)['G_ema'].to(device)
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
max_style = int(2 * np.log2(Gs.img_resolution)) - 3
assert max(col_styles) <= max_style, f"Maximum col-style allowed: {max_style}"
# Left col latents
print('Generating Source W vectors...')
src_shape = [num_frames] + [Gs.z_dim]
src_z = np.random.RandomState(*src_seed).randn(*src_shape).astype(np.float32) # [frames, src, component]
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
src_z /= np.sqrt(np.mean(np.square(src_z)))
# Map into the detangled latent space W and do truncation trick
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None)
w_avg = Gs.mapping.w_avg
src_w = w_avg + (src_w - w_avg) * truncation_psi
# Top row latents (fixed reference)
print('Generating Destination W vectors...')
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None)
dst_w = w_avg + (dst_w - w_avg) * truncation_psi
# Get the width and height of each image:
H = Gs.img_resolution # 1024
W = Gs.img_resolution//2 # 512
# Generate ALL the source images:
src_images = Gs.synthesis(src_w, noise_mode=noise_mode)
src_images = (src_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# Generate the column images:
dst_images = Gs.synthesis(dst_w, noise_mode=noise_mode)
dst_images = (dst_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
print('Generating full video (including source and destination images)')
# Generate our canvas where we will paste all the generated images:
canvas = PIL.Image.new("RGB", ((W-indent_range) * (len(dst_seeds) + 1), H * (len(src_seed) + 1)), "white") # W, H
for col, dst_image in enumerate(list(dst_images)): #dst_image:[3,1024,512]
canvas.paste(PIL.Image.fromarray(dst_image.cpu().numpy(), "RGB"), ((col + 1) * (W-indent_range), 0)) #H
# Aux functions: Frame generation func for moviepy.
def make_frame(t):
# Get the frame number according to time t:
frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
# We wish the image belonging to the frame at time t:
src_image = src_images[frame_idx] # always in the same place
canvas.paste(PIL.Image.fromarray(src_image.cpu().numpy(), "RGB"), (0-indent_range, H)) # Paste it to the lower left
# Now, for each of the column images:
for col, dst_image in enumerate(list(dst_images)):
# Select the pertinent latent w column:
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
w_col = torch.from_numpy(w_col).to(device)
# Replace the values defined by col_styles:
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
# Generate these synthesized images:
col_images = Gs.synthesis(w_col, noise_mode=noise_mode)
col_images = (col_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# Paste them in their respective spot:
for row, image in enumerate(list(col_images)):
canvas.paste(
PIL.Image.fromarray(image.cpu().numpy(), "RGB"),
((col + 1) * (W - indent_range), (row + 1) * H),
)
return np.array(canvas)
# Generate video using make_frame:
print('Generating style-mixed video...')
videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
grid_size = [len(dst_seeds), len(src_seed)]
mp4 = "{}x{}-style-mixing_{}_{}.mp4".format(*grid_size,min(col_styles),max(col_styles))
if not os.path.exists(outdir): os.makedirs(outdir)
videoclip.write_videofile(os.path.join(outdir,mp4),
fps=mp4_fps,
codec=mp4_codec,
bitrate=mp4_bitrate)
if __name__ == "__main__":
style_mixing_video()