Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
11.5 kB
import torch, itertools, os
from collections import defaultdict, OrderedDict
from gan_training import checkpoints
from . import nethook, pbar, LBFGS
from seeing.nethook import InstrumentedModel
from torch.nn.functional import mse_loss
def refine_z_lbfgs(init_z, G, target_x, F=None, lambda_f=0,
R=None, lambda_r=0, num_steps=1000, quiet=False,
show=None, show_every=100):
'''
Starting at init_z, uss LBFGS to find a z for which G(z) -> target_x.
By default uses l1_loss, but can also mse(F(G(z)), F(target_x))
'''
z = init_z.clone()
parameters = [z]
nethook.set_requires_grad(False, G)
nethook.set_requires_grad(True, *parameters)
if lambda_f:
with torch.no_grad():
target_f = F(target_x)
nethook.set_requires_grad(False, F)
optimizer = LBFGS.FullBatchLBFGS(parameters)
def closure():
optimizer.zero_grad()
current_x = G(z)
loss = torch.nn.functional.l1_loss(target_x, current_x)
if lambda_f:
loss += torch.nn.functional.mse_loss(target_f, F(current_x)
) * lambda_f
if lambda_r:
loss += R(z) * lambda_r
return loss
pb = (lambda x: x) if quiet else pbar
with torch.enable_grad():
for step_num in pb(range(num_steps + 1)):
if step_num == 0:
loss = closure()
loss.backward()
lr, F_eval, G_eval = 0, 0, 0
else:
options = {'closure': closure, 'current_loss': loss,
'max_ls': 10}
loss, _, lr, _, _, _, _, _ = optimizer.step(options)
if show and (step_num % show_every == 0):
with torch.no_grad():
show(x=G(z), z=z, loss=loss, it=step_num)
return z
def split_gen_layers(enc, gen, layername):
'''
Given an inverter layername, splits the generator sequence into three:
(1) generator sequence before the layers to be inverted
(2) sequence of generator layers to be inverted by enc[layername]
(3) generator sequence after the layers to be inverted
'''
info = list(enc.inverse_info().items())
index = [i for i, (n, v) in enumerate(info) if n == layername][0]
upto_layer = info[index - 1][1]['first_layer'] if index > 0 else None
args = info[index][1]
first_layer = args['first_layer']
layers = nethook.subsequence(gen,
first_layer=first_layer, upto_layer=upto_layer)
prev_layers = nethook.subsequence(gen, upto_layer=first_layer)
next_layers = (nethook.subsequence(gen, first_layer=upto_layer) if
upto_layer else torch.nn.Sequential())
return prev_layers, layers, next_layers
def last_gen_layername(enc, gen, layername):
_, layers, _ = split_gen_layers(enc, gen, layername)
return [n for n, c in layers.named_children()][-1]
def layers_after(enc, layername):
layernames = [n for n, c in enc.named_children()]
index = layernames.index(layername)
if index + 1 < len(layernames):
return nethook.subsequence(enc, layernames[index + 1])
else:
return torch.nn.Sequential()
def train_inv_layer(enc, gen, dists, layername, combine_z=None,
batch_size=100, test_batches=10, checkpoint_dir='ckpts',
resume_from=None, logfile=None, **kwargs):
'''
Inverts a single layer of a multilayer inverter.
Both enc and should be a nn.Sequential subclasses, and
layername specifies the layer of enc to train. That layer
of enc will be trained to invert a set of gen layers.
Which layers specifically are determined by split_gen_layers,
which depends on enc.inverse_info(), to specify how each inverter
layer relates to layers of gen.
'''
if logfile is None:
logfile = os.path.join(checkpoint_dir, 'log.txt')
prev_layers, layers, next_layers = split_gen_layers(enc, gen, layername)
inv = getattr(enc, layername)
device = next(gen.parameters()).device
args = enc.inverse_info()[layername]
kwargs_out = {k: v for k, v in args.items() if k != 'first_layer'}
kwargs_out.update(kwargs)
if 'x_weight' not in kwargs_out:
kwargs_out['x_weight'] = 0
nethook.set_requires_grad(False, layers, prev_layers, next_layers)
if combine_z is not None:
nethook.set_requires_grad(False, combine_z)
prev_layers_old = prev_layers
prev_layers = lambda *a: prev_layers_old(combine_z(*a))
zsampler = infinite_sampler(dists, prev_layers, batch_size, device)
tsample = test_sampler(dists, prev_layers, batch_size, test_batches, device)
train_inverse(inv, layers, zsampler, test_sampler=tsample,
resume_from=resume_from,
checkpoint_dir=checkpoint_dir, logfile=logfile,
**kwargs_out)
def train_inv_joint(enc, gen, dists, combine_z=None,
inv_layer=None, gen_layer=None,
batch_size=50, test_batches=10,
checkpoint_dir='ckpts',
logfile=None, **kwargs):
if logfile is None:
logfile = os.path.join(checkpoint_dir, 'log.txt')
device = next(gen.parameters()).device
zsampler = infinite_sampler(dists, combine_z, batch_size, device)
tsample = test_sampler(dists, combine_z, batch_size, test_batches, device)
with InstrumentedModel(gen) as G, InstrumentedModel(enc) as E:
G.retain_layer(gen_layer, detach=False)
nethook.set_requires_grad(False, G)
E.retain_layer(inv_layer, detach=False)
train_inverse(E, G, zsampler, inv_layer, gen_layer,
r_weight=1.0, ir_weight=1.0, test_sampler=tsample,
checkpoint_dir=checkpoint_dir,
checkpoint_selector=lambda x: x.model,
logfile=logfile, **kwargs)
def train_inverse(inv, gen, sampler, inv_layer=None, gen_layer=None,
z_weight=1.0, x_weight=1.0, r_weight=0.0, ir_weight=0.0, reg_weight=0.0,
adjust_z=None, regularize_z=None,
test_sampler=None, lr=0.01, lr_milestones=None,
epoch_batches=100, num_epochs=100, save_every=50,
logfile=None,
checkpoint_dir=None, checkpoint_selector=None, resume_from=None):
'''
To set this up:
inv and gen should both be instrumented models,
and inv layer and gen layer should be retained on both of them
without detach.
'''
if lr_milestones is None:
lr_milestones = []
optimizer = torch.optim.Adam(inv.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=lr_milestones, gamma=0.1)
ckpts = checkpoints.CheckpointIO(checkpoint_dir)
ckpts.register_modules(optimizer=optimizer, scheduler=scheduler,
inv=inv if not checkpoint_selector else checkpoint_selector(inv))
if resume_from is not None:
ckpts.load('ckpt_%d.pt' % resume_from)
start_epoch = resume_from + 1
else:
start_epoch = 0
def inv_loss(z):
return sum(loss * weight
for loss, weight in component_loss(z).values())
def component_loss(true_z):
observed_x = gen(true_z)
if gen_layer:
true_r = gen.retained_layer(gen_layer, clear=True)
estimated_z = inv(observed_x)
if inv_layer:
inverted_r = inv.retained_layer(inv_layer, clear=True)
if adjust_z:
fixed_z = adjust_z(estimated_z, true_z)
else:
fixed_z = estimated_z
if x_weight or r_weight:
reconstructed_x = gen(fixed_z)
if gen_layer:
reconstructed_r = gen.retained_layer(gen_layer, clear=True)
losses = OrderedDict()
if reg_weight:
losses['reg'] = (regularize_z(estimated_z, true_z), reg_weight)
if z_weight:
losses['z'] = (mse_loss(true_z, estimated_z), z_weight)
if ir_weight:
losses['ir'] = (cor_square_error(true_r, inverted_r), ir_weight)
if x_weight:
losses['x'] = (mse_loss(observed_x, reconstructed_x), x_weight)
if r_weight:
losses['r'] = (cor_square_error(true_r, reconstructed_r), r_weight)
return losses
with torch.no_grad():
for epoch, epoch_loader in pbar(
epoch_grouper(sampler, epoch_batches, num_epochs=1+num_epochs,
start_epoch=start_epoch),
total=(1+num_epochs-start_epoch)):
if epoch > 0:
for (z_batch,) in epoch_loader:
with torch.enable_grad():
optimizer.zero_grad()
loss = inv_loss(z_batch)
loss.backward()
pbar.post(l=loss.item())
optimizer.step()
scheduler.step()
if test_sampler is not None:
stats = MeanStats()
for (z_batch,) in test_sampler:
stats.add(component_loss(z_batch), len(z_batch))
logline = stats.logline(epoch)
pbar.print(logline)
if logfile is not None:
with open(logfile, 'a') as f:
f.write(logline + '\n')
elif epoch > 0:
pbar.print('%d: loss=%4g' % (epoch, loss.item()))
if epoch % save_every == 0 or epoch == num_epochs:
ckpts.save(epoch, 'ckpt_%d.pt' % epoch)
def infinite_sampler(dists, f, batch_size, device):
'''
Returns an infinite generator that runs through randomized z
batches, forever.
'''
while True:
zs = [dist.sample([batch_size]).to(device) for dist in dists]
r = zs[0] if f is None else f(*zs)
yield (r,)
def test_sampler(dists, f, batch_size, test_batches, device):
class TestSampler():
def __init__(self):
self.num_batches = test_batches
self.zs_batches = [
dist.sample([test_batches, batch_size]).to(device)
for dist in dists]
def __iter__(self):
for i in range(self.num_batches):
zs = [uncombined[i] for uncombined in self.zs_batches]
r = zs[0] if f is None else f(*zs)
yield (r,)
return TestSampler()
def epoch_grouper(loader, epoch_size, num_epochs=None, start_epoch=0):
'''
To use with an infinite training loader: groups the training data
batches into epochs of the given size.
'''
it = iter(loader)
epoch = start_epoch
while True:
chunk_it = itertools.islice(it, epoch_size)
try:
first_el = next(chunk_it)
except StopIteration:
return
yield epoch, itertools.chain((first_el,), chunk_it)
epoch += 1
if num_epochs is not None and epoch >= num_epochs:
return
def cor_square_error(x, y, eps=1e-12):
# Analogous to MSE, but in terms of Pearson's correlation
return (1.0 - torch.nn.functional.cosine_similarity(x, y, eps=eps)).mean()
class MeanStats:
def __init__(self):
self.tally = defaultdict(float)
self.count = 0
def add(self, c, size):
for n, (loss, weight) in c.items():
self.tally[n] += loss.item() * size
self.count += size
def summary(self):
return {n: v / self.count for n, v in self.tally.items()}
def logline(self, i=None):
prefix = '' if i is None else '%d: ' % i
return prefix + ' '.join('%s=%4g' % (n, v)
for n, v in self.summary().items())