# python3.7 """A simple tool to synthesize images with pre-trained models.""" import os import argparse import subprocess from tqdm import tqdm import numpy as np import torch from models import MODEL_ZOO from models import build_generator from utils.misc import bool_parser from utils.visualizer import HtmlPageVisualizer from utils.visualizer import postprocess_image from utils.visualizer import save_image def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser( description='Synthesize images with pre-trained models.') parser.add_argument('model_name', type=str, help='Name to the pre-trained model.') parser.add_argument('--save_dir', type=str, default=None, help='Directory to save the results. If not specified, ' 'the results will be saved to ' '`work_dirs/synthesis/` by default. ' '(default: %(default)s)') parser.add_argument('--num', type=int, default=100, help='Number of samples to synthesize. ' '(default: %(default)s)') parser.add_argument('--batch_size', type=int, default=1, help='Batch size. (default: %(default)s)') parser.add_argument('--generate_html', type=bool_parser, default=True, help='Whether to use HTML page to visualize the ' 'synthesized results. (default: %(default)s)') parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False, help='Whether to save raw synthesis. ' '(default: %(default)s)') parser.add_argument('--seed', type=int, default=0, help='Seed for sampling. (default: %(default)s)') parser.add_argument('--trunc_psi', type=float, default=0.7, help='Psi factor used for truncation. This is ' 'particularly applicable to StyleGAN (v1/v2). ' '(default: %(default)s)') parser.add_argument('--trunc_layers', type=int, default=8, help='Number of layers to perform truncation. This is ' 'particularly applicable to StyleGAN (v1/v2). ' '(default: %(default)s)') parser.add_argument('--randomize_noise', type=bool_parser, default=False, help='Whether to randomize the layer-wise noise. This ' 'is particularly applicable to StyleGAN (v1/v2). ' '(default: %(default)s)') return parser.parse_args() def main(): """Main function.""" args = parse_args() if args.num <= 0: return if not args.save_raw_synthesis and not args.generate_html: return # Parse model configuration. if args.model_name not in MODEL_ZOO: raise SystemExit(f'Model `{args.model_name}` is not registered in ' f'`models/model_zoo.py`!') model_config = MODEL_ZOO[args.model_name].copy() url = model_config.pop('url') # URL to download model if needed. # Get work directory and job name. if args.save_dir: work_dir = args.save_dir else: work_dir = os.path.join('work_dirs', 'synthesis') os.makedirs(work_dir, exist_ok=True) job_name = f'{args.model_name}_{args.num}' if args.save_raw_synthesis: os.makedirs(os.path.join(work_dir, job_name), exist_ok=True) # Build generation and get synthesis kwargs. print(f'Building generator for model `{args.model_name}` ...') generator = build_generator(**model_config) synthesis_kwargs = dict(trunc_psi=args.trunc_psi, trunc_layers=args.trunc_layers, randomize_noise=args.randomize_noise) print(f'Finish building generator.') # Load pre-trained weights. os.makedirs('/import/nobackup_mmv_ioannisp/jo001/genforce_models', exist_ok=True) checkpoint_path = os.path.join('/import/nobackup_mmv_ioannisp/jo001/genforce_models', args.model_name + '.pth') print(f'Loading checkpoint from `{checkpoint_path}` ...') if not os.path.exists(checkpoint_path): print(f' Downloading checkpoint from `{url}` ...') subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url]) print(f' Finish downloading checkpoint.') checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'generator_smooth' in checkpoint: generator.load_state_dict(checkpoint['generator_smooth']) else: generator.load_state_dict(checkpoint['generator']) generator = generator.cuda() generator.eval() print(f'Finish loading checkpoint.') # Set random seed. np.random.seed(args.seed) torch.manual_seed(args.seed) # Sample and synthesize. print(f'Synthesizing {args.num} samples ...') indices = list(range(args.num)) if args.generate_html: html = HtmlPageVisualizer(grid_size=args.num) for batch_idx in tqdm(range(0, args.num, args.batch_size)): sub_indices = indices[batch_idx:batch_idx + args.batch_size] code = torch.randn(len(sub_indices), generator.z_space_dim).cuda() with torch.no_grad(): images = generator(code, **synthesis_kwargs)['image'] images = postprocess_image(images.detach().cpu().numpy()) for sub_idx, image in zip(sub_indices, images): if args.save_raw_synthesis: save_path = os.path.join( work_dir, job_name, f'{sub_idx:06d}.jpg') save_image(save_path, image) if args.generate_html: row_idx, col_idx = divmod(sub_idx, html.num_cols) html.set_cell(row_idx, col_idx, image=image, text=f'Sample {sub_idx:06d}') if args.generate_html: html.save(os.path.join(work_dir, f'{job_name}.html')) print(f'Finish synthesizing {args.num} samples.') if __name__ == '__main__': main()