import torch def run_inversion(inputs, net, opts, return_intermediate_results=False): y_hat, latent, weights_deltas, codes = None, None, None, None if return_intermediate_results: results_batch = {idx: [] for idx in range(inputs.shape[0])} results_latent = {idx: [] for idx in range(inputs.shape[0])} results_deltas = {idx: [] for idx in range(inputs.shape[0])} else: results_batch, results_latent, results_deltas = None, None, None for iter in range(opts.n_iters_per_batch): y_hat, latent, weights_deltas, codes, _ = net.forward(inputs, y_hat=y_hat, codes=codes, weights_deltas=weights_deltas, return_latents=True, resize=opts.resize_outputs, randomize_noise=False, return_weight_deltas_and_codes=True) if "cars" in opts.dataset_type: if opts.resize_outputs: y_hat = y_hat[:, :, 32:224, :] else: y_hat = y_hat[:, :, 64:448, :] if return_intermediate_results: store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) # resize input to 256 before feeding into next iteration if "cars" in opts.dataset_type: y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) else: y_hat = net.face_pool(y_hat) if return_intermediate_results: return results_batch, results_latent, results_deltas return y_hat, latent, weights_deltas, codes def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas): for idx in range(y_hat.shape[0]): results_batch[idx].append(y_hat[idx]) results_latent[idx].append(latent[idx].cpu().numpy()) results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas])