|
|
|
|
|
|
|
import numpy as np |
|
from PIL import Image |
|
import wandb |
|
from pti.pti_configs import global_config |
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def log_image_from_w(w, G, name): |
|
img = get_image_from_w(w, G) |
|
pillow_image = Image.fromarray(img) |
|
wandb.log( |
|
{f"{name}": [ |
|
wandb.Image(pillow_image, caption=f"current inversion {name}")]}, |
|
step=global_config.training_step) |
|
|
|
|
|
def log_images_from_w(ws, G, names): |
|
for name, w in zip(names, ws): |
|
w = w.to(global_config.device) |
|
log_image_from_w(w, G, name) |
|
|
|
|
|
def plot_image_from_w(w, G): |
|
img = get_image_from_w(w, G) |
|
pillow_image = Image.fromarray(img) |
|
plt.imshow(pillow_image) |
|
plt.show() |
|
|
|
|
|
def plot_image(img): |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() |
|
pillow_image = Image.fromarray(img[0]) |
|
plt.imshow(pillow_image) |
|
plt.show() |
|
|
|
|
|
def save_image(name, method_type, results_dir, image, run_id): |
|
image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg') |
|
|
|
|
|
def save_w(w, G, name, method_type, results_dir): |
|
im = get_image_from_w(w, G) |
|
im = Image.fromarray(im, mode='RGB') |
|
save_image(name, method_type, results_dir, im) |
|
|
|
|
|
def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G, |
|
old_G, |
|
file_name, |
|
extra_image=None): |
|
images_to_save = [] |
|
if extra_image is not None: |
|
images_to_save.append(extra_image) |
|
for latent in image_latents: |
|
images_to_save.append(get_image_from_w(latent, old_G)) |
|
images_to_save.append(get_image_from_w(new_inv_image_latent, new_G)) |
|
result_image = create_alongside_images(images_to_save) |
|
result_image.save(f'{base_dir}/{file_name}.jpg') |
|
|
|
|
|
def save_single_image(base_dir, image_latent, G, file_name): |
|
image_to_save = get_image_from_w(image_latent, G) |
|
image_to_save = Image.fromarray(image_to_save, mode='RGB') |
|
image_to_save.save(f'{base_dir}/{file_name}.jpg') |
|
|
|
|
|
def create_alongside_images(images): |
|
res = np.concatenate([np.array(image) for image in images], axis=1) |
|
return Image.fromarray(res, mode='RGB') |
|
|
|
|
|
def get_image_from_w(w, G): |
|
if len(w.size()) <= 2: |
|
w = w.unsqueeze(0) |
|
with torch.no_grad(): |
|
img = G.synthesis(w, noise_mode='const') |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() |
|
return img[0] |
|
|