ethanNeuralImage's picture
models
92ec8d3
raw
history blame
No virus
2.29 kB
import torch
def run_inversion(inputs, net, opts, return_intermediate_results=False):
y_hat, latent, weights_deltas, codes = None, None, None, None
if return_intermediate_results:
results_batch = {idx: [] for idx in range(inputs.shape[0])}
results_latent = {idx: [] for idx in range(inputs.shape[0])}
results_deltas = {idx: [] for idx in range(inputs.shape[0])}
else:
results_batch, results_latent, results_deltas = None, None, None
for iter in range(opts.n_iters_per_batch):
y_hat, latent, weights_deltas, codes, _ = net.forward(inputs,
y_hat=y_hat,
codes=codes,
weights_deltas=weights_deltas,
return_latents=True,
resize=opts.resize_outputs,
randomize_noise=False,
return_weight_deltas_and_codes=True)
if "cars" in opts.dataset_type:
if opts.resize_outputs:
y_hat = y_hat[:, :, 32:224, :]
else:
y_hat = y_hat[:, :, 64:448, :]
if return_intermediate_results:
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
# resize input to 256 before feeding into next iteration
if "cars" in opts.dataset_type:
y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat)
else:
y_hat = net.face_pool(y_hat)
if return_intermediate_results:
return results_batch, results_latent, results_deltas
return y_hat, latent, weights_deltas, codes
def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas):
for idx in range(y_hat.shape[0]):
results_batch[idx].append(y_hat[idx])
results_latent[idx].append(latent[idx].cpu().numpy())
results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas])