Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
from typing import List, Tuple, Union, Optional, Type | |
from collections import OrderedDict | |
from locale import atof | |
import click | |
import numpy as np | |
import torch | |
import dnnlib | |
import legacy | |
# ---------------------------------------------------------------------------- | |
channels_dict = {1: 'L', 3: 'RGB', 4: 'RGBA'} | |
# ---------------------------------------------------------------------------- | |
available_cfgs = ['stylegan2', 'stylegan2-ext', 'stylegan3-t', 'stylegan3-r'] | |
# ---------------------------------------------------------------------------- | |
def create_image_grid(images: np.ndarray, grid_size: Optional[Tuple[int, int]] = None): | |
""" | |
Create a grid with the fed images | |
Args: | |
images (np.array): array of images | |
grid_size (tuple(int)): size of grid (grid_width, grid_height) | |
Returns: | |
grid (np.array): image grid of size grid_size | |
""" | |
# Sanity check | |
assert images.ndim == 3 or images.ndim == 4, f'Images has {images.ndim} dimensions (shape: {images.shape})!' | |
num, img_h, img_w, c = images.shape | |
# If user specifies the grid shape, use it | |
if grid_size is not None: | |
grid_w, grid_h = tuple(grid_size) | |
# If one of the sides is None, then we must infer it (this was divine inspiration) | |
if grid_w is None: | |
grid_w = num // grid_h + min(num % grid_h, 1) | |
elif grid_h is None: | |
grid_h = num // grid_w + min(num % grid_w, 1) | |
# Otherwise, we can infer it by the number of images (priority is given to grid_w) | |
else: | |
grid_w = max(int(np.ceil(np.sqrt(num))), 1) | |
grid_h = max((num - 1) // grid_w + 1, 1) | |
# Sanity check | |
assert grid_w * grid_h >= num, 'Number of rows and columns in the grid must be greater than the number of images!' | |
# Get the grid | |
grid = np.zeros([grid_h * img_h, grid_w * img_h] + list(images.shape[-1:]), dtype=images.dtype) | |
# Paste each image in the grid | |
for idx in range(num): | |
x = (idx % grid_w) * img_w | |
y = (idx // grid_w) * img_h | |
grid[y:y + img_h, x:x + img_w, ...] = images[idx] | |
return grid | |
# ---------------------------------------------------------------------------- | |
def parse_fps(fps: Union[str, int]) -> int: | |
"""Return FPS for the video; at worst, video will be 1 FPS, but no lower. | |
Useful if we don't have click, else simply use click.IntRange(min=1)""" | |
if isinstance(fps, int): | |
return max(fps, 1) | |
try: | |
fps = int(atof(fps)) | |
return max(fps, 1) | |
except ValueError: | |
print(f'Typo in "--fps={fps}", will use default value of 30') | |
return 30 | |
def num_range(s: str, remove_repeated: bool = False) -> List[int]: | |
""" | |
Extended helper function from the original (original is contained here). | |
Accept a comma separated list of numbers 'a,b,c', a range 'a-c', or a combination | |
of both 'a,b-c', 'a-b,c', 'a,b-c,d,e-f,...', and return as a list of ints. | |
""" | |
nums = [] | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
for el in s.split(','): | |
match = range_re.match(el) | |
if match: | |
# Sanity check 1: accept ranges 'a-b' or 'b-a', with a<=b | |
lower, upper = int(match.group(1)), int(match.group(2)) | |
if lower <= upper: | |
r = list(range(lower, upper + 1)) | |
else: | |
r = list(range(upper, lower + 1)) | |
# We will extend nums as r is also a list | |
nums.extend(r) | |
else: | |
# It's a single number, so just append it (if it's an int) | |
try: | |
nums.append(int(atof(el))) | |
except ValueError: | |
continue # we ignore bad values | |
# Sanity check 2: delete repeating numbers by default, but keep order given by user | |
if remove_repeated: | |
nums = list(OrderedDict.fromkeys(nums)) | |
return nums | |
def float_list(s: str) -> List[float]: | |
""" | |
Helper function for parsing a string of comma-separated floats and returning each float | |
""" | |
str_list = s.split(',') | |
nums = [] | |
float_re = re.compile(r'^(\d+.\d+)$') | |
for el in str_list: | |
match = float_re.match(el) | |
if match: | |
nums.append(float(match.group(1))) | |
else: | |
try: | |
nums.append(float(el)) | |
except ValueError: | |
continue # Ignore bad values | |
return nums | |
def parse_slowdown(slowdown: Union[str, int]) -> int: | |
"""Function to parse the 'slowdown' parameter by the user. Will approximate to the nearest power of 2.""" | |
# TODO: slowdown should be any int | |
if not isinstance(slowdown, int): | |
try: | |
slowdown = atof(slowdown) | |
except ValueError: | |
print(f'Typo in "{slowdown}"; will use default value of 1') | |
slowdown = 1 | |
assert slowdown > 0, '"slowdown" cannot be negative or 0!' | |
# Let's approximate slowdown to the closest power of 2 (nothing happens if it's already a power of 2) | |
slowdown = 2**int(np.rint(np.log2(slowdown))) | |
return max(slowdown, 1) # Guard against 0.5, 0.25, ... cases | |
def parse_new_center(s: str) -> Tuple[str, Union[int, Tuple[np.ndarray, Optional[str]]]]: | |
"""Get a new center for the W latent space (a seed or projected dlatent; to be transformed later)""" | |
try: | |
new_center = int(s) # it's a seed | |
return s, new_center | |
except ValueError: | |
new_center = get_latent_from_file(s, return_ext=False) # it's a projected dlatent | |
return s, new_center | |
def parse_all_projected_dlatents(s: str) -> List[torch.Tensor]: | |
"""Get all the dlatents (.npy/.npz files) in a given directory""" | |
# Get all the files in the directory and subdirectories | |
files = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(s)) for f in fn] | |
# Filter only the .npy or .npz files | |
files = [f for f in files if f.endswith('.npy') or f.endswith('.npz')] | |
# Sort them by name, but only according to the last digits in the name (in case there's an error before) | |
files = sorted(files, key=lambda x: int(''.join(filter(str.isdigit, x)))) | |
# Get the full path | |
# files = [os.path.join(s, f) for f in files] | |
# Get the dlatents | |
dlatents = [get_latent_from_file(f, return_ext=False) for f in files] | |
return dlatents | |
def load_network(name: str, network_pkl: Union[str, os.PathLike], cfg: Optional[str], device: torch.device): | |
"""Load and return the discriminator D from a trained network.""" | |
# Define the model | |
if cfg is not None: | |
assert network_pkl in resume_specs[cfg], f'This model is not available for config {cfg}!' | |
network_pkl = resume_specs[cfg][network_pkl] | |
print(f'Loading networks from "{network_pkl}"...') | |
with dnnlib.util.open_url(network_pkl) as f: | |
net = legacy.load_network_pkl(f)[name].eval().requires_grad_(False).to(device) # type: ignore | |
return net | |
def parse_class(G, class_idx: int, ctx: click.Context) -> Union[int, Type[None]]: | |
"""Parse the class_idx and return it, if it's allowed by the conditional model G""" | |
if G.c_dim == 0: | |
# Unconditional model | |
return None | |
# Conditional model, so class must be specified by user | |
if class_idx is None: | |
ctx.fail('Must specify class label with --class when using a conditional network!') | |
if class_idx not in range(G.c_dim): | |
ctx.fail(f'Your class label can be at most {G.c_dim - 1}!') | |
print(f'Using class {class_idx} (available labels: range({G.c_dim - 1})...)') | |
return class_idx | |
# ---------------------------------------------------------------------------- | |
def save_video_from_images(run_dir: str, | |
image_names: str, | |
video_name: str, | |
fps: int = 30, | |
reverse_video: bool = True, | |
crf: int = 20, | |
pix_fmt: str = 'yuv420p') -> None: | |
""" Save a .mp4 video from the images in the run_dir directory; the video can also be saved in reverse """ | |
print('Saving video...') | |
try: | |
import ffmpeg | |
except ImportError: | |
raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"') | |
# Get the ffmpeg command for the current OS (not tested in MacOS!) | |
if os.name == 'nt': | |
ffmpeg_command = r'C:\\Ffmpeg\\bin\\ffmpeg.exe' | |
else: | |
# Get where is the ffmpeg command via `whereis ffmpeg` in the terminal | |
ffmpeg_command = os.popen('whereis ffmpeg').read().split(' ')[1:] | |
# Remove any ffprobe and ffplay commands | |
ffmpeg_command = [c for c in ffmpeg_command if 'ffprobe' not in c and 'ffplay' not in c] | |
# If there are more, just select the first one and remove the newline character | |
ffmpeg_command = ffmpeg_command[0].replace('\n', '') | |
stream = ffmpeg.input(os.path.join(run_dir, image_names), framerate=fps) | |
stream = ffmpeg.output(stream, os.path.join(run_dir, f'{video_name}.mp4'), crf=crf, pix_fmt=pix_fmt) | |
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True, cmd=ffmpeg_command) | |
# Save the reversed video apart from the original one, so the user can compare both | |
if reverse_video: | |
stream = ffmpeg.input(os.path.join(run_dir, f'{video_name}.mp4')) | |
stream = stream.video.filter('reverse') | |
stream = ffmpeg.output(stream, os.path.join(run_dir, f'{video_name}_reversed.mp4'), crf=crf, pix_fmt=pix_fmt) | |
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True) # ibidem | |
def compress_video( | |
original_video: Union[str, os.PathLike], | |
original_video_name: Union[str, os.PathLike], | |
outdir: Union[str, os.PathLike], | |
ctx: click.Context) -> None: | |
""" Helper function to compress the original_video using ffmpeg-python. moviepy creates huge videos, so use | |
ffmpeg to 'compress' it (won't be perfect, 'compression' will depend on the video dimensions). ffmpeg | |
can also be used to e.g. resize the video, make a GIF, save all frames in the video to the outdir, etc. | |
""" | |
try: | |
import ffmpeg | |
except (ModuleNotFoundError, ImportError): | |
ctx.fail('Missing ffmpeg! Install it via "pip install ffmpeg-python"') | |
print('Compressing the video...') | |
resized_video_name = os.path.join(outdir, f'{original_video_name}-compressed.mp4') | |
ffmpeg.input(original_video).output(resized_video_name).run(capture_stdout=True, capture_stderr=True) | |
print('Success!') | |
# ---------------------------------------------------------------------------- | |
def interpolation_checks( | |
t: Union[float, np.ndarray], | |
v0: np.ndarray, | |
v1: np.ndarray) -> Tuple[Union[float, np.ndarray], np.ndarray, np.ndarray]: | |
"""Tests for the interpolation functions""" | |
# Make sure 0.0<=t<=1.0 | |
assert np.min(t) >= 0.0 and np.max(t) <= 1.0 | |
# Guard against v0 and v1 not being NumPy arrays | |
if not isinstance(v0, np.ndarray): | |
v0 = np.array(v0) | |
if not isinstance(v1, np.ndarray): | |
v1 = np.array(v1) | |
# Both should have the same shape in order to interpolate between them | |
assert v0.shape == v1.shape, f'Incompatible shapes! v0: {v0.shape}, v1: {v1.shape}' | |
return t, v0, v1 | |
def lerp( | |
t: Union[float, np.ndarray], | |
v0: Union[float, list, tuple, np.ndarray], | |
v1: Union[float, list, tuple, np.ndarray]) -> np.ndarray: | |
""" | |
Linear interpolation between v0 (starting) and v1 (final) vectors; for optimal results, | |
use t as an np.ndarray to return all results at once via broadcasting | |
""" | |
t, v0, v1 = interpolation_checks(t, v0, v1) | |
v2 = (1.0 - t) * v0 + t * v1 | |
return v2 | |
def slerp( | |
t: Union[float, np.ndarray], | |
v0: Union[float, list, tuple, np.ndarray], | |
v1: Union[float, list, tuple, np.ndarray], | |
dot_threshold: float = 0.9995) -> np.ndarray: | |
""" | |
Spherical linear interpolation between v0 (starting) and v1 (final) vectors; for optimal | |
results, use t as an np.ndarray to return all results at once via broadcasting. | |
dot_threshold is the threshold for considering if the two vectors are collinear (not recommended to alter). | |
Adapted from the Python code at: https://en.wikipedia.org/wiki/Slerp (at the time, now no longer available). | |
Most likely taken from Jonathan Blow's code in C++: | |
http://number-none.com/product/Understanding%20Slerp,%20Then%20Not%20Using%20It | |
""" | |
t, v0, v1 = interpolation_checks(t, v0, v1) | |
# Copy vectors to reuse them later | |
v0_copy = np.copy(v0) | |
v1_copy = np.copy(v1) | |
# Normalize the vectors to get the directions and angles | |
v0 = v0 / np.linalg.norm(v0) | |
v1 = v1 / np.linalg.norm(v1) | |
# Dot product with the normalized vectors (can't always use np.dot, so we use the definition) | |
dot = np.sum(v0 * v1) | |
# If it's ~1, vectors are ~colineal, so use lerp on the original vectors | |
if np.abs(dot) > dot_threshold: | |
return lerp(t, v0_copy, v1_copy) | |
# Stay within domain of arccos | |
dot = np.clip(dot, -1.0, 1.0) | |
# Calculate initial angle between v0 and v1 | |
theta_0 = np.arccos(dot) | |
sin_theta_0 = np.sin(theta_0) | |
# Divide the angle into t steps | |
theta_t = theta_0 * t | |
sin_theta_t = np.sin(theta_t) | |
# Finish the slerp algorithm | |
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = sin_theta_t / sin_theta_0 | |
v2 = s0 * v0_copy + s1 * v1_copy | |
return v2 | |
def interpolate( | |
v0: Union[float, list, tuple, np.ndarray], | |
v1: Union[float, list, tuple, np.ndarray], | |
n_steps: int, | |
interp_type: str = 'spherical', | |
smooth: bool = False) -> np.ndarray: | |
""" | |
Interpolation function between two vectors, v0 and v1. We will either do a 'linear' or 'spherical' interpolation, | |
taking n_steps. The steps can be 'smooth'-ed out, so that the transition between vectors isn't too drastic. | |
""" | |
t_array = np.linspace(0, 1, num=n_steps, endpoint=False) | |
# TODO: have a dictionary with easing functions that contains my 'smooth' one (might be useful for someone else) | |
if smooth: | |
# Smooth out the interpolation with a polynomial of order 3 (cubic function f) | |
# Constructed f by setting f'(0) = f'(1) = 0, and f(0) = 0, f(1) = 1 => f(t) = -2t^3+3t^2 = t^2 (3-2t) | |
# NOTE: I've merely rediscovered the Smoothstep function S_1(x): https://en.wikipedia.org/wiki/Smoothstep | |
t_array = t_array ** 2 * (3 - 2 * t_array) # One line thanks to NumPy arrays | |
# TODO: this might be possible to optimize by using the fact they're numpy arrays, but haven't found a nice way yet | |
funcs_dict = {'linear': lerp, 'spherical': slerp} | |
vectors = np.array([funcs_dict[interp_type](t, v0, v1) for t in t_array], dtype=np.float32) | |
return vectors | |
# ---------------------------------------------------------------------------- | |
def double_slowdown(latents: np.ndarray, duration: float, frames: int) -> Tuple[np.ndarray, float, int]: | |
""" | |
Auxiliary function to slow down the video by 2x. We return the new latents, duration, and frames of the video | |
""" | |
# Make an empty latent vector with double the amount of frames, but keep the others the same | |
z = np.empty(np.multiply(latents.shape, [2, 1, 1]), dtype=np.float32) | |
# In the even frames, populate it with the latents | |
for i in range(len(latents)): | |
z[2 * i] = latents[i] | |
# Interpolate in the odd frames | |
for i in range(1, len(z), 2): | |
# slerp between (t=0.5) even frames; for the last frame, we loop to the first one (z[0]) | |
z[i] = slerp(0.5, z[i - 1], z[i + 1]) if i != len(z) - 1 else slerp(0.5, z[0], z[i - 1]) | |
# TODO: we could change this to any slowdown: slerp(1/slowdown, ...), and we return z, slowdown * duration, ... | |
# Return the new latents, and the respective new duration and number of frames | |
return z, 2 * duration, 2 * frames | |
def global_pulsate_psi(psi_start: float, psi_end: float, n_steps: int, frequency: float = 1.0) -> torch.Tensor: | |
""" | |
Pulsate the truncation psi parameter between start and end, taking n_steps on a sinusoidal wave. | |
""" | |
alpha = (psi_start + psi_end) / (psi_start - psi_end) | |
beta = 2 / (psi_start - psi_end) | |
total_time = 2 * np.pi # This value doesn't matter in the end | |
timesteps = torch.arange(0, total_time, total_time / n_steps) | |
truncation_psi = (torch.cos(frequency * timesteps) + alpha) / beta | |
return truncation_psi | |
def wave_pulse_truncation_psi(psi_start: float, | |
psi_end: float, | |
n_steps: int, | |
grid_shape: Tuple[int, int], | |
frequency: int, | |
time: int) -> torch.Tensor: # Output shape: [num_grid_cells, 1, 1] | |
""" | |
Pulsate the truncation psi parameter between start and end, taking n_steps on a sinusoidal wave on a grid | |
Note: The output shape should be [math.prod(grid_shape), 1, 1] | |
""" | |
# Let's save some headaches, shall we? | |
if psi_start == psi_end: | |
import math | |
return torch.ones(math.prod(grid_shape), 1, 1) * psi_start | |
# We define a total time, but note it's specific to our definition of the wave below (the 2*pi in the conditions) | |
total_time = 5 * torch.pi # T | |
timesteps = torch.arange(0, total_time, total_time / n_steps) | |
# Envolope function | |
def envelope(time): | |
""" Envelope function that will regulate the amplitude of the wave; usage: envelope(time) * wave(time) """ | |
# Example: a 1D Gabor filter | |
# gaussian = torch.exp(-(time - total_time / 2) ** 2 / 16) | |
# sinusoid = torch.exp(1j * torch.pi(time - total_time / 2) / 2) | |
# return torch.sin(time * torch.pi / total_time) / 2 + 0.5 | |
return torch.tensor(1.0) | |
# Define the grid itself as a 2D grid where we will evaluate our wave function/psi | |
width, height = grid_shape | |
xs = torch.arange(0, 2*torch.pi, 2*torch.pi/width) | |
ys = torch.arange(0, 2*torch.pi, 2*torch.pi/height) | |
x, y = torch.meshgrid(xs, ys, indexing='xy') | |
# Define the wave equation (go crazy here!) | |
# In my case, I will use a sinusoidal wave with source at the upper-left corner of the grid | |
# The wave will travel radially from the source, and will be truncated at the edges of the grid with the psi_start value | |
r = torch.sqrt(x ** 2 + y ** 2) | |
# The wave function is defined by parts, that is, keep it constant (psi_start) before and after the wave; its | |
# general shape in 1D will be psi(x, t) = (cos(f(x-t)) + alpha) / beta, where alpha and beta are defined so as to | |
# satisfy the boundary conditions (psi(x, 0) = psi_start, psi(x, T/2) = psi_end, psi(x, T) = psi_start)) | |
alpha = (psi_start + psi_end) / (psi_start - psi_end) | |
beta = 2 / (psi_start - psi_end) | |
def truncate(value): | |
""" | |
Auxiliary function to interpolate between your start and end psi. Use to translate from "value=0" (psi_start) | |
to "value=1" (psi_end) """ | |
return psi_start + value * (psi_end - psi_start) | |
# Define the wave function by parts, that is, keep it constant (psi_start) before and after the wave | |
truncation_psi = torch.where(torch.gt(r, timesteps[time]) | torch.lt(r, timesteps[time] - 2 * torch.pi), | |
torch.tensor(psi_start), | |
(torch.cos(frequency * (r - timesteps[time])) + alpha) / beta) | |
# Make sure the output is of the right shape | |
truncation_psi = truncation_psi.view(width*height, 1, 1) | |
return truncation_psi | |
# ---------------------------------------------------------------------------- | |
def make_affine_transform(m: Union[torch.Tensor, np.ndarray] = None, | |
angle: float = 0.0, | |
translate_x: float = 0.0, | |
translate_y: float = 0.0, | |
scale_x: float = 1.0, | |
scale_y: float = 1.0, | |
shear_x: float = 0.0, | |
shear_y: float = 0.0, | |
mirror_x: bool = False, | |
mirror_y: bool = False) -> np.array: | |
"""Make affine transformation with the given parameters. If none are passed, will return the identity. | |
As a guide for affine transformations: https://en.wikipedia.org/wiki/Affine_transformation""" | |
# m is the starting affine transformation matrix (e.g., G.synthesis.input.transform) | |
if m is None: | |
m = np.eye(3, dtype=np.float64) | |
elif isinstance(m, torch.Tensor): | |
m = m.cpu().numpy() | |
elif isinstance(m, np.ndarray): | |
pass | |
# Remember these are the inverse transformations! | |
# Rotation matrix | |
rotation_matrix = np.array([[np.cos(angle), np.sin(angle), 0.0], | |
[-np.sin(angle), np.cos(angle), 0.0], | |
[0.0, 0.0, 1.0]], dtype=np.float64) | |
# Translation matrix | |
translation_matrix = np.array([[1.0, 0.0, -translate_x], | |
[0.0, 1.0, -translate_y], | |
[0.0, 0.0, 1.0]], dtype=np.float64) | |
# Scale matrix (don't let it go into negative or 0) | |
scale_matrix = np.array([[1. / max(scale_x, 1e-4), 0.0, 0.0], | |
[0.0, 1. / max(scale_y, 1e-4), 0.0], | |
[0.0, 0.0, 1.0]], dtype=np.float64) | |
# Shear matrix | |
shear_matrix = np.array([[1.0, -shear_x, 0.0], | |
[-shear_y, 1.0, 0.0], | |
[0.0, 0.0, 1.0]], dtype=np.float64) | |
# Mirror/reflection in x matrix | |
xmirror_matrix = np.array([[1.0 - 2 * mirror_x, 0.0, 0.0], | |
[0.0, 1.0, 0.0], | |
[0.0, 0.0, 1.0]], dtype=np.float64) | |
# Mirror/reflection in y matrix | |
ymirror_matrix = np.array([[1.0, 0.0, 0.0], | |
[0.0, 1.0 - 2 * mirror_y, 0.0], | |
[0.0, 0.0, 1.0]], dtype=np.float64) | |
# Make the resulting affine transformation (note that these are non-commutative, so we *choose* this order) | |
m = m @ rotation_matrix @ translation_matrix @ scale_matrix @ shear_matrix @ xmirror_matrix @ ymirror_matrix | |
return m | |
def anchor_latent_space(G) -> None: | |
# Thanks to @RiversHaveWings and @nshepperd1 | |
if hasattr(G.synthesis, 'input'): | |
# Unconditional models differ by a bit | |
if G.c_dim == 0: | |
shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0)).squeeze(0) | |
else: | |
shift = G.synthesis.input.affine(G.mapping.w_avg).mean(0) | |
G.synthesis.input.affine.bias.data.add_(shift) | |
G.synthesis.input.affine.weight.data.zero_() | |
def force_fp32(G) -> None: | |
"""Force fp32 as in during training""" | |
G.synthesis.num_fp16_res = 0 | |
for name, layer in G.synthesis.named_modules(): | |
if hasattr(layer, 'conv_clamp'): | |
layer.conv_clamp = None | |
layer.use_fp16 = False | |
def use_cpu(G) -> None: | |
"""Use the CPU instead of the GPU; force_fp32 must be set to True, apart from the device setting""" | |
# @nurpax found this before: https://github.com/NVlabs/stylegan2-ada-pytorch/issues/54#issuecomment-793713965, but we | |
# will use @JCBrouwer's solution: https://github.com/NVlabs/stylegan2-ada-pytorch/issues/105#issuecomment-838577639 | |
import functools | |
G.forward = functools.partial(G.forward, force_fp32=True) | |
# ---------------------------------------------------------------------------- | |
resume_specs = { | |
# For StyleGAN2/ADA models; --cfg=stylegan2 | |
'stylegan2': { | |
# Official NVIDIA models | |
'ffhq256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl', | |
'ffhqu256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-256x256.pkl', | |
'ffhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl', | |
'ffhq1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl', | |
'ffhqu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl', | |
'celebahq256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-celebahq-256x256.pkl', | |
'lsundog256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-lsundog-256x256.pkl', | |
'afhqcat512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl', | |
'afhqdog512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqdog-512x512.pkl', | |
'afhqwild512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqwild-512x512.pkl', | |
'afhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqv2-512x512.pkl', | |
'brecahad512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-brecahad-512x512.pkl', | |
'cifar10': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-cifar10-32x32.pkl', | |
'metfaces1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfaces-1024x1024.pkl', | |
'metfacesu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfacesu-1024x1024.pkl', | |
# Other configs are available at: https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/, but I will list here the config-f only | |
'lsuncar512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-f.pkl', # config-f | |
'lsuncat256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl', # config-f | |
'lsunchurch256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-church-config-f.pkl', # config-f | |
'lsunhorse256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-horse-config-f.pkl', # config-f | |
# Community models. More can be found at: https://github.com/justinpinkney/awesome-pretrained-stylegan2 by @justinpinkney, but weren't added here | |
'minecraft1024': 'https://github.com/jeffheaton/pretrained-gan-minecraft/releases/download/v1/minecraft-gan-2020-12-22.pkl', # Thanks to @jeffheaton | |
'imagenet512': 'https://battle.shawwn.com/sdc/stylegan2-imagenet-512/model.ckpt-533504.pkl', # Thanks to @shawwn | |
'wikiart1024-C': 'https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt5.pkl', # Thanks to @pbaylies; conditional (167 classes in total: --class=0 to 166) | |
'wikiart1024-U': 'https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt_Uncond2.pkl', # Thanks to @pbaylies; unconditional | |
'maps1024': 'https://archive.org/download/mapdreamer/mapdreamer.pkl', # Thanks to @tjukanov | |
'fursona512': 'https://thisfursonadoesnotexist.com/model/network-e621-r-512-3194880.pkl', # Thanks to @arfafax | |
'mlpony512': 'https://thisponydoesnotexist.net/model/network-ponies-1024-151552.pkl', # Thanks to @arfafax | |
'lhq1024': 'https://huggingface.co/justinpinkney/lhq-sg2-1024/resolve/main/lhq-sg2-1024.pkl', # Thanks to @justinpinkney | |
# Deceive-D/APA models (ignoring the faces models): https://github.com/EndlessSora/DeceiveD | |
'afhqcat256': 'https://drive.google.com/u/0/uc?export=download&confirm=zFoN&id=1P9ouHIK-W8JTb6bvecfBe4c_3w6gmMJK', | |
'anime256': 'https://drive.google.com/u/0/uc?export=download&confirm=6Uie&id=1EWOdieqELYmd2xRxUR4gnx7G10YI5dyP', | |
'cub256': 'https://drive.google.com/u/0/uc?export=download&confirm=KwZS&id=1J0qactT55ofAvzddDE_xnJEY8s3vbo1_', | |
# Self-Distilled StyleGAN (full body representation of each class): https://github.com/self-distilled-stylegan/self-distilled-internet-photos | |
'sddogs1024': 'https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl', | |
'sdelephant512': 'https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl', | |
'sdhorses256': 'https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl', | |
'sdbicycles256': 'https://storage.googleapis.com/self-distilled-stylegan/bicycles_256_pytorch.pkl', | |
'sdlions512': 'https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl', | |
'sdgiraffes512': 'https://storage.googleapis.com/self-distilled-stylegan/giraffes_512_pytorch.pkl', | |
'sdparrots512': 'https://storage.googleapis.com/self-distilled-stylegan/parrots_512_pytorch.pkl' | |
}, | |
# For StyleGAN2 extended (--cfg=styelgan2-ext) | |
'stylegan2-ext': { | |
'anime512': 'https://drive.google.com/u/0/uc?export=download&confirm=zFoN&id=1A-E_E32WAtTHRlOzjhhYhyyBDXLJN9_H' # Thanks to @aydao | |
}, | |
# For StyleGAN3 config-r models (--cfg=stylegan3-r) | |
'stylegan3-r': { | |
# Official NVIDIA models | |
'afhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl', | |
'ffhq1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl', | |
'ffhqu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl', | |
'ffhqu256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-256x256.pkl', | |
'metfaces1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-metfaces-1024x1024.pkl', | |
'metfacesu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-metfacesu-1024x1024.pkl', | |
}, | |
# For StyleGAN3 config-t models (--cfg=stylegan3-t) | |
'stylegan3-t': { | |
# Official NVIDIA models | |
'afhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-afhqv2-512x512.pkl', | |
'ffhq1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl', | |
'ffhqu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl', | |
'ffhqu256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl', | |
'metfaces1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfaces-1024x1024.pkl', | |
'metfacesu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl', | |
# Community models, found in: https://github.com/justinpinkney/awesome-pretrained-stylegan3 by @justinpinkney | |
'landscapes256': 'https://drive.google.com/u/0/uc?export=download&confirm=eJHe&id=14UGDDOusZ9TMb-pOrF0PAjMGVWLSAii1', # Thanks to @justinpinkney | |
'wikiart1024': 'https://drive.google.com/u/0/uc?export=download&confirm=2tz5&id=18MOpwTMJsl_Z17q-wQVnaRLCUFZYSNkj', # Thanks to @justinpinkney | |
# -> Wombo Dream-based models found in: https://github.com/edstoica/lucid_stylegan3_datasets_models by @edstoica; TODO: more to come, update the list as they are released! | |
'mechfuture256': 'https://www.dropbox.com/s/v2oie53cz62ozvu/network-snapshot-000029.pkl?dl=1', # Thanks to @edstoica; 29kimg tick | |
'vivflowers256': 'https://www.dropbox.com/s/o33lhgnk91hstvx/network-snapshot-000069.pkl?dl=1', # Thanks to @edstoica; 68kimg tick | |
'alienglass256': 'https://www.dropbox.com/s/gur14k0e7kspguy/network-snapshot-000038.pkl?dl=1', # Thanks to @edstoica; 38kimg tick | |
'scificity256': 'https://www.dropbox.com/s/1kfsmlct4mriphc/network-snapshot-000210.pkl?dl=1', # Thanks to @edstoica; 210kimg tick | |
'scifiship256': 'https://www.dropbox.com/s/02br3mjkma1hubc/network-snapshot-000162.pkl?dl=1', # Thanks to @edstoica; 168kimg tick | |
} | |
} | |
# ---------------------------------------------------------------------------- | |
# TODO: all of the following functions must work for RGBA images | |
def w_to_img(G, dlatents: Union[List[torch.Tensor], torch.Tensor], | |
noise_mode: str = 'const', | |
new_w_avg: torch.Tensor = None, | |
truncation_psi: float = 1.0) -> np.ndarray: | |
""" | |
Get an image/np.ndarray from a dlatent W using G and the selected noise_mode. The final shape of the | |
returned image will be [len(dlatents), G.img_resolution, G.img_resolution, G.img_channels]. | |
Note: this function should be used after doing the truncation trick! | |
Note: Optionally, you can also pass a new_w_avg to use instead of the one in G, with a reverse | |
truncation trick | |
""" | |
# If we have a single dlatent, we need to add a batch dimension | |
assert isinstance(dlatents, torch.Tensor), f'dlatents should be a torch.Tensor!: "{type(dlatents)}"' | |
if len(dlatents.shape) == 2: | |
dlatents = dlatents.unsqueeze(0) # An individual dlatent => [1, G.mapping.num_ws, G.mapping.w_dim] | |
if new_w_avg is not None: | |
new_w_avg = new_w_avg.to(next(G.parameters()).device) | |
dlatents = (dlatents - new_w_avg) * (1 - truncation_psi) + new_w_avg | |
synth_image = G.synthesis(dlatents, noise_mode=noise_mode) | |
synth_image = (synth_image + 1) * 255/2 # [-1.0, 1.0] -> [0.0, 255.0] | |
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy() # NCWH => NWHC | |
return synth_image | |
def z_to_dlatent(G, latents: torch.Tensor, label: torch.Tensor, truncation_psi: float = 1.0) -> torch.Tensor: | |
"""Get the dlatent from the given latent, class label and truncation psi""" | |
assert isinstance(latents, torch.Tensor), f'latents should be a torch.Tensor!: "{type(latents)}"' | |
assert isinstance(label, torch.Tensor), f'label should be a torch.Tensor!: "{type(label)}"' | |
if len(latents.shape) == 1: | |
latents = latents.unsqueeze(0) # An individual latent => [1, G.z_dim] | |
dlatents = G.mapping(z=latents, c=label, truncation_psi=truncation_psi) | |
return dlatents | |
def z_to_img(G, latents: torch.Tensor, label: torch.Tensor, truncation_psi: float, noise_mode: str = 'const') -> np.ndarray: | |
""" | |
Get an image/np.ndarray from a latent Z using G, the label, truncation_psi, and noise_mode. The shape | |
of the output image/np.ndarray will be [len(latents), G.img_resolution, G.img_resolution, G.img_channels] | |
""" | |
dlatents = z_to_dlatent(G=G, latents=latents, label=label, truncation_psi=1.0) | |
dlatents = G.mapping.w_avg + (G.mapping.w_avg - dlatents) * truncation_psi | |
img = w_to_img(G=G, dlatents=dlatents, noise_mode=noise_mode) # Let's not redo code | |
return img | |
def get_w_from_seed(G, device: torch.device, seed: int, truncation_psi: float, new_w_avg: torch.Tensor = None) -> torch.Tensor: | |
"""Get the dlatent from a random seed, using the truncation trick (this could be optional)""" | |
z = np.random.RandomState(seed).randn(1, G.z_dim) | |
w = G.mapping(torch.from_numpy(z).to(device), None) | |
w_avg = G.mapping.w_avg if new_w_avg is None else new_w_avg.to(device) | |
w = w_avg + (w - w_avg) * truncation_psi | |
return w | |
def get_latent_from_file(file: Union[str, os.PathLike], | |
return_ext: bool = False, | |
named_latent: str = 'w') -> Tuple[np.ndarray, Optional[str]]: | |
"""Get dlatent (w) from a .npy or .npz file""" | |
filename, file_extension = os.path.splitext(file) | |
assert file_extension in ['.npy', '.npz'], f'"{file}" has wrong file format! Only ".npy" or ".npz" are allowed' | |
if file_extension == '.npy': | |
latent = np.load(file) | |
extension = '.npy' | |
else: | |
latent = np.load(file)[named_latent] | |
extension = '.npz' | |
if len(latent.shape) == 4: | |
latent = latent[0] | |
return (latent, extension) if return_ext else latent | |
# ---------------------------------------------------------------------------- | |
def save_config(ctx: click.Context, run_dir: Union[str, os.PathLike], save_name: str = 'config.json') -> None: | |
"""Save the configuration stored in ctx.obj into a JSON file at the output directory.""" | |
with open(os.path.join(run_dir, save_name), 'w') as f: | |
json.dump(ctx.obj, f, indent=4, sort_keys=True) | |
# ---------------------------------------------------------------------------- | |
def make_run_dir(outdir: Union[str, os.PathLike], desc: str, dry_run: bool = False) -> str: | |
"""Reject modernity, return to automatically create the run dir.""" | |
# Pick output directory. | |
prev_run_dirs = [] | |
if os.path.isdir(outdir): # sanity check, but click.Path() should clear this one | |
prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] | |
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] | |
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] | |
cur_run_id = max(prev_run_ids, default=-1) + 1 # start with 00000 | |
run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}') | |
assert not os.path.exists(run_dir) # make sure it doesn't already exist | |
# Don't create the dir if it's a dry-run | |
if not dry_run: | |
print('Creating output directory...') | |
os.makedirs(run_dir) | |
return run_dir | |