|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Generate images using pretrained network pickle.""" |
|
|
|
import os |
|
import re |
|
from typing import List, Optional |
|
|
|
import click |
|
import dnnlib |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
|
|
import legacy |
|
|
|
|
|
|
|
|
|
def num_range(s: str) -> List[int]: |
|
"""Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.""" |
|
|
|
range_re = re.compile(r"^(\d+)-(\d+)$") |
|
m = range_re.match(s) |
|
if m: |
|
return list(range(int(m.group(1)), int(m.group(2)) + 1)) |
|
vals = s.split(",") |
|
return [int(x) for x in vals] |
|
|
|
|
|
|
|
|
|
|
|
@click.command() |
|
@click.pass_context |
|
@click.option("--network", "network_pkl", help="Network pickle filename", required=True) |
|
@click.option("--seeds", type=num_range, help="List of random seeds") |
|
@click.option( |
|
"--trunc", |
|
"truncation_psi", |
|
type=float, |
|
help="Truncation psi", |
|
default=1, |
|
show_default=True, |
|
) |
|
@click.option( |
|
"--class", |
|
"class_idx", |
|
type=int, |
|
help="Class label (unconditional if not specified)", |
|
) |
|
@click.option( |
|
"--noise-mode", |
|
help="Noise mode", |
|
type=click.Choice(["const", "random", "none"]), |
|
default="const", |
|
show_default=True, |
|
) |
|
@click.option("--projected-w", help="Projection result file", type=str, metavar="FILE") |
|
@click.option( |
|
"--outdir", |
|
help="Where to save the output images", |
|
type=str, |
|
required=True, |
|
metavar="DIR", |
|
) |
|
def generate_images( |
|
ctx: click.Context, |
|
network_pkl: str, |
|
seeds: Optional[List[int]], |
|
truncation_psi: float, |
|
noise_mode: str, |
|
outdir: str, |
|
class_idx: Optional[int], |
|
projected_w: Optional[str], |
|
): |
|
"""Generate images using pretrained network pickle. |
|
|
|
Examples: |
|
|
|
\b |
|
# Generate curated MetFaces images without truncation (Fig.10 left) |
|
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl |
|
|
|
\b |
|
# Generate uncurated MetFaces images with truncation (Fig.12 upper left) |
|
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl |
|
|
|
\b |
|
# Generate class conditional CIFAR-10 images (Fig.17 left, Car) |
|
python generate.py --outdir=out --seeds=0-35 --class=1 \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl |
|
|
|
\b |
|
# Render an image from projected W |
|
python generate.py --outdir=out --projected_w=projected_w.npz \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl |
|
""" |
|
|
|
print('Loading networks from "%s"...' % network_pkl) |
|
device = torch.device("cuda") |
|
with dnnlib.util.open_url(network_pkl) as f: |
|
G = legacy.load_network_pkl(f)["G_ema"].to(device) |
|
|
|
os.makedirs(outdir, exist_ok=True) |
|
|
|
|
|
if projected_w is not None: |
|
if seeds is not None: |
|
print("warn: --seeds is ignored when using --projected-w") |
|
print(f'Generating images from projected W "{projected_w}"') |
|
ws = np.load(projected_w)["w"] |
|
ws = torch.tensor(ws, device=device) |
|
assert ws.shape[1:] == (G.num_ws, G.w_dim) |
|
for idx, w in enumerate(ws): |
|
img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
img = PIL.Image.fromarray(img[0].cpu().numpy(), "RGB").save( |
|
f"{outdir}/proj{idx:02d}.png" |
|
) |
|
return |
|
|
|
if seeds is None: |
|
ctx.fail("--seeds option is required when not using --projected-w") |
|
|
|
|
|
label = torch.zeros([1, G.c_dim], device=device) |
|
if G.c_dim != 0: |
|
if class_idx is None: |
|
ctx.fail( |
|
"Must specify class label with --class when using a conditional network" |
|
) |
|
label[:, class_idx] = 1 |
|
else: |
|
if class_idx is not None: |
|
print("warn: --class=lbl ignored when running on an unconditional network") |
|
|
|
|
|
for seed_idx, seed in enumerate(seeds): |
|
print("Generating image for seed %d (%d/%d) ..." % (seed, seed_idx, len(seeds))) |
|
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) |
|
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
PIL.Image.fromarray(img[0].cpu().numpy(), "RGB").save( |
|
f"{outdir}/seed{seed:04d}.png" |
|
) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
generate_images() |
|
|
|
|
|
|