import numpy as np from PIL import Image import wandb from PTI.configs import global_config import torch import matplotlib.pyplot as plt def log_image_from_w(w, G, name): img = get_image_from_w(w, G) pillow_image = Image.fromarray(img) wandb.log( {f"{name}": [ wandb.Image(pillow_image, caption=f"current inversion {name}")]}, step=global_config.training_step) def log_images_from_w(ws, G, names): for name, w in zip(names, ws): w = w.to(global_config.device) log_image_from_w(w, G, name) def plot_image_from_w(w, G): img = get_image_from_w(w, G) pillow_image = Image.fromarray(img) plt.imshow(pillow_image) plt.show() def plot_image(img): img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() pillow_image = Image.fromarray(img[0]) plt.imshow(pillow_image) plt.show() def save_image(name, method_type, results_dir, image, run_id): image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg') def save_w(w, G, name, method_type, results_dir): im = get_image_from_w(w, G) im = Image.fromarray(im, mode='RGB') save_image(name, method_type, results_dir, im) def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G, old_G, file_name, extra_image=None): images_to_save = [] if extra_image is not None: images_to_save.append(extra_image) for latent in image_latents: images_to_save.append(get_image_from_w(latent, old_G)) images_to_save.append(get_image_from_w(new_inv_image_latent, new_G)) result_image = create_alongside_images(images_to_save) result_image.save(f'{base_dir}/{file_name}.jpg') def save_single_image(base_dir, image_latent, G, file_name): image_to_save = get_image_from_w(image_latent, G) image_to_save = Image.fromarray(image_to_save, mode='RGB') image_to_save.save(f'{base_dir}/{file_name}.jpg') def create_alongside_images(images): res = np.concatenate([np.array(image) for image in images], axis=1) return Image.fromarray(res, mode='RGB') def get_image_from_w(w, G): if len(w.size()) <= 2: w = w.unsqueeze(0) with torch.no_grad(): img = G.synthesis(w, noise_mode='const') img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() return img[0]