PTI / utils /log_utils.py
ucalyptus's picture
pushes
f799d59
import numpy as np
from PIL import Image
from configs import global_config
import torch
import matplotlib.pyplot as plt
def log_image_from_w(w, G, name):
img = get_image_from_w(w, G)
pillow_image = Image.fromarray(img)
wandb.log(
{f"{name}": [
wandb.Image(pillow_image, caption=f"current inversion {name}")]},
step=global_config.training_step)
def log_images_from_w(ws, G, names):
for name, w in zip(names, ws):
w = w.to(global_config.device)
log_image_from_w(w, G, name)
def plot_image_from_w(w, G):
img = get_image_from_w(w, G)
pillow_image = Image.fromarray(img)
plt.imshow(pillow_image)
plt.show()
def plot_image(img):
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
pillow_image = Image.fromarray(img[0])
plt.imshow(pillow_image)
plt.show()
def save_image(name, method_type, results_dir, image, run_id):
image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg')
def save_w(w, G, name, method_type, results_dir):
im = get_image_from_w(w, G)
im = Image.fromarray(im, mode='RGB')
save_image(name, method_type, results_dir, im)
def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G,
old_G,
file_name,
extra_image=None):
images_to_save = []
if extra_image is not None:
images_to_save.append(extra_image)
for latent in image_latents:
images_to_save.append(get_image_from_w(latent, old_G))
images_to_save.append(get_image_from_w(new_inv_image_latent, new_G))
result_image = create_alongside_images(images_to_save)
result_image.save(f'{base_dir}/{file_name}.jpg')
def save_single_image(base_dir, image_latent, G, file_name):
image_to_save = get_image_from_w(image_latent, G)
image_to_save = Image.fromarray(image_to_save, mode='RGB')
image_to_save.save(f'{base_dir}/{file_name}.jpg')
def create_alongside_images(images):
res = np.concatenate([np.array(image) for image in images], axis=1)
return Image.fromarray(res, mode='RGB')
def get_image_from_w(w, G):
if len(w.size()) <= 2:
w = w.unsqueeze(0)
with torch.no_grad():
img = G.synthesis(w, noise_mode='const')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
return img[0]