"""SeFa.""" import os import argparse from tqdm import tqdm import numpy as np import torch from models import parse_gan_type from utils import to_tensor from utils import postprocess from utils import load_generator from utils import factorize_weight from utils import HtmlPageVisualizer def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser( description='Discover semantics from the pre-trained weight.') parser.add_argument('model_name', type=str, help='Name to the pre-trained model.') parser.add_argument('--save_dir', type=str, default='results', help='Directory to save the visualization pages. ' '(default: %(default)s)') parser.add_argument('-L', '--layer_idx', type=str, default='all', help='Indices of layers to interpret. ' '(default: %(default)s)') parser.add_argument('-N', '--num_samples', type=int, default=5, help='Number of samples used for visualization. ' '(default: %(default)s)') parser.add_argument('-K', '--num_semantics', type=int, default=5, help='Number of semantic boundaries corresponding to ' 'the top-k eigen values. (default: %(default)s)') parser.add_argument('--start_distance', type=float, default=-3.0, help='Start point for manipulation on each semantic. ' '(default: %(default)s)') parser.add_argument('--end_distance', type=float, default=3.0, help='Ending point for manipulation on each semantic. ' '(default: %(default)s)') parser.add_argument('--step', type=int, default=11, help='Manipulation step on each semantic. ' '(default: %(default)s)') parser.add_argument('--viz_size', type=int, default=256, help='Size of images to visualize on the HTML page. ' '(default: %(default)s)') parser.add_argument('--trunc_psi', type=float, default=0.7, help='Psi factor used for truncation. This is ' 'particularly applicable to StyleGAN (v1/v2). ' '(default: %(default)s)') parser.add_argument('--trunc_layers', type=int, default=8, help='Number of layers to perform truncation. This is ' 'particularly applicable to StyleGAN (v1/v2). ' '(default: %(default)s)') parser.add_argument('--seed', type=int, default=0, help='Seed for sampling. (default: %(default)s)') parser.add_argument('--gpu_id', type=str, default='0', help='GPU(s) to use. (default: %(default)s)') return parser.parse_args() def main(): """Main function.""" args = parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id os.makedirs(args.save_dir, exist_ok=True) # Factorize weights. generator = load_generator(args.model_name) gan_type = parse_gan_type(generator) layers, boundaries, values = factorize_weight(generator, args.layer_idx) # Set random seed. np.random.seed(args.seed) torch.manual_seed(args.seed) # Prepare codes. codes = torch.randn(args.num_samples, generator.z_space_dim).cuda() if gan_type == 'pggan': codes = generator.layer0.pixel_norm(codes) elif gan_type in ['stylegan', 'stylegan2']: codes = generator.mapping(codes)['w'] codes = generator.truncation(codes, trunc_psi=args.trunc_psi, trunc_layers=args.trunc_layers) codes = codes.detach().cpu().numpy() # Generate visualization pages. distances = np.linspace(args.start_distance,args.end_distance, args.step) num_sam = args.num_samples num_sem = args.num_semantics vizer_1 = HtmlPageVisualizer(num_rows=num_sem * (num_sam + 1), num_cols=args.step + 1, viz_size=args.viz_size) vizer_2 = HtmlPageVisualizer(num_rows=num_sam * (num_sem + 1), num_cols=args.step + 1, viz_size=args.viz_size) headers = [''] + [f'Distance {d:.2f}' for d in distances] vizer_1.set_headers(headers) vizer_2.set_headers(headers) for sem_id in range(num_sem): value = values[sem_id] vizer_1.set_cell(sem_id * (num_sam + 1), 0, text=f'Semantic {sem_id:03d}
({value:.3f})', highlight=True) for sam_id in range(num_sam): vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, 0, text=f'Sample {sam_id:03d}') for sam_id in range(num_sam): vizer_2.set_cell(sam_id * (num_sem + 1), 0, text=f'Sample {sam_id:03d}', highlight=True) for sem_id in range(num_sem): value = values[sem_id] vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, 0, text=f'Semantic {sem_id:03d}
({value:.3f})') for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False): code = codes[sam_id:sam_id + 1] for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False): boundary = boundaries[sem_id:sem_id + 1] for col_id, d in enumerate(distances, start=1): temp_code = code.copy() if gan_type == 'pggan': temp_code += boundary * d image = generator(to_tensor(temp_code))['image'] elif gan_type in ['stylegan', 'stylegan2']: temp_code[:, layers, :] += boundary * d image = generator.synthesis(to_tensor(temp_code))['image'] image = postprocess(image)[0] vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, col_id, image=image) vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, col_id, image=image) prefix = (f'{args.model_name}_' f'N{num_sam}_K{num_sem}_L{args.layer_idx}_seed{args.seed}') vizer_1.save(os.path.join(args.save_dir, f'{prefix}_sample_first.html')) vizer_2.save(os.path.join(args.save_dir, f'{prefix}_semantic_first.html')) if __name__ == '__main__': main()