|
import torch, itertools, os |
|
from collections import defaultdict, OrderedDict |
|
from gan_training import checkpoints |
|
from . import nethook, pbar, LBFGS |
|
from seeing.nethook import InstrumentedModel |
|
from torch.nn.functional import mse_loss |
|
|
|
def refine_z_lbfgs(init_z, G, target_x, F=None, lambda_f=0, |
|
R=None, lambda_r=0, num_steps=1000, quiet=False, |
|
show=None, show_every=100): |
|
''' |
|
Starting at init_z, uss LBFGS to find a z for which G(z) -> target_x. |
|
By default uses l1_loss, but can also mse(F(G(z)), F(target_x)) |
|
''' |
|
z = init_z.clone() |
|
parameters = [z] |
|
nethook.set_requires_grad(False, G) |
|
nethook.set_requires_grad(True, *parameters) |
|
if lambda_f: |
|
with torch.no_grad(): |
|
target_f = F(target_x) |
|
nethook.set_requires_grad(False, F) |
|
|
|
optimizer = LBFGS.FullBatchLBFGS(parameters) |
|
|
|
def closure(): |
|
optimizer.zero_grad() |
|
current_x = G(z) |
|
loss = torch.nn.functional.l1_loss(target_x, current_x) |
|
if lambda_f: |
|
loss += torch.nn.functional.mse_loss(target_f, F(current_x) |
|
) * lambda_f |
|
if lambda_r: |
|
loss += R(z) * lambda_r |
|
return loss |
|
|
|
pb = (lambda x: x) if quiet else pbar |
|
|
|
with torch.enable_grad(): |
|
for step_num in pb(range(num_steps + 1)): |
|
if step_num == 0: |
|
loss = closure() |
|
loss.backward() |
|
lr, F_eval, G_eval = 0, 0, 0 |
|
else: |
|
options = {'closure': closure, 'current_loss': loss, |
|
'max_ls': 10} |
|
loss, _, lr, _, _, _, _, _ = optimizer.step(options) |
|
if show and (step_num % show_every == 0): |
|
with torch.no_grad(): |
|
show(x=G(z), z=z, loss=loss, it=step_num) |
|
|
|
return z |
|
|
|
def split_gen_layers(enc, gen, layername): |
|
''' |
|
Given an inverter layername, splits the generator sequence into three: |
|
(1) generator sequence before the layers to be inverted |
|
(2) sequence of generator layers to be inverted by enc[layername] |
|
(3) generator sequence after the layers to be inverted |
|
''' |
|
info = list(enc.inverse_info().items()) |
|
index = [i for i, (n, v) in enumerate(info) if n == layername][0] |
|
upto_layer = info[index - 1][1]['first_layer'] if index > 0 else None |
|
args = info[index][1] |
|
first_layer = args['first_layer'] |
|
layers = nethook.subsequence(gen, |
|
first_layer=first_layer, upto_layer=upto_layer) |
|
prev_layers = nethook.subsequence(gen, upto_layer=first_layer) |
|
next_layers = (nethook.subsequence(gen, first_layer=upto_layer) if |
|
upto_layer else torch.nn.Sequential()) |
|
return prev_layers, layers, next_layers |
|
|
|
def last_gen_layername(enc, gen, layername): |
|
_, layers, _ = split_gen_layers(enc, gen, layername) |
|
return [n for n, c in layers.named_children()][-1] |
|
|
|
def layers_after(enc, layername): |
|
layernames = [n for n, c in enc.named_children()] |
|
index = layernames.index(layername) |
|
if index + 1 < len(layernames): |
|
return nethook.subsequence(enc, layernames[index + 1]) |
|
else: |
|
return torch.nn.Sequential() |
|
|
|
def train_inv_layer(enc, gen, dists, layername, combine_z=None, |
|
batch_size=100, test_batches=10, checkpoint_dir='ckpts', |
|
resume_from=None, logfile=None, **kwargs): |
|
''' |
|
Inverts a single layer of a multilayer inverter. |
|
Both enc and should be a nn.Sequential subclasses, and |
|
layername specifies the layer of enc to train. That layer |
|
of enc will be trained to invert a set of gen layers. |
|
Which layers specifically are determined by split_gen_layers, |
|
which depends on enc.inverse_info(), to specify how each inverter |
|
layer relates to layers of gen. |
|
''' |
|
if logfile is None: |
|
logfile = os.path.join(checkpoint_dir, 'log.txt') |
|
prev_layers, layers, next_layers = split_gen_layers(enc, gen, layername) |
|
inv = getattr(enc, layername) |
|
device = next(gen.parameters()).device |
|
args = enc.inverse_info()[layername] |
|
kwargs_out = {k: v for k, v in args.items() if k != 'first_layer'} |
|
kwargs_out.update(kwargs) |
|
if 'x_weight' not in kwargs_out: |
|
kwargs_out['x_weight'] = 0 |
|
nethook.set_requires_grad(False, layers, prev_layers, next_layers) |
|
if combine_z is not None: |
|
nethook.set_requires_grad(False, combine_z) |
|
prev_layers_old = prev_layers |
|
prev_layers = lambda *a: prev_layers_old(combine_z(*a)) |
|
zsampler = infinite_sampler(dists, prev_layers, batch_size, device) |
|
tsample = test_sampler(dists, prev_layers, batch_size, test_batches, device) |
|
train_inverse(inv, layers, zsampler, test_sampler=tsample, |
|
resume_from=resume_from, |
|
checkpoint_dir=checkpoint_dir, logfile=logfile, |
|
**kwargs_out) |
|
|
|
def train_inv_joint(enc, gen, dists, combine_z=None, |
|
inv_layer=None, gen_layer=None, |
|
batch_size=50, test_batches=10, |
|
checkpoint_dir='ckpts', |
|
logfile=None, **kwargs): |
|
if logfile is None: |
|
logfile = os.path.join(checkpoint_dir, 'log.txt') |
|
device = next(gen.parameters()).device |
|
zsampler = infinite_sampler(dists, combine_z, batch_size, device) |
|
tsample = test_sampler(dists, combine_z, batch_size, test_batches, device) |
|
|
|
with InstrumentedModel(gen) as G, InstrumentedModel(enc) as E: |
|
G.retain_layer(gen_layer, detach=False) |
|
nethook.set_requires_grad(False, G) |
|
E.retain_layer(inv_layer, detach=False) |
|
train_inverse(E, G, zsampler, inv_layer, gen_layer, |
|
r_weight=1.0, ir_weight=1.0, test_sampler=tsample, |
|
checkpoint_dir=checkpoint_dir, |
|
checkpoint_selector=lambda x: x.model, |
|
logfile=logfile, **kwargs) |
|
|
|
def train_inverse(inv, gen, sampler, inv_layer=None, gen_layer=None, |
|
z_weight=1.0, x_weight=1.0, r_weight=0.0, ir_weight=0.0, reg_weight=0.0, |
|
adjust_z=None, regularize_z=None, |
|
test_sampler=None, lr=0.01, lr_milestones=None, |
|
epoch_batches=100, num_epochs=100, save_every=50, |
|
logfile=None, |
|
checkpoint_dir=None, checkpoint_selector=None, resume_from=None): |
|
''' |
|
To set this up: |
|
inv and gen should both be instrumented models, |
|
and inv layer and gen layer should be retained on both of them |
|
without detach. |
|
''' |
|
if lr_milestones is None: |
|
lr_milestones = [] |
|
optimizer = torch.optim.Adam(inv.parameters(), lr=lr) |
|
scheduler = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer, milestones=lr_milestones, gamma=0.1) |
|
|
|
ckpts = checkpoints.CheckpointIO(checkpoint_dir) |
|
ckpts.register_modules(optimizer=optimizer, scheduler=scheduler, |
|
inv=inv if not checkpoint_selector else checkpoint_selector(inv)) |
|
if resume_from is not None: |
|
ckpts.load('ckpt_%d.pt' % resume_from) |
|
start_epoch = resume_from + 1 |
|
else: |
|
start_epoch = 0 |
|
|
|
def inv_loss(z): |
|
return sum(loss * weight |
|
for loss, weight in component_loss(z).values()) |
|
|
|
def component_loss(true_z): |
|
observed_x = gen(true_z) |
|
if gen_layer: |
|
true_r = gen.retained_layer(gen_layer, clear=True) |
|
estimated_z = inv(observed_x) |
|
if inv_layer: |
|
inverted_r = inv.retained_layer(inv_layer, clear=True) |
|
if adjust_z: |
|
fixed_z = adjust_z(estimated_z, true_z) |
|
else: |
|
fixed_z = estimated_z |
|
if x_weight or r_weight: |
|
reconstructed_x = gen(fixed_z) |
|
if gen_layer: |
|
reconstructed_r = gen.retained_layer(gen_layer, clear=True) |
|
losses = OrderedDict() |
|
if reg_weight: |
|
losses['reg'] = (regularize_z(estimated_z, true_z), reg_weight) |
|
if z_weight: |
|
losses['z'] = (mse_loss(true_z, estimated_z), z_weight) |
|
if ir_weight: |
|
losses['ir'] = (cor_square_error(true_r, inverted_r), ir_weight) |
|
if x_weight: |
|
losses['x'] = (mse_loss(observed_x, reconstructed_x), x_weight) |
|
if r_weight: |
|
losses['r'] = (cor_square_error(true_r, reconstructed_r), r_weight) |
|
return losses |
|
|
|
with torch.no_grad(): |
|
for epoch, epoch_loader in pbar( |
|
epoch_grouper(sampler, epoch_batches, num_epochs=1+num_epochs, |
|
start_epoch=start_epoch), |
|
total=(1+num_epochs-start_epoch)): |
|
if epoch > 0: |
|
for (z_batch,) in epoch_loader: |
|
with torch.enable_grad(): |
|
optimizer.zero_grad() |
|
loss = inv_loss(z_batch) |
|
loss.backward() |
|
pbar.post(l=loss.item()) |
|
optimizer.step() |
|
scheduler.step() |
|
if test_sampler is not None: |
|
stats = MeanStats() |
|
for (z_batch,) in test_sampler: |
|
stats.add(component_loss(z_batch), len(z_batch)) |
|
logline = stats.logline(epoch) |
|
pbar.print(logline) |
|
if logfile is not None: |
|
with open(logfile, 'a') as f: |
|
f.write(logline + '\n') |
|
elif epoch > 0: |
|
pbar.print('%d: loss=%4g' % (epoch, loss.item())) |
|
if epoch % save_every == 0 or epoch == num_epochs: |
|
ckpts.save(epoch, 'ckpt_%d.pt' % epoch) |
|
|
|
def infinite_sampler(dists, f, batch_size, device): |
|
''' |
|
Returns an infinite generator that runs through randomized z |
|
batches, forever. |
|
''' |
|
while True: |
|
zs = [dist.sample([batch_size]).to(device) for dist in dists] |
|
r = zs[0] if f is None else f(*zs) |
|
yield (r,) |
|
|
|
def test_sampler(dists, f, batch_size, test_batches, device): |
|
class TestSampler(): |
|
def __init__(self): |
|
self.num_batches = test_batches |
|
self.zs_batches = [ |
|
dist.sample([test_batches, batch_size]).to(device) |
|
for dist in dists] |
|
|
|
def __iter__(self): |
|
for i in range(self.num_batches): |
|
zs = [uncombined[i] for uncombined in self.zs_batches] |
|
r = zs[0] if f is None else f(*zs) |
|
yield (r,) |
|
return TestSampler() |
|
|
|
def epoch_grouper(loader, epoch_size, num_epochs=None, start_epoch=0): |
|
''' |
|
To use with an infinite training loader: groups the training data |
|
batches into epochs of the given size. |
|
''' |
|
it = iter(loader) |
|
epoch = start_epoch |
|
while True: |
|
chunk_it = itertools.islice(it, epoch_size) |
|
try: |
|
first_el = next(chunk_it) |
|
except StopIteration: |
|
return |
|
yield epoch, itertools.chain((first_el,), chunk_it) |
|
epoch += 1 |
|
if num_epochs is not None and epoch >= num_epochs: |
|
return |
|
|
|
def cor_square_error(x, y, eps=1e-12): |
|
|
|
return (1.0 - torch.nn.functional.cosine_similarity(x, y, eps=eps)).mean() |
|
|
|
|
|
class MeanStats: |
|
def __init__(self): |
|
self.tally = defaultdict(float) |
|
self.count = 0 |
|
|
|
def add(self, c, size): |
|
for n, (loss, weight) in c.items(): |
|
self.tally[n] += loss.item() * size |
|
self.count += size |
|
|
|
def summary(self): |
|
return {n: v / self.count for n, v in self.tally.items()} |
|
|
|
def logline(self, i=None): |
|
prefix = '' if i is None else '%d: ' % i |
|
return prefix + ' '.join('%s=%4g' % (n, v) |
|
for n, v in self.summary().items()) |
|
|