Spaces:
Build error
Build error
import argparse | |
import functools | |
import logging | |
import os | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
import checkpoint | |
from stylegan2.generator import Generator | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True) | |
logger = logging.getLogger(__name__) | |
def generate_images(args): | |
logger.info(f"Loading checking '{args.checkpoint}'...") | |
ckpt = checkpoint.load_checkpoint(args.checkpoint) | |
config = ckpt['config'] | |
params_ema_G = ckpt['params_ema_G'] | |
generator_ema = Generator( | |
resolution=config.resolution, | |
num_channels=config.img_channels, | |
z_dim=config.z_dim, | |
c_dim=config.c_dim, | |
w_dim=config.w_dim, | |
num_ws=int(np.log2(config.resolution)) * 2 - 3, | |
num_mapping_layers=8, | |
fmap_base=config.fmap_base, | |
dtype=jnp.float32 | |
) | |
generator_apply = jax.jit( | |
functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const') | |
) | |
logger.info(f"Generating {len(args.seeds)} images with truncation {args.truncation_psi}...") | |
for seed in tqdm(args.seeds): | |
rng = jax.random.PRNGKey(seed) | |
z_latent = jax.random.normal(rng, shape=(1, config.z_dim)) | |
image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), None) | |
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) | |
Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png')) | |
logger.info(f"Images saved in '{args.out_path}/'") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint.', required=True) | |
parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.') | |
parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.') | |
parser.add_argument('--seeds', type=int, nargs='*', default=[0], help='List of random seeds.') | |
args = parser.parse_args() | |
os.makedirs(args.out_path, exist_ok=True) | |
generate_images(args) | |