Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
3.11 kB
import torch, copy, argparse, re, os, numpy
from seeing import nethook, setting, renormalize, zdataset, pbar
from seeing import encoder_net, nethook
from seeing import workerpool
from seeing.encoder_loss import cor_square_error
from torch.nn.functional import mse_loss, l1_loss
from seeing.LBFGS import FullBatchLBFGS
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='church')
parser.add_argument('--dataset', default='church_outdoor_train')
parser.add_argument('--iterations', type=int, default=0)
args = parser.parse_args()
batch_size = 32
expdir = 'results/z_dataset/%s/it_%d' % (args.dataset, args.iterations)
G = setting.load_proggan('church').eval().cuda()
E = encoder_net.HybridLayerNormEncoder()
filename = 'results/church/invert_hybrid_bottom_b5/snapshots/epoch_1000.pth.tar'
E.load_state_dict(torch.load(filename)['state_dict'])
E.eval().cuda()
dataset = setting.load_dataset(args.dataset, full=True)
loader = torch.utils.data.DataLoader(dataset, shuffle=False,
batch_size=batch_size, num_workers=10, pin_memory=True)
torch.set_grad_enabled(False)
class SaveNpyWorker(workerpool.WorkerBase):
def work(self, filename, data):
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
numpy.save(filename, data)
saver_pool = workerpool.WorkerPool(worker=SaveNpyWorker)
def target_filename_from_source(filename):
patharr = filename.split('/')
patharr = patharr[patharr.index(args.dataset)+1:]
patharr[-1] = os.path.splitext(patharr[-1])[0] + '.npy'
return os.path.join(expdir, *patharr)
def refine_z_lbfgs(init_z, target_x, lambda_f=1.0, num_steps=100):
z = init_z.clone()
parameters = [z]
F = E
target_f = F(target_x)
nethook.set_requires_grad(False, G, E, target_x, target_f)
nethook.set_requires_grad(True, *parameters)
optimizer = FullBatchLBFGS(parameters)
def closure():
optimizer.zero_grad()
current_x = G(z)
loss = l1_loss(target_x, current_x)
if lambda_f:
loss += mse_loss(target_f, F(current_x)) * lambda_f
return loss
with torch.enable_grad():
for step_num in pbar(range(num_steps + 1)):
if step_num == 0:
loss = closure()
loss.backward()
else:
options = dict(closure=closure, current_loss=loss, max_ls=10)
loss, _, _, _, _, _, _, _ = optimizer.step(options)
return z
index = 0
for [im] in pbar(loader):
im = im.cuda()
z = E(im)
if args.iterations > 0:
#for i in range(len(im)):
# z[i:i+1] = refine_z_lbfgs(z[i:i+1], im[i:i+1],
# num_steps=args.iterations)
z = refine_z_lbfgs(z, im, num_steps=args.iterations)
cpu_z = z.cpu().numpy()
for i in range(len(im)):
filename = target_filename_from_source(dataset.images[index + i][0])
data = cpu_z[i].copy()[None]
pbar.print(filename, data.shape)
saver_pool.add(filename, data)
index += len(im)
saver_pool.join()