Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
7.63 kB
import torch, multiprocessing, itertools, os, shutil, PIL, argparse, numpy
from collections import OrderedDict
from numbers import Number
from torch.nn.functional import mse_loss, l1_loss
from seeing import pbar
from seeing import zdataset, seededsampler
from seeing import proggan, customnet, parallelfolder
from seeing import encoder_net, encoder_loss, setting
from torchvision import transforms, models
from torchvision.models.vgg import model_urls
from seeing.pidfile import exit_if_job_done, mark_job_done
from seeing import nethook, LBFGS
from seeing.encoder_loss import cor_square_error
from seeing.nethook import InstrumentedModel
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('--image_number', type=int, help='Image number',
default=95)
parser.add_argument('--image_source', #choices=['val', 'train', 'gan', 'test'],
default='test')
parser.add_argument('--redo', type=int, help='Nonzero to delete done.txt',
default=0)
parser.add_argument('--model', type=str, help='Dataset being modeled',
default='church')
parser.add_argument('--halfsize', type=int,
help='Set to 1 for half size enoder',
default=0)
parser.add_argument('--lambda_f', type=float, help='Feature regularizer',
default=0.25)
parser.add_argument('--num_steps', type=int,
help='run for n steps',
default=3000)
parser.add_argument('--snapshot_every', type=int,
help='only generate snapshots every n iterations',
default=1000)
args = parser.parse_args()
num_steps = args.num_steps
global_seed = 1
image_number = args.image_number
expgroup = 'optimize_lbfgs'
imagetypecode = (dict(val='i', train='n', gan='z', test='t')
.get(args.image_source, args.image_source[0]))
expname = 'opt_%s_%d' % (imagetypecode, image_number)
expdir = os.path.join('results', args.model, expgroup, 'cases', expname)
sumdir = os.path.join('results', args.model, expgroup,
'summary_%s' % imagetypecode)
os.makedirs(expdir, exist_ok=True)
os.makedirs(sumdir, exist_ok=True)
# First load single image optimize (load via test ParallelFolder dataset).
def main():
pbar.print('Running %s' % expdir)
delete_log()
# Grab a target image
dirname = os.path.join(expdir, 'images')
os.makedirs(dirname, exist_ok=True)
loaded_x, loaded_z = setting.load_test_image(image_number,
args.image_source, model=args.model)
visualize_results((image_number, 'target'),
loaded_x[0], summarize=True)
# Load the pretrained generator model.
G = setting.load_proggan(args.model)
# Load a pretrained gan inverter
E = nethook.InstrumentedModel(
encoder_net.HybridLayerNormEncoder(halfsize=args.halfsize))
E.load_state_dict(torch.load(os.path.join('results', args.model,
'invert_hybrid_cse/snapshots/epoch_1000.pth.tar'))['state_dict'])
E.eval()
G.cuda()
E.cuda()
F = E
torch.set_grad_enabled(False)
# Some constants for the GPU
# Our true image is constant
true_x = loaded_x.cuda()
# Invert our image once!
init_z = E(true_x)
# For GAN-generated images we have ground truth.
if loaded_z is None:
true_z = None
else:
true_z = loaded_z.cuda()
current_z = init_z.clone()
target_x = loaded_x.clone().cuda()
target_f = F(loaded_x.cuda())
parameters = [current_z]
show_every = args.snapshot_every
nethook.set_requires_grad(False, G, E)
nethook.set_requires_grad(True, *parameters)
optimizer = LBFGS.FullBatchLBFGS(parameters)
def compute_all_loss():
current_x = G(current_z)
all_loss = {}
all_loss['x'] = l1_loss(target_x, current_x)
all_loss['z'] = 0.0 if not args.lambda_f else (
mse_loss(target_f, F(current_x)) * args.lambda_f)
return current_x, all_loss
def closure():
optimizer.zero_grad()
_, all_loss = compute_all_loss()
return sum(all_loss.values())
with torch.enable_grad():
for step_num in pbar(range(num_steps + 1)):
if step_num == 0:
loss = closure()
loss.backward()
else:
options = {'closure': closure, 'current_loss': loss,
'max_ls': 10}
loss, _, lr, _, _, _, _, _ = optimizer.step(options)
if step_num % show_every == 0:
with torch.no_grad():
current_x, all_loss = compute_all_loss()
log_progress('%d ' % step_num + ' '.join(
'%s=%.3f' % (k, all_loss[k])
for k in sorted(all_loss.keys())), phase='a')
visualize_results((image_number, 'a', step_num), current_x,
summarize=(step_num in [0, num_steps]))
checkpoint_dict = OrderedDict(all_loss)
checkpoint_dict['init_z'] = init_z
checkpoint_dict['target_x'] = target_x
checkpoint_dict['current_z'] = target_x
save_checkpoint(
phase='a',
step=step_num,
optimizer=optimizer.state_dict(),
**checkpoint_dict)
def delete_log():
try:
os.remove(os.path.join(expdir, 'log.txt'))
except:
pass
def log_progress(s, phase='a'):
with open(os.path.join(expdir, 'log.txt'), 'a') as f:
f.write(phase + ' ' + s + '\n')
pbar.print(s)
def save_checkpoint(**kwargs):
dirname = os.path.join(expdir, 'snapshots')
os.makedirs(dirname, exist_ok=True)
filename = 'step_%s_%d.pth.tar' % (kwargs['phase'], kwargs['step'])
torch.save(kwargs, os.path.join(dirname, filename))
# Also save as .mat file for analysis.
numeric_data = {
k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
if isinstance(v, (Number, numpy.ndarray, torch.Tensor))}
filename = 'step_%s_%d.npz' % (kwargs['phase'], kwargs['step'])
numpy.savez(os.path.join(dirname, filename), **numeric_data)
def visualize_results(step, img, summarize=False):
# TODO: add editing etc.
if isinstance(step, tuple):
filename = '%s.png' % ('_'.join(str(i) for i in step))
else:
filename = '%s.png' % str(step)
dirname = os.path.join(expdir, 'images')
os.makedirs(dirname, exist_ok=True)
save_tensor_image(img, os.path.join(dirname, filename))
lbname = os.path.join(dirname, '+lightbox.html')
if not os.path.exists(lbname):
shutil.copy('seeing/lightbox.html', lbname)
if summarize:
save_tensor_image(img, os.path.join(sumdir, filename))
lbname = os.path.join(sumdir, '+lightbox.html')
if not os.path.exists(lbname):
shutil.copy('seeing/lightbox.html', lbname)
def save_tensor_image(img, filename):
if len(img.shape) == 4:
img = img[0]
np_data = ((img.permute(1, 2, 0) / 2 + 0.5) * 255
).clamp(0, 255).byte().cpu().numpy()
PIL.Image.fromarray(np_data).save(filename)
def set_requires_grad(requires_grad, *models):
for model in models:
if isinstance(model, torch.nn.Module):
for param in model.parameters():
param.requires_grad = requires_grad
elif isintance(model, torch.nn.Parameter):
model.requires_grad = requires_grad
else:
assert False, 'unknown type %r' % type(model)
if __name__ == '__main__':
exit_if_job_done(expdir, redo=args.redo)
main()
mark_job_done(expdir)