kpop-face-generator / stylegan3-fun /visual_reactive.py
rossellison's picture
Upload 194 files
e3a6a57
raw
history blame
16.5 kB
import os
import click
from tqdm import tqdm
try:
import ffmpeg
except ImportError:
raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"')
try:
import skvideo.io
except ImportError:
raise ImportError('scikit-video not found! Install it via "pip install scikit-video"')
import PIL
from PIL import Image
import scipy.ndimage as nd
from fractions import Fraction
import numpy as np
import torch
from torchvision import transforms
from typing import Union, Tuple
import dnnlib
import legacy
from torch_utils import gen_utils
from network_features import VGG16FeaturesNVIDIA, DiscriminatorFeatures
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide'
import moviepy.editor
# ----------------------------------------------------------------------------
def normalize_image(image: Union[PIL.Image.Image, np.ndarray]) -> np.ndarray:
"""Change dynamic range of an image from [0, 255] to [-1, 1]"""
image = np.array(image, dtype=np.float32)
image = image / 127.5 - 1.0
return image
def get_video_information(mp4_filename: Union[str, os.PathLike],
max_length_seconds: float = None,
starting_second: float = 0.0) -> Tuple[int, float, int, int, int, int]:
"""Take a mp4 file and return a list containing each frame as a NumPy array"""
metadata = skvideo.io.ffprobe(mp4_filename)
# Get video properties
fps = int(np.rint(float(Fraction(metadata['video']['@avg_frame_rate'])))) # Possible error here if we get 0
total_video_num_frames = int(metadata['video']['@nb_frames'])
video_duration = float(metadata['video']['@duration'])
video_width = int(metadata['video']['@width'])
video_height = int(metadata['video']['@height'])
# Maximum number of frames to return (if not provided, return the full video)
if max_length_seconds is None:
print('Considering the full video...')
max_length_seconds = video_duration
if starting_second != 0.0:
print('Using part of the video...')
starting_second = min(starting_second, video_duration)
max_length_seconds = min(video_duration - starting_second, max_length_seconds)
max_num_frames = int(np.rint(max_length_seconds * fps))
max_frames = min(total_video_num_frames, max_num_frames)
returned_duration = min(video_duration, max_length_seconds)
# Frame to start from
starting_frame = int(np.rint(starting_second * fps))
return fps, returned_duration, starting_frame, max_frames, video_width, video_height
def get_video_frames(mp4_filename: Union[str, os.PathLike],
run_dir: Union[str, os.PathLike],
starting_frame: int,
max_frames: int,
center_crop: bool = False,
save_selected_frames: bool = False) -> np.ndarray:
"""Get all the frames of a video as a np.ndarray"""
# DEPRECATED
print('Getting video frames...')
frames = skvideo.io.vread(mp4_filename) # TODO: crazy things with scikit-video
frames = frames[starting_frame:min(starting_frame + max_frames, len(frames)), :, :, :]
frames = np.transpose(frames, (0, 3, 2, 1)) # NHWC => NCWH
if center_crop:
frame_width, frame_height = frames.shape[2], frames.shape[3]
min_side = min(frame_width, frame_height)
frames = frames[:, :, (frame_width - min_side) // 2:(frame_width + min_side) // 2, (frame_height - min_side) // 2:(frame_height + min_side) // 2]
if save_selected_frames:
skvideo.io.vwrite(os.path.join(run_dir, 'selected_frames.mp4'), np.transpose(frames, (0, 3, 2, 1)))
return frames
# ----------------------------------------------------------------------------
@click.command()
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
# Encoder options
@click.option('--encoder', type=click.Choice(['discriminator', 'vgg16', 'clip']), help='Choose the model to encode each frame into the latent space Z.', default='discriminator', show_default=True)
@click.option('--vgg16-layer', type=click.Choice(['conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3', 'adavgpool', 'fc1', 'fc2']), help='Choose the layer to use from VGG16 (if used as encoder)', default='adavgpool', show_default=True)
# Source video options
@click.option('--source-video', '-video', 'video_file', type=click.Path(exists=True, dir_okay=False), help='Path to video file', required=True)
@click.option('--max-video-length', type=click.FloatRange(min=0.0, min_open=True), help='How many seconds of the video to take (from the starting second)', default=None, show_default=True)
@click.option('--starting-second', type=click.FloatRange(min=0.0), help='Second to start the video from', default=0.0, show_default=True)
@click.option('--frame-transform', type=click.Choice(['none', 'center-crop', 'resize']), help='TODO: Transform to apply to the individual frame.')
@click.option('--center-crop', is_flag=True, help='Center-crop each frame of the video')
@click.option('--save-selected-frames', is_flag=True, help='Save the selected frames of the input video after the selected transform')
# Synthesis options
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a dlatent (.npy/.npz)', default=None)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
# Video options
@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
# Extra parameters for saving the results
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out','visual-reactive'), show_default=True, metavar='DIR')
@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='', show_default=True)
def visual_reactive_interpolation(
ctx: click.Context,
network_pkl: Union[str, os.PathLike],
encoder: str,
vgg16_layer: str,
video_file: Union[str, os.PathLike],
max_video_length: float,
starting_second: float,
frame_transform: str,
center_crop: bool,
save_selected_frames: bool,
truncation_psi: float,
new_center: Tuple[str, Union[int, np.ndarray]],
noise_mode: str,
anchor_latent_space: bool,
outdir: Union[str, os.PathLike],
description: str,
compress: bool,
smoothing_sec: float = 0.1 # For Gaussian blur; the lower, the faster the reaction; higher leads to more generated frames being the same
):
print(f'Loading networks from "{network_pkl}"...')
# Define the model (load both D, G, and the features of D)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if encoder == 'discriminator':
print('Loading Discriminator and its features...')
with dnnlib.util.open_url(network_pkl) as f:
D = legacy.load_network_pkl(f)['D'].eval().requires_grad_(False).to(device) # type: ignore
D_features = DiscriminatorFeatures(D).requires_grad_(False).to(device)
del D
elif encoder == 'vgg16':
print('Loading VGG16 and its features...')
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with dnnlib.util.open_url(url) as f:
vgg16 = torch.jit.load(f).eval().to(device)
vgg16_features = VGG16FeaturesNVIDIA(vgg16).requires_grad_(False).to(device)
del vgg16
elif encoder == 'clip':
print('Loading CLIP model...')
try:
import clip
except ImportError:
raise ImportError('clip not installed! Install it via "pip install git+https://github.com/openai/CLIP.git"')
model, preprocess = clip.load('ViT-B/32', device=device)
model = model.requires_grad_(False) # Otherwise OOM
print('Loading Generator...')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].eval().requires_grad_(False).to(device) # type: ignore
if anchor_latent_space:
gen_utils.anchor_latent_space(G)
if new_center is None:
# Stick to the tracked center of W during training
w_avg = G.mapping.w_avg
else:
new_center, new_center_value = new_center
# We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
if isinstance(new_center_value, int):
new_center = f'seed_{new_center}'
w_avg = gen_utils.get_w_from_seed(G, device, new_center_value, truncation_psi=1.0) # We want the pure dlatent
elif isinstance(new_center_value, np.ndarray):
w_avg = torch.from_numpy(new_center_value).to(device)
else:
ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
# Create the run dir with the given name description; add slowdown if different than the default (1)
description = 'visual-reactive' if len(description) == 0 else description
run_dir = gen_utils.make_run_dir(outdir, description)
# Name of the video
video_name, _ = os.path.splitext(video_file)
video_name = video_name.split(os.sep)[-1] # Get the actual name of the video
mp4_name = f'visual-reactive_{video_name}'
# Get all the frames of the video and its properties
# TODO: resize the frames to the size of the network (G.img_resolution)
fps, max_video_length, starting_frame, max_frames, width, height = get_video_information(video_file,
max_video_length,
starting_second)
videogen = skvideo.io.vreader(video_file)
fake_dlatents = list()
if save_selected_frames:
# skvideo.io.vwrite sets FPS=25, so we have to manually enter it via FFmpeg
# TODO: use only ffmpeg-python
writer = skvideo.io.FFmpegWriter(os.path.join(run_dir, f'selected-frames_{video_name}.mp4'),
inputdict={'-r': str(fps)})
for idx, frame in enumerate(tqdm(videogen, desc=f'Getting frames+latents of "{video_name}"', unit='frames')):
# Only save the frames that the user has selected
if idx < starting_frame:
continue
if idx > starting_frame + max_frames:
break
if center_crop:
frame_width, frame_height = frame.shape[1], frame.shape[0]
min_side = min(frame_width, frame_height)
frame = frame[(frame_height - min_side) // 2:(frame_height + min_side) // 2, (frame_width - min_side) // 2:(frame_width + min_side) // 2, :]
if save_selected_frames:
writer.writeFrame(frame)
# Get fake latents
if encoder == 'discriminator':
frame = normalize_image(frame) # [0, 255] => [-1, 1]
frame = torch.from_numpy(np.transpose(frame, (2, 1, 0))).unsqueeze(0).to(device) # HWC => CWH => NCWH, N=1
fake_z = D_features.get_layers_features(frame, layers=['fc'])[0]
elif encoder == 'vgg16':
preprocess = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
frame = preprocess(frame).unsqueeze(0).to(device)
fake_z = vgg16_features.get_layers_features(frame, layers=[vgg16_layer])[0]
fake_z = fake_z.view(1, 512, -1).mean(2) # [1, C, H, W] => [1, C]; can be used in any layer
elif encoder == 'clip':
frame = Image.fromarray(frame) # [0, 255]
frame = preprocess(frame).unsqueeze(0).to(device)
fake_z = model.encode_image(frame)
# Normalize the latent so that it's ~N(0, 1), or divide by its .max()
# fake_z = fake_z / fake_z.max()
fake_z = (fake_z - fake_z.mean()) / fake_z.std()
# Get dlatent
fake_w = G.mapping(fake_z, None)
# Truncation trick
fake_w = w_avg + (fake_w - w_avg) * truncation_psi
fake_dlatents.append(fake_w)
if save_selected_frames:
# Close the video writer
writer.close()
# Set the fake_dlatents as a torch tensor; we can't just do torch.tensor(fake_dlatents) as with NumPy :(
fake_dlatents = torch.cat(fake_dlatents, 0)
# Smooth out so larger changes in the scene are the ones that affect the generation
fake_dlatents = torch.from_numpy(nd.gaussian_filter(fake_dlatents.cpu(),
sigma=[smoothing_sec * fps, 0, 0])).to(device)
# Auxiliary function for moviepy
def make_frame(t):
# Get the frame, dlatent, and respective image
frame_idx = int(np.clip(np.round(t * fps), 0, len(fake_dlatents) - 1))
fake_w = fake_dlatents[frame_idx]
image = gen_utils.w_to_img(G, fake_w, noise_mode)
# Create grid for this timestamp
grid = gen_utils.create_image_grid(image, (1, 1))
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid
# Generate video using the respective make_frame function
videoclip = moviepy.editor.VideoClip(make_frame, duration=max_video_length)
videoclip.set_duration(max_video_length)
# Change the video parameters (codec, bitrate) if you so desire
final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
# Compress the video (lower file size, same resolution, if successful)
if compress:
gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
# TODO: merge the videos side by side, but we will need them be the same height
if save_selected_frames:
# GUIDE: https://github.com/kkroening/ffmpeg-python/issues/150
min_height = min(height, G.img_resolution)
input0 = ffmpeg.input(os.path.join(run_dir, f'selected-frames_{video_name}.mp4'))
input1 = ffmpeg.input(os.path.join(run_dir, f'{mp4_name}-compressed.mp4' if compress else f'{mp4_name}.mp4'))
out = ffmpeg.filter([input0, input1], 'hstack').output(os.path.join(run_dir, 'side-by-side.mp4'))
# Save the configuration used
new_center = 'w_avg' if new_center is None else new_center
ctx.obj = {
'network_pkl': network_pkl,
'encoder_options': {
'encoder': encoder,
'vgg16_layer': vgg16_layer,
},
'source_video_options': {
'source_video': video_file,
'sorce_video_params': {
'fps': fps,
'height': height,
'width': width,
'length': max_video_length,
'starting_frame': starting_frame,
'total_frames': max_frames
},
'max_video_length': max_video_length,
'starting_second': starting_second,
'frame_transform': frame_transform,
'center_crop': center_crop,
'save_selected_frames': save_selected_frames
},
'synthesis_options': {
'truncation_psi': truncation_psi,
'new_center': new_center,
'noise_mode': noise_mode,
'smoothing_sec': smoothing_sec
},
'video_options': {
'compress': compress
},
'extra_parameters': {
'outdir': run_dir,
'description': description
}
}
gen_utils.save_config(ctx=ctx, run_dir=run_dir)
# ----------------------------------------------------------------------------
if __name__ == '__main__':
visual_reactive_interpolation()
# ----------------------------------------------------------------------------