File size: 3,237 Bytes
7f49ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def resize(img):
return F.interpolate(img, size=256)
def batch_generate(zs, netG, batch=8):
g_images = []
with torch.no_grad():
for i in range(len(zs)//batch):
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
if len(zs)%batch>0:
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
return torch.cat(g_images)
def batch_save(images, folder_name):
if not os.path.exists(folder_name):
os.mkdir(folder_name)
for i, image in enumerate(images):
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='generate images'
)
parser.add_argument('--ckpt', type=str)
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
parser.add_argument('--start_iter', type=int, default=6)
parser.add_argument('--end_iter', type=int, default=10)
parser.add_argument('--dist', type=str, default='.')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=16, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=2000)
parser.add_argument('--big', action='store_true')
parser.add_argument('--im_size', type=int, default=1024)
parser.set_defaults(big=False)
args = parser.parse_args()
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))
net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
ckpt = f"{args.artifacts}/models/{epoch}.pth"
checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
# Remove prefix `module`.
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
net_ig.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
#net_ig.eval()
print('load checkpoint success, epoch %d'%epoch)
net_ig.to(device)
del checkpoint
dist = 'eval_%d'%(epoch)
dist = os.path.join(dist, 'img')
os.makedirs(dist, exist_ok=True)
with torch.no_grad():
for i in tqdm(range(args.n_sample//args.batch)):
noise = torch.randn(args.batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
g_imgs = F.interpolate(g_imgs, 512)
for j, g_img in enumerate( g_imgs ):
vutils.save_image(g_img.add(1).mul(0.5),
os.path.join(dist, '%d.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|