File size: 4,884 Bytes
7f49ac7 c7c9ff6 7f49ac7 c7c9ff6 7f49ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def resize(img):
return F.interpolate(img, size=256)
def batch_generate(zs, netG, batch=8):
g_images = []
with torch.no_grad():
for i in range(len(zs)//batch):
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
if len(zs)%batch>0:
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
return torch.cat(g_images)
def batch_save(images, folder_name):
if not os.path.exists(folder_name):
os.mkdir(folder_name)
for i, image in enumerate(images):
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='generate images'
)
parser.add_argument('--ckpt', type=str, default="pre_trained_checkpoint_4ch/all_50000.pth")
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
parser.add_argument('--start_iter', type=int, default=6)
parser.add_argument('--end_iter', type=int, default=10)
parser.add_argument('--dist', type=str, default='test_out')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=1, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=1000)
parser.add_argument('--big', action='store_true')
parser.add_argument('--im_size', type=int, default=256)
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
parser.set_defaults(big=False)
args = parser.parse_args()
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))
net_ig = Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
#checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
checkpoint = torch.load(ckpt)
# Remove prefix `module`.
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
net_ig.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
#net_ig.eval()
print("load checkpoint success")
net_ig.to(device)
del checkpoint
#dist = 'eval_%d'%(epoch)
#dist = os.path.join(args.dist, 'img')
dist = args.dist
os.makedirs(dist, exist_ok=True)
with torch.no_grad():
for i in tqdm(range(args.n_sample//args.batch)):
noise = torch.randn(args.batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
g_imgs = F.interpolate(g_imgs, 512)
for j, g_img in enumerate( g_imgs ):
#print("img sahpe=", g_img.shape)
g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
g_img = g_img.add(1).mul(0.5)[0:3, :, :]
# Clean generated data using clamping
g_mask = torch.clamp(g_mask, min=0, max=1)
g_img = torch.clamp(g_img, min=0, max=1)
#print(g_mask.type())
g_mask = (g_mask > 0.5) * 1.0
#print(g_mask.type())
#print("gmask_min:", g_mask.min())
#print("gmask_max:", g_mask.max())
#exit()
#print("img sahpe=", g_img.shape)
if args.save_option == "image_and_mask":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
elif args.save_option == "image_only":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
elif args.save_option == "mask_only":
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
else:
print("wrong choise to save option.")
|