Spaces:
Runtime error
Runtime error
import torch | |
def get_average_image(net, opts): | |
avg_image = net(net.latent_avg.unsqueeze(0), | |
input_code=True, | |
randomize_noise=False, | |
return_latents=False, | |
average_code=True)[0] | |
avg_image = avg_image.to('cuda').float().detach() | |
if "cars" in opts.dataset_type: | |
avg_image = avg_image[:, 32:224, :] | |
return avg_image | |
def run_on_batch(inputs, net, opts): | |
avg_image = get_average_image(net, opts) | |
y_hat, latent = None, None | |
for iter in range(opts.n_iters_per_batch): | |
if iter == 0: | |
avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) | |
x_input = torch.cat([inputs, avg_image_for_batch], dim=1) | |
else: | |
x_input = torch.cat([inputs, y_hat], dim=1) | |
y_hat, latent = net.forward(x_input, | |
latent=latent, | |
randomize_noise=False, | |
return_latents=True, | |
resize=opts.resize_outputs) | |
if "cars" in opts.dataset_type: | |
if opts.resize_outputs: | |
y_hat = y_hat[:, :, 32:224, :] | |
else: | |
y_hat = y_hat[:, :, 64:448, :] | |
# resize input to 256 before feeding into next iteration | |
if "cars" in opts.dataset_type: | |
y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) | |
else: | |
y_hat = net.face_pool(y_hat) | |
return y_hat, latent | |