ybelkada's picture
commit files
4d6b877
raw
history blame
9.56 kB
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators."""
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config
#----------------------------------------------------------------------------
# Helpers for loading and using pre-trained generators.
url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl
url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl
url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl
url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl
synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
_Gs_cache = dict()
def load_Gs(url):
if url not in _Gs_cache:
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
_Gs_cache[url] = Gs
return _Gs_cache[url]
#----------------------------------------------------------------------------
# Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images.
def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed):
print(png)
latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1])
images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]
canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')
image_iter = iter(list(images))
for col, lod in enumerate(lods):
for row in range(rows * 2**lod):
image = PIL.Image.fromarray(next(image_iter), 'RGB')
image = image.crop((cx, cy, cx + cw, cy + ch))
image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)
canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 3: Style mixing.
def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
print(png)
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)
canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')
for col, src_image in enumerate(list(src_images)):
canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
for row, dst_image in enumerate(list(dst_images)):
canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
for col, image in enumerate(list(row_images)):
canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 4: Noise detail.
def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds):
print(png)
canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white')
for row, seed in enumerate(seeds):
latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples)
images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs)
canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h))
for i in range(4):
crop = PIL.Image.fromarray(images[i + 1], 'RGB')
crop = crop.crop((650, 180, 906, 436))
crop = crop.resize((w//2, h//2), PIL.Image.NEAREST)
canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2))
diff = np.std(np.mean(images, axis=3), axis=0) * 4
diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)
canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 5: Noise components.
def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips):
print(png)
Gsc = Gs.clone()
noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')]
noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...]
latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
all_images = []
for noise_range in noise_ranges:
tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)})
range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs)
range_images[flips, :, :] = range_images[flips, :, ::-1]
all_images.append(list(range_images))
canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white')
for col, col_images in enumerate(zip(*all_images)):
canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0))
canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0))
canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h))
canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 8: Truncation trick.
def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis):
print(png)
latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component]
dlatent_avg = Gs.get_var('dlatent_avg') # [component]
canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')
for row, dlatent in enumerate(list(dlatents)):
row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg
row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
for col, image in enumerate(list(row_images)):
canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h))
canvas.save(png)
#----------------------------------------------------------------------------
# Main program.
def main():
tflib.init_tf()
os.makedirs(config.result_dir, exist_ok=True)
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5)
draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)])
draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012])
draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1])
draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1])
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0)
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2)
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1)
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------