import argparse import os import cv2 import numpy as np import torch from tqdm import tqdm from model import Generator from utils import ten2cv, cv2ten import random seed = 0 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def generate(args, g_ema, device, mean_latent, sample_style, add_weight_index): if args.sample_zs is not None: sample_zs = torch.load(args.sample_zs) else: sample_zs = None with torch.no_grad(): g_ema.eval() for i in tqdm(range(args.pics)): if sample_zs is not None: sample_z = sample_zs[i] else: sample_z = torch.randn(1, args.latent, device=device) sample1, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False) sample2, _ = g_ema([sample_z], z_embed=sample_style, add_weight_index=add_weight_index, truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False) sample1 = ten2cv(sample1) sample2 = ten2cv(sample2) out = np.concatenate([sample1, sample2], axis=1) cv2.imwrite(f'{args.outdir}/{str(i).zfill(6)}.jpg', out) if __name__ == '__main__': device = 'cuda' parser = argparse.ArgumentParser() parser.add_argument('--size', type=int, default=1024) parser.add_argument('--pics', type=int, default=20, help='N_PICS') parser.add_argument('--truncation', type=float, default=0.75) parser.add_argument('--truncation_mean', type=int, default=4096) parser.add_argument('--ckpt', type=str, default='', help='path to BlendGAN checkpoint') parser.add_argument('--style_img', type=str, default=None, help='path to style image') parser.add_argument('--sample_zs', type=str, default=None) parser.add_argument('--add_weight_index', type=int, default=6) parser.add_argument('--channel_multiplier', type=int, default=2) parser.add_argument('--outdir', type=str, default="") args = parser.parse_args() outdir = args.outdir if not os.path.exists(outdir): os.makedirs(outdir, exist_ok=True) args.latent = 512 args.n_mlp = 8 checkpoint = torch.load(args.ckpt) model_dict = checkpoint['g_ema'] if "latent_avg" in checkpoint.keys(): latent_avg = checkpoint["latent_avg"] else: latent_avg = None if "truncation" in checkpoint.keys(): args.truncation = checkpoint["truncation"] print('ckpt: ', args.ckpt) print('truncation: ', args.truncation) g_ema = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier ).to(device) g_ema.load_state_dict(model_dict) if args.truncation < 1: if latent_avg is not None: mean_latent = latent_avg print('### use mean_latent in ckpt["latent_avg"]') else: with torch.no_grad(): mean_latent = g_ema.mean_latent(args.truncation_mean) print('### generate mean_latent with \'g_ema.mean_latent\'') else: mean_latent = None print('### args.truncation = 1, mean_latent is None') if args.style_img is not None: img = cv2.imread(args.style_img, 1) img = cv2ten(img, device) sample_style = g_ema.get_z_embed(img) else: sample_style = torch.randn(1, args.latent, device=device) generate(args, g_ema, device, mean_latent, sample_style, args.add_weight_index) print('Done!')