|
|
"""Utility helpers for data processing and visualization.""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
from config import vae_plots_path, unet_plots_path, cfg, device |
|
|
from diffusers import DDPMScheduler |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def create_path_if_not_exists(path: str) -> None: |
|
|
"""Create a directory if it does not already exist.""" |
|
|
if not os.path.exists(path): |
|
|
os.makedirs(path) |
|
|
|
|
|
|
|
|
def revert_images(imgs: torch.Tensor) -> np.ndarray: |
|
|
"""Convert normalized tensors into displayable images.""" |
|
|
h = imgs.shape[-1] |
|
|
imgs = imgs.cpu().detach().numpy() |
|
|
min_vals = imgs.min(axis=(1, 2, 3))[:, np.newaxis, np.newaxis, np.newaxis] |
|
|
max_vals = imgs.max(axis=(1, 2, 3))[:, np.newaxis, np.newaxis, np.newaxis] |
|
|
|
|
|
imgs = ((max_vals - imgs) / (max_vals - min_vals)) * 255 |
|
|
if imgs.shape[1] == 1: |
|
|
imgs = imgs.astype(int).reshape(-1, h, h) |
|
|
|
|
|
return imgs |
|
|
|
|
|
|
|
|
def plot_side_by_side( |
|
|
images_y: torch.Tensor, images_pred: torch.Tensor, latents: torch.Tensor, epoch: int |
|
|
) -> None: |
|
|
"""Visualize input images, outputs and latent channels.""" |
|
|
images_y, images_pred = revert_images(images_y), revert_images(images_pred) |
|
|
latents = revert_images(latents) |
|
|
idx = np.random.randint(0, images_y.shape[0]) |
|
|
fig, axs = plt.subplots(1, 2) |
|
|
|
|
|
|
|
|
axs[0].imshow(images_y[idx], cmap="gray") |
|
|
axs[0].axis("off") |
|
|
axs[0].set_title("Input") |
|
|
|
|
|
axs[1].imshow(images_pred[idx], cmap="gray") |
|
|
axs[1].axis("off") |
|
|
axs[1].set_title("Output") |
|
|
plt.savefig(os.path.join(vae_plots_path, f"epoch_{epoch}_input_output.png")) |
|
|
plt.clf() |
|
|
|
|
|
latent_channels = latents.shape[1] |
|
|
fig, axs = plt.subplots(1, 4) |
|
|
|
|
|
|
|
|
for i in range(latent_channels): |
|
|
axs[i].imshow(latents[idx, i, :, :], cmap="gray") |
|
|
axs[i].axis("off") |
|
|
axs[i].set_title(f"Latent channel: {i}", fontsize=8) |
|
|
plt.savefig(os.path.join(vae_plots_path, f"epoch_{epoch}_latent_channels.png")) |
|
|
plt.clf() |
|
|
|
|
|
|
|
|
def generate( |
|
|
vae: torch.nn.Module, |
|
|
unet: torch.nn.Module, |
|
|
noise_scheduler: DDPMScheduler, |
|
|
epoch: int, |
|
|
) -> None: |
|
|
"""Generate samples from the UNet model.""" |
|
|
|
|
|
def plot(recon_imgs: torch.Tensor, timesteps: int, epoch: int) -> None: |
|
|
create_path_if_not_exists(os.path.join(unet_plots_path, f"epoch_{epoch}")) |
|
|
recon_imgs = revert_images(recon_imgs.sample) |
|
|
fig, axs = plt.subplots(2, 5) |
|
|
for i in range(10): |
|
|
axs[i // 5][i % 5].imshow(recon_imgs[i], cmap="gray") |
|
|
axs[i // 5][i % 5].axis("off") |
|
|
axs[i // 5][i % 5].set_title(str(i)) |
|
|
plt.suptitle(f"Timesteps: {timesteps}") |
|
|
plt.savefig( |
|
|
os.path.join(unet_plots_path, f"epoch_{epoch}", f"plot {timesteps}.png") |
|
|
) |
|
|
plt.clf() |
|
|
|
|
|
latents = torch.randn((10, cfg.latent_channels, 8, 8)).to(device) |
|
|
labels = torch.arange(10).to(device) |
|
|
|
|
|
for t in tqdm(noise_scheduler.timesteps): |
|
|
with torch.no_grad(): |
|
|
noise_pred = unet( |
|
|
latents, t, class_labels=labels, encoder_hidden_states=None |
|
|
).sample |
|
|
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
recon_imgs = vae.decode(latents) |
|
|
if t == 999 or t % 100 == 0: |
|
|
plot(recon_imgs, t, epoch) |
|
|
|