Spaces:
Build error
Build error
# python3.7 | |
"""A simple tool to synthesize images with pre-trained models.""" | |
import os | |
import argparse | |
import subprocess | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
from models import MODEL_ZOO | |
from models import build_generator | |
from utils.misc import bool_parser | |
from utils.visualizer import HtmlPageVisualizer | |
from utils.visualizer import postprocess_image | |
from utils.visualizer import save_image | |
def parse_args(): | |
"""Parses arguments.""" | |
parser = argparse.ArgumentParser( | |
description='Synthesize images with pre-trained models.') | |
parser.add_argument('model_name', type=str, | |
help='Name to the pre-trained model.') | |
parser.add_argument('--save_dir', type=str, default=None, | |
help='Directory to save the results. If not specified, ' | |
'the results will be saved to ' | |
'`work_dirs/synthesis/` by default. ' | |
'(default: %(default)s)') | |
parser.add_argument('--num', type=int, default=100, | |
help='Number of samples to synthesize. ' | |
'(default: %(default)s)') | |
parser.add_argument('--batch_size', type=int, default=1, | |
help='Batch size. (default: %(default)s)') | |
parser.add_argument('--generate_html', type=bool_parser, default=True, | |
help='Whether to use HTML page to visualize the ' | |
'synthesized results. (default: %(default)s)') | |
parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False, | |
help='Whether to save raw synthesis. ' | |
'(default: %(default)s)') | |
parser.add_argument('--seed', type=int, default=0, | |
help='Seed for sampling. (default: %(default)s)') | |
parser.add_argument('--trunc_psi', type=float, default=0.7, | |
help='Psi factor used for truncation. This is ' | |
'particularly applicable to StyleGAN (v1/v2). ' | |
'(default: %(default)s)') | |
parser.add_argument('--trunc_layers', type=int, default=8, | |
help='Number of layers to perform truncation. This is ' | |
'particularly applicable to StyleGAN (v1/v2). ' | |
'(default: %(default)s)') | |
parser.add_argument('--randomize_noise', type=bool_parser, default=False, | |
help='Whether to randomize the layer-wise noise. This ' | |
'is particularly applicable to StyleGAN (v1/v2). ' | |
'(default: %(default)s)') | |
return parser.parse_args() | |
def main(): | |
"""Main function.""" | |
args = parse_args() | |
if args.num <= 0: | |
return | |
if not args.save_raw_synthesis and not args.generate_html: | |
return | |
# Parse model configuration. | |
if args.model_name not in MODEL_ZOO: | |
raise SystemExit(f'Model `{args.model_name}` is not registered in ' | |
f'`models/model_zoo.py`!') | |
model_config = MODEL_ZOO[args.model_name].copy() | |
url = model_config.pop('url') # URL to download model if needed. | |
# Get work directory and job name. | |
if args.save_dir: | |
work_dir = args.save_dir | |
else: | |
work_dir = os.path.join('work_dirs', 'synthesis') | |
os.makedirs(work_dir, exist_ok=True) | |
job_name = f'{args.model_name}_{args.num}' | |
if args.save_raw_synthesis: | |
os.makedirs(os.path.join(work_dir, job_name), exist_ok=True) | |
# Build generation and get synthesis kwargs. | |
print(f'Building generator for model `{args.model_name}` ...') | |
generator = build_generator(**model_config) | |
synthesis_kwargs = dict(trunc_psi=args.trunc_psi, | |
trunc_layers=args.trunc_layers, | |
randomize_noise=args.randomize_noise) | |
print(f'Finish building generator.') | |
# Load pre-trained weights. | |
os.makedirs('/import/nobackup_mmv_ioannisp/jo001/genforce_models', exist_ok=True) | |
checkpoint_path = os.path.join('/import/nobackup_mmv_ioannisp/jo001/genforce_models', args.model_name + '.pth') | |
print(f'Loading checkpoint from `{checkpoint_path}` ...') | |
if not os.path.exists(checkpoint_path): | |
print(f' Downloading checkpoint from `{url}` ...') | |
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url]) | |
print(f' Finish downloading checkpoint.') | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
if 'generator_smooth' in checkpoint: | |
generator.load_state_dict(checkpoint['generator_smooth']) | |
else: | |
generator.load_state_dict(checkpoint['generator']) | |
generator = generator.cuda() | |
generator.eval() | |
print(f'Finish loading checkpoint.') | |
# Set random seed. | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
# Sample and synthesize. | |
print(f'Synthesizing {args.num} samples ...') | |
indices = list(range(args.num)) | |
if args.generate_html: | |
html = HtmlPageVisualizer(grid_size=args.num) | |
for batch_idx in tqdm(range(0, args.num, args.batch_size)): | |
sub_indices = indices[batch_idx:batch_idx + args.batch_size] | |
code = torch.randn(len(sub_indices), generator.z_space_dim).cuda() | |
with torch.no_grad(): | |
images = generator(code, **synthesis_kwargs)['image'] | |
images = postprocess_image(images.detach().cpu().numpy()) | |
for sub_idx, image in zip(sub_indices, images): | |
if args.save_raw_synthesis: | |
save_path = os.path.join( | |
work_dir, job_name, f'{sub_idx:06d}.jpg') | |
save_image(save_path, image) | |
if args.generate_html: | |
row_idx, col_idx = divmod(sub_idx, html.num_cols) | |
html.set_cell(row_idx, col_idx, image=image, | |
text=f'Sample {sub_idx:06d}') | |
if args.generate_html: | |
html.save(os.path.join(work_dir, f'{job_name}.html')) | |
print(f'Finish synthesizing {args.num} samples.') | |
if __name__ == '__main__': | |
main() | |