File size: 2,328 Bytes
81170fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import functools
import logging
import os

import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
from tqdm import tqdm

import checkpoint
from stylegan2.generator import Generator

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
logger = logging.getLogger(__name__)


def generate_images(args):
    logger.info(f"Loading checking '{args.checkpoint}'...")
    ckpt = checkpoint.load_checkpoint(args.checkpoint)
    config = ckpt['config']
    params_ema_G = ckpt['params_ema_G']

    generator_ema = Generator(
        resolution=config.resolution,
        num_channels=config.img_channels,
        z_dim=config.z_dim,
        c_dim=config.c_dim,
        w_dim=config.w_dim,
        num_ws=int(np.log2(config.resolution)) * 2 - 3,
        num_mapping_layers=8,
        fmap_base=config.fmap_base,
        dtype=jnp.float32
    )

    generator_apply = jax.jit(
        functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const')
    )

    logger.info(f"Generating {len(args.seeds)} images with truncation {args.truncation_psi}...")
    for seed in tqdm(args.seeds):
        rng = jax.random.PRNGKey(seed)
        z_latent = jax.random.normal(rng, shape=(1, config.z_dim))
        image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), None)
        image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))

        Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png'))
    logger.info(f"Images saved in '{args.out_path}/'")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint.', required=True)
    parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.')
    parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.')
    parser.add_argument('--seeds', type=int, nargs='*', default=[0], help='List of random seeds.')
    args = parser.parse_args()
    os.makedirs(args.out_path, exist_ok=True)

    generate_images(args)