inversion_testing / models /hyperstyle /utils /restyle_inference_utils.py
ethanNeuralImage's picture
models
92ec8d3
import torch
def get_average_image(net, opts):
avg_image = net(net.latent_avg.unsqueeze(0),
input_code=True,
randomize_noise=False,
return_latents=False,
average_code=True)[0]
avg_image = avg_image.to('cuda').float().detach()
if "cars" in opts.dataset_type:
avg_image = avg_image[:, 32:224, :]
return avg_image
def run_on_batch(inputs, net, opts):
avg_image = get_average_image(net, opts)
y_hat, latent = None, None
for iter in range(opts.n_iters_per_batch):
if iter == 0:
avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1)
x_input = torch.cat([inputs, avg_image_for_batch], dim=1)
else:
x_input = torch.cat([inputs, y_hat], dim=1)
y_hat, latent = net.forward(x_input,
latent=latent,
randomize_noise=False,
return_latents=True,
resize=opts.resize_outputs)
if "cars" in opts.dataset_type:
if opts.resize_outputs:
y_hat = y_hat[:, :, 32:224, :]
else:
y_hat = y_hat[:, :, 64:448, :]
# resize input to 256 before feeding into next iteration
if "cars" in opts.dataset_type:
y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat)
else:
y_hat = net.face_pool(y_hat)
return y_hat, latent