PandA / networks /genforce /synthesize.py
james-oldfield's picture
Upload 194 files
2a76164
raw
history blame contribute delete
No virus
6.14 kB
# python3.7
"""A simple tool to synthesize images with pre-trained models."""
import os
import argparse
import subprocess
from tqdm import tqdm
import numpy as np
import torch
from models import MODEL_ZOO
from models import build_generator
from utils.misc import bool_parser
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Synthesize images with pre-trained models.')
parser.add_argument('model_name', type=str,
help='Name to the pre-trained model.')
parser.add_argument('--save_dir', type=str, default=None,
help='Directory to save the results. If not specified, '
'the results will be saved to '
'`work_dirs/synthesis/` by default. '
'(default: %(default)s)')
parser.add_argument('--num', type=int, default=100,
help='Number of samples to synthesize. '
'(default: %(default)s)')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size. (default: %(default)s)')
parser.add_argument('--generate_html', type=bool_parser, default=True,
help='Whether to use HTML page to visualize the '
'synthesized results. (default: %(default)s)')
parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False,
help='Whether to save raw synthesis. '
'(default: %(default)s)')
parser.add_argument('--seed', type=int, default=0,
help='Seed for sampling. (default: %(default)s)')
parser.add_argument('--trunc_psi', type=float, default=0.7,
help='Psi factor used for truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--trunc_layers', type=int, default=8,
help='Number of layers to perform truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--randomize_noise', type=bool_parser, default=False,
help='Whether to randomize the layer-wise noise. This '
'is particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
if args.num <= 0:
return
if not args.save_raw_synthesis and not args.generate_html:
return
# Parse model configuration.
if args.model_name not in MODEL_ZOO:
raise SystemExit(f'Model `{args.model_name}` is not registered in '
f'`models/model_zoo.py`!')
model_config = MODEL_ZOO[args.model_name].copy()
url = model_config.pop('url') # URL to download model if needed.
# Get work directory and job name.
if args.save_dir:
work_dir = args.save_dir
else:
work_dir = os.path.join('work_dirs', 'synthesis')
os.makedirs(work_dir, exist_ok=True)
job_name = f'{args.model_name}_{args.num}'
if args.save_raw_synthesis:
os.makedirs(os.path.join(work_dir, job_name), exist_ok=True)
# Build generation and get synthesis kwargs.
print(f'Building generator for model `{args.model_name}` ...')
generator = build_generator(**model_config)
synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
trunc_layers=args.trunc_layers,
randomize_noise=args.randomize_noise)
print(f'Finish building generator.')
# Load pre-trained weights.
os.makedirs('/import/nobackup_mmv_ioannisp/jo001/genforce_models', exist_ok=True)
checkpoint_path = os.path.join('/import/nobackup_mmv_ioannisp/jo001/genforce_models', args.model_name + '.pth')
print(f'Loading checkpoint from `{checkpoint_path}` ...')
if not os.path.exists(checkpoint_path):
print(f' Downloading checkpoint from `{url}` ...')
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
print(f' Finish downloading checkpoint.')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'generator_smooth' in checkpoint:
generator.load_state_dict(checkpoint['generator_smooth'])
else:
generator.load_state_dict(checkpoint['generator'])
generator = generator.cuda()
generator.eval()
print(f'Finish loading checkpoint.')
# Set random seed.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Sample and synthesize.
print(f'Synthesizing {args.num} samples ...')
indices = list(range(args.num))
if args.generate_html:
html = HtmlPageVisualizer(grid_size=args.num)
for batch_idx in tqdm(range(0, args.num, args.batch_size)):
sub_indices = indices[batch_idx:batch_idx + args.batch_size]
code = torch.randn(len(sub_indices), generator.z_space_dim).cuda()
with torch.no_grad():
images = generator(code, **synthesis_kwargs)['image']
images = postprocess_image(images.detach().cpu().numpy())
for sub_idx, image in zip(sub_indices, images):
if args.save_raw_synthesis:
save_path = os.path.join(
work_dir, job_name, f'{sub_idx:06d}.jpg')
save_image(save_path, image)
if args.generate_html:
row_idx, col_idx = divmod(sub_idx, html.num_cols)
html.set_cell(row_idx, col_idx, image=image,
text=f'Sample {sub_idx:06d}')
if args.generate_html:
html.save(os.path.join(work_dir, f'{job_name}.html'))
print(f'Finish synthesizing {args.num} samples.')
if __name__ == '__main__':
main()