File size: 6,478 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import os.path as osp
import argparse
import torch
import json
from hashlib import sha1
from torchvision.transforms import ToPILImage
from lib import GENFORCE_MODELS, update_progress, update_stdout
from models.load_generator import load_generator
def tensor2image(tensor, img_size=None, adaptive=False):
# Squeeze tensor image
tensor = tensor.squeeze(dim=0)
if adaptive:
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
if img_size:
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size))
else:
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8))
else:
tensor = (tensor + 1) / 2
tensor.clamp(0, 1)
if img_size:
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size))
else:
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8))
def main():
"""A script for sampling from a pre-trained GAN's latent space and generating images. The generated images, along
with the corresponding latent codes, will be stored under `experiments/latent_codes/<gan>/`.
Options:
-v, --verbose : set verbose mode on
--gan : set GAN generator (see GENFORCE_MODELS in lib/config.py)
--truncation : set W-space truncation parameter. If set, W-space codes will be truncated
--num-samples : set the number of latent codes to sample for generating images
--cuda : use CUDA (default)
--no-cuda : do not use CUDA
"""
parser = argparse.ArgumentParser(description="Sample a pre-trained GAN latent space and generate images")
parser.add_argument('-v', '--verbose', action='store_true', help="verbose mode on")
parser.add_argument('--gan', type=str, required=True, choices=GENFORCE_MODELS.keys(), help='GAN generator')
parser.add_argument('--truncation', type=float, default=1.0, help="W-space truncation parameter")
parser.add_argument('--num-samples', type=int, default=4, help="set number of latent codes to sample")
parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training")
parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training")
parser.set_defaults(cuda=True)
# ================================================================================================================ #
# Parse given arguments
args = parser.parse_args()
# Create output dir for generated images
out_dir = osp.join('experiments', 'latent_codes', args.gan)
out_dir = osp.join(out_dir, '{}-{}'.format(args.gan, args.num_samples))
os.makedirs(out_dir, exist_ok=True)
# Save argument in json file
with open(osp.join(out_dir, 'args.json'), 'w') as args_json_file:
json.dump(args.__dict__, args_json_file)
# CUDA
use_cuda = False
if torch.cuda.is_available():
if args.cuda:
use_cuda = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
print("*** WARNING ***: It looks like you have a CUDA device, but aren't using CUDA.\n"
" Run with --cuda for optimal training speed.")
torch.set_default_tensor_type('torch.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
# Build GAN generator model and load with pre-trained weights
if args.verbose:
print("#. Build GAN generator model G and load with pre-trained weights...")
print(" \\__GAN generator : {} (res: {})".format(args.gan, GENFORCE_MODELS[args.gan][1]))
print(" \\__Pre-trained weights: {}".format(GENFORCE_MODELS[args.gan][0]))
G = load_generator(model_name=args.gan,
latent_is_w='stylegan' in args.gan,
verbose=args.verbose).eval()
# Upload GAN generator model to GPU
if use_cuda:
G = G.cuda()
# Latent codes sampling
if args.verbose:
print("#. Sample {} {}-dimensional latent codes...".format(args.num_samples, G.dim_z))
zs = torch.randn(args.num_samples, G.dim_z)
if use_cuda:
zs = zs.cuda()
if args.verbose:
print("#. Generate images...")
print(" \\__{}".format(out_dir))
# Iterate over given latent codes
for i in range(args.num_samples):
# Un-squeeze current latent code in shape [1, dim] and create hash code for it
z = zs[i, :].unsqueeze(0)
latent_code_hash = sha1(z.cpu().numpy()).hexdigest()
if args.verbose:
update_progress(
" \\__.Latent code hash: {} [{:03d}/{:03d}] ".format(latent_code_hash, i + 1, args.num_samples),
args.num_samples, i)
# Create directory for current latent code
latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
os.makedirs(latent_code_dir, exist_ok=True)
if 'stylegan' in args.gan:
# Get the w+ and w codes for the given z code, save them, and the generated image based on the w code
# Note that w+ has torch.Size([1, 512]) and w torch.Size([18, 512]) -- the latter is just a repetition of
# the w code for all 18 layers
w_plus = G.get_w(z, truncation=args.truncation)[0, :, :]
w = w_plus[0, :].unsqueeze(0)
torch.save(z.cpu(), osp.join(latent_code_dir, 'latent_code_z.pt'))
torch.save(w.cpu(), osp.join(latent_code_dir, 'latent_code_w.pt'))
torch.save(w_plus.cpu(), osp.join(latent_code_dir, 'latent_code_w+.pt'))
img_w = G(w).cpu()
tensor2image(img_w, adaptive=True).save(osp.join(latent_code_dir, 'image_w.jpg'),
"JPEG", quality=95, optimize=True, progressive=True)
else:
# Save latent code (Z-space), generate image for this code, and save the generated image
torch.save(z.cpu(), osp.join(latent_code_dir, 'latent_code_z.pt'))
img_z = G(z).cpu()
tensor2image(img_z, adaptive=True).save(osp.join(latent_code_dir, 'image_z.jpg'),
"JPEG", quality=95, optimize=True, progressive=True)
if args.verbose:
update_stdout(1)
print()
print()
if __name__ == '__main__':
main()
|