Paul Engstler
Initial commit
92f0e98
import torch, multiprocessing, itertools, os, shutil, PIL, argparse, numpy
from collections import OrderedDict
from numbers import Number
from torch.nn.functional import mse_loss, l1_loss
from . import pbar
from . import zdataset
from . import proggan, customnet, parallelfolder
from . import encoder_net, encoder_loss, setting
from torchvision import transforms, models
from torchvision.models.vgg import model_urls
from .pidfile import exit_if_job_done, mark_job_done
from . import nethook
from .pidfile import exit_if_job_done, mark_job_done
from .encoder_loss import cor_square_error
from .nethook import InstrumentedModel
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, help='Learning rate', default=0.01)
parser.add_argument('--image_number', type=int, help='Image number',
default=95)
parser.add_argument('--image_source', #choices=['val', 'train', 'gan', 'test'],
default='test')
parser.add_argument('--redo', type=int, help='Nonzero to delete done.txt',
default=0)
parser.add_argument('--residuals', nargs='*', help='Residuals to adjust',
default=None)
parser.add_argument('--model', type=str, help='Dataset being modeled',
default='church')
parser.add_argument('--halfsize', type=int,
help='Set to 1 for half size enoder',
default=0)
parser.add_argument('--snapshot_every', type=int,
help='only generate snapshots every n iterations',
default=1000)
args = parser.parse_args()
num_steps = 3000
lr_milestones = [800, 1200, 1800]
residuals = (args.residuals if args.residuals is not None
else ['layer1', 'layer2', 'layer3'])
global_seed = 1
learning_rate = args.lr
image_number = args.image_number
expgroup = 'optimize_residuals'
# Use an explicit directory name for a different selection of residuals.
if args.residuals is not None:
expgroup += '_' + '_'.join(residuals)
imagetypecode = (dict(val='i', train='n', gan='z', test='t')
.get(args.image_source, args.image_source[0]))
expname = 'opt_%s_%d' % (imagetypecode, image_number)
expdir = os.path.join('results', args.model, expgroup, 'cases', expname)
sumdir = os.path.join('results', args.model, expgroup,
'summary_%s' % imagetypecode)
os.makedirs(expdir, exist_ok=True)
os.makedirs(sumdir, exist_ok=True)
# First load single image optimize (load via test ParallelFolder dataset).
def main():
pbar.print('Running %s' % expdir)
delete_log()
# Grab a target image
dirname = os.path.join(expdir, 'images')
os.makedirs(dirname, exist_ok=True)
loaded_image, loaded_z = setting.load_test_image(image_number,
args.image_source, model=args.model)
visualize_results((image_number, 'target'),
loaded_image[0], summarize=True)
# Load the pretrained generator model.
gan_generator = setting.load_proggan(args.model)
# We will wrap this model
unwrapped_H = nethook.subsequence(gan_generator, last_layer='layer4')
# Edit the output of this layer
F = nethook.subsequence(gan_generator, first_layer='layer5')
# Load a pretrained gan inverter
encoder = nethook.InstrumentedModel(
encoder_net.HybridLayerNormEncoder(halfsize=args.halfsize))
encoder.load_state_dict(torch.load(os.path.join('results', args.model,
'invert_hybrid_cse/snapshots/epoch_1000.pth.tar'))['state_dict'])
encoder.eval()
E = nethook.subsequence(encoder.model, last_layer='resnet')
D = nethook.subsequence(encoder.model, first_layer='inv4')
# Also make a conv features model from pretrained VGG
vgg = models.vgg16(pretrained=True)
VF = nethook.subsequence(vgg.features, last_layer='20')
# Move models and data to GPU
for m in [F, unwrapped_H, E, D, VF]:
m.cuda()
# Some constants for the GPU
with torch.no_grad():
# Our true image is constant
true_p = loaded_image.cuda()
# Invert our image once!
init_r = E(true_p)
init_z = D(init_r)
# Compute our features once!
true_v = VF(true_p)
# For GAN-generated images we have ground truth.
if loaded_z is None:
true_z = None
true_r = None
true_r1, true_r2, true_r3 = None, None, None
else:
true_z = loaded_z.cuda()
with InstrumentedModel(unwrapped_H) as inst_H:
inst_H.retain_layers(['layer1', 'layer2', 'layer3'])
true_r = inst_H(true_z)
true_r1, true_r2, true_r3 = [inst_H.retained_layer(n)
for n in ['layer1', 'layer2', 'layer3']]
# The model we learn are the top-level parameters of this wrapped model.
H = encoder_net.ResidualGenerator(
unwrapped_H, init_z, residuals)
H.eval()
H.cuda()
# Set up optimizer
set_requires_grad(False, F, H, E, D, VF)
parameters = OrderedDict(H.named_parameters(recurse=False))
for n, p in parameters.items():
p.requires_grad = True
optimizer = torch.optim.Adam(parameters.values(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=lr_milestones, gamma=0.5)
# Phase 1: find a better r4 by seeking d1, d2, d3, etc.
for step_num in pbar(range(num_steps + 1)):
current_r = H()
current_p = F(current_r)
current_v = VF(current_p)
loss_p = l1_loss(true_p, current_p)
loss_v = l1_loss(true_v, current_v)
loss_z = H.dz.pow(2).mean() if hasattr(H, 'dz') else 0
loss_1 = H.d1.pow(2).mean() if hasattr(H, 'd1') else 0
loss_2 = H.d2.pow(2).mean() if hasattr(H, 'd2') else 0
loss_3 = H.d3.pow(2).mean() if hasattr(H, 'd3') else 0
loss_4 = H.d4.pow(2).mean() if hasattr(H, 'd4') else 0
loss_r = mse_loss(init_r, current_r)
loss = (loss_p + loss_v + loss_z + loss_1 + loss_2 + loss_3 + loss_4)
all_loss = dict(loss=loss, loss_v=loss_v, loss_p=loss_p,
loss_r=loss_r,
loss_z=loss_z,
loss_1=loss_1,
loss_2=loss_2,
loss_3=loss_3,
loss_4=loss_4
)
all_loss = { k: v.item() for k, v in all_loss.items()
if v is not 0 }
if (step_num % args.snapshot_every == 0) or (step_num == num_steps):
with torch.no_grad():
if true_r is not None:
all_loss['err_r'] = cor_square_error(current_r, true_r
) * 100
all_loss['err_p'] = (current_p - true_p).pow(2).mean() * 100
log_progress('%d ' % step_num + ' '.join(
'%s=%.3f' % (k, all_loss[k])
for k in sorted(all_loss.keys())), phase='a')
visualize_results((image_number, 'a', step_num), current_p,
summarize=(step_num in [0, num_steps]))
checkpoint_dict = OrderedDict(all_loss)
for s in residuals:
s = s.replace('layer', '')
checkpoint_dict['init_%s' % s] = getattr(H, 'init_' + s)
checkpoint_dict['d_%s' % s] = getattr(H, 'd' + s)
checkpoint_dict['current_%s' % s] = (
getattr(H, 'init_' + s) + getattr(H, 'd' + s))
save_checkpoint(
phase='a',
step=step_num,
current_r=current_r,
current_p=current_p,
true_z=true_z,
true_r=true_r,
true_p=true_p,
lr=learning_rate,
optimizer=optimizer.state_dict(),
**checkpoint_dict)
optimizer.zero_grad()
loss.backward()
if step_num < num_steps:
optimizer.step()
scheduler.step()
def delete_log():
try:
os.remove(os.path.join(expdir, 'log.txt'))
except:
pass
def log_progress(s, phase='a'):
with open(os.path.join(expdir, 'log.txt'), 'a') as f:
f.write(phase + ' ' + s + '\n')
pbar.print(s)
def save_checkpoint(**kwargs):
dirname = os.path.join(expdir, 'snapshots')
os.makedirs(dirname, exist_ok=True)
filename = 'step_%s_%d.pth.tar' % (kwargs['phase'], kwargs['step'])
torch.save(kwargs, os.path.join(dirname, filename))
# Also save as .mat file for analysis.
numeric_data = {
k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
if isinstance(v, (Number, numpy.ndarray, torch.Tensor))}
filename = 'step_%s_%d.npz' % (kwargs['phase'], kwargs['step'])
numpy.savez(os.path.join(dirname, filename), **numeric_data)
def visualize_results(step, img, summarize=False):
if isinstance(step, tuple):
filename = '%s.png' % ('_'.join(str(i) for i in step))
else:
filename = '%s.png' % str(step)
dirname = os.path.join(expdir, 'images')
os.makedirs(dirname, exist_ok=True)
save_tensor_image(img, os.path.join(dirname, filename))
lbname = os.path.join(dirname, '+lightbox.html')
if not os.path.exists(lbname):
shutil.copy(os.path.join(os.path.dirname(__file__),
'lightbox.html'), lbname)
if summarize:
save_tensor_image(img, os.path.join(sumdir, filename))
lbname = os.path.join(sumdir, '+lightbox.html')
if not os.path.exists(lbname):
shutil.copy(os.path.join(os.path.dirname(__file__),
'lightbox.html'), lbname)
def save_tensor_image(img, filename):
if len(img.shape) == 4:
img = img[0]
np_data = ((img.permute(1, 2, 0) / 2 + 0.5) * 255
).clamp(0, 255).byte().cpu().numpy()
PIL.Image.fromarray(np_data).save(filename)
def set_requires_grad(requires_grad, *models):
for model in models:
if isinstance(model, torch.nn.Module):
for param in model.parameters():
param.requires_grad = requires_grad
elif isintance(model, torch.nn.Parameter):
model.requires_grad = requires_grad
else:
assert False, 'unknown type %r' % type(model)
def edit(x):
x = x.clone()
x[:,EDIT_UNITS] = 0
return x
#unit_level99 = {}
#for cls in ablation_units:
# corpus = numpy.load('reltest/churchoutdoor/layer4/ace/%s/corpus.npz' % cls)
if __name__ == '__main__':
exit_if_job_done(expdir, redo=args.redo)
main()
mark_job_done(expdir)