import torch from torch import nn from torch import optim import torch.nn.functional as F from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader from torchvision import utils as vutils import os import random import argparse from tqdm import tqdm from models import Generator def load_params(model, new_param): for p, new_p in zip(model.parameters(), new_param): p.data.copy_(new_p) def resize(img): return F.interpolate(img, size=256) def batch_generate(zs, netG, batch=8): g_images = [] with torch.no_grad(): for i in range(len(zs)//batch): g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() ) if len(zs)%batch>0: g_images.append( netG(zs[-(len(zs)%batch):]).cpu() ) return torch.cat(g_images) def batch_save(images, folder_name): if not os.path.exists(folder_name): os.mkdir(folder_name) for i, image in enumerate(images): vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i) if __name__ == "__main__": parser = argparse.ArgumentParser( description='generate images' ) parser.add_argument('--ckpt', type=str, default="pre_trained_checkpoint_4ch/all_50000.pth") parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.') parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use') parser.add_argument('--start_iter', type=int, default=6) parser.add_argument('--end_iter', type=int, default=10) parser.add_argument('--dist', type=str, default='test_out') parser.add_argument('--size', type=int, default=256) parser.add_argument('--batch', default=1, type=int, help='batch size') parser.add_argument('--n_sample', type=int, default=1000) parser.add_argument('--big', action='store_true') parser.add_argument('--im_size', type=int, default=256) parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"]) parser.set_defaults(big=False) args = parser.parse_args() noise_dim = 256 device = torch.device('cuda:%d'%(args.cuda)) net_ig = Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big ) net_ig.to(device) #for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]: ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth" #checkpoint = torch.load(ckpt, map_location=lambda a,b: a) checkpoint = torch.load(ckpt) # Remove prefix `module`. checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()} net_ig.load_state_dict(checkpoint['g']) #load_params(net_ig, checkpoint['g_ema']) #net_ig.eval() print("load checkpoint success") net_ig.to(device) del checkpoint #dist = 'eval_%d'%(epoch) #dist = os.path.join(args.dist, 'img') dist = args.dist os.makedirs(dist, exist_ok=True) with torch.no_grad(): for i in tqdm(range(args.n_sample//args.batch)): noise = torch.randn(args.batch, noise_dim).to(device) g_imgs = net_ig(noise)[0] g_imgs = F.interpolate(g_imgs, 512) for j, g_img in enumerate( g_imgs ): #print("img sahpe=", g_img.shape) g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1) g_img = g_img.add(1).mul(0.5)[0:3, :, :] # Clean generated data using clamping g_mask = torch.clamp(g_mask, min=0, max=1) g_img = torch.clamp(g_img, min=0, max=1) #print(g_mask.type()) g_mask = (g_mask > 0.5) * 1.0 #print(g_mask.type()) #print("gmask_min:", g_mask.min()) #print("gmask_max:", g_mask.max()) #exit() #print("img sahpe=", g_img.shape) if args.save_option == "image_and_mask": vutils.save_image(g_img, os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1)) vutils.save_image(g_mask, os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1)) elif args.save_option == "image_only": vutils.save_image(g_img, os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1)) elif args.save_option == "mask_only": vutils.save_image(g_mask, os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1)) else: print("wrong choise to save option.")