import torch def get_average_image(net, opts): avg_image = net(net.latent_avg.unsqueeze(0), input_code=True, randomize_noise=False, return_latents=False, average_code=True)[0] avg_image = avg_image.to('cuda').float().detach() if "cars" in opts.dataset_type: avg_image = avg_image[:, 32:224, :] return avg_image def run_on_batch(inputs, net, opts): avg_image = get_average_image(net, opts) y_hat, latent = None, None for iter in range(opts.n_iters_per_batch): if iter == 0: avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) x_input = torch.cat([inputs, avg_image_for_batch], dim=1) else: x_input = torch.cat([inputs, y_hat], dim=1) y_hat, latent = net.forward(x_input, latent=latent, randomize_noise=False, return_latents=True, resize=opts.resize_outputs) if "cars" in opts.dataset_type: if opts.resize_outputs: y_hat = y_hat[:, :, 32:224, :] else: y_hat = y_hat[:, :, 64:448, :] # resize input to 256 before feeding into next iteration if "cars" in opts.dataset_type: y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) else: y_hat = net.face_pool(y_hat) return y_hat, latent