import argparse import torch from torchvision import utils from model import Generator from tqdm import tqdm def generate(args, g_ema, device, mean_latent): with torch.no_grad(): g_ema.eval() for i in tqdm(range(args.pics)): sample_z = torch.randn(args.sample, args.latent, device=device) sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent) utils.save_image( sample, f'sample/{str(i).zfill(6)}.png', nrow=1, normalize=True, range=(-1, 1), ) if __name__ == '__main__': device = 'cuda' parser = argparse.ArgumentParser() parser.add_argument('--size', type=int, default=1024) parser.add_argument('--sample', type=int, default=1) parser.add_argument('--pics', type=int, default=20) parser.add_argument('--truncation', type=float, default=1) parser.add_argument('--truncation_mean', type=int, default=4096) parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt") parser.add_argument('--channel_multiplier', type=int, default=2) args = parser.parse_args() args.latent = 512 args.n_mlp = 8 g_ema = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier ).to(device) checkpoint = torch.load(args.ckpt) g_ema.load_state_dict(checkpoint['g_ema']) if args.truncation < 1: with torch.no_grad(): mean_latent = g_ema.mean_latent(args.truncation_mean) else: mean_latent = None generate(args, g_ema, device, mean_latent)