| """Utils functions for visualization.""" |
|
|
| import torch |
| import torchvision.transforms.functional as F |
| from einops import rearrange |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| def make_viz_from_samples( |
| original_images, |
| reconstructed_images |
| ): |
| """Generates visualization images from original images and reconstructed images. |
| |
| Args: |
| original_images: A torch.Tensor, original images. |
| reconstructed_images: A torch.Tensor, reconstructed images. |
| |
| Returns: |
| A tuple containing two lists - images_for_saving and images_for_logging. |
| """ |
| reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) |
| reconstructed_images = reconstructed_images * 255.0 |
| reconstructed_images = reconstructed_images.cpu() |
| |
| original_images = torch.clamp(original_images, 0.0, 1.0) |
| original_images *= 255.0 |
| original_images = original_images.cpu() |
|
|
| diff_img = torch.abs(original_images - reconstructed_images) |
| to_stack = [original_images, reconstructed_images, diff_img] |
|
|
| images_for_logging = rearrange( |
| torch.stack(to_stack), |
| "(l1 l2) b c h w -> b c (l1 h) (l2 w)", |
| l1=1).byte() |
| images_for_saving = [F.to_pil_image(image) for image in images_for_logging] |
|
|
| return images_for_saving, images_for_logging |
|
|
|
|
| def make_viz_from_samples_generation( |
| generated_images, |
| ): |
| generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 |
| images_for_logging = rearrange( |
| generated, |
| "(l1 l2) c h w -> c (l1 h) (l2 w)", |
| l1=2) |
|
|
| images_for_logging = images_for_logging.cpu().byte() |
| images_for_saving = F.to_pil_image(images_for_logging) |
|
|
| return images_for_saving, images_for_logging |
|
|
|
|
| def make_viz_from_samples_t2i_generation( |
| generated_images, |
| captions, |
| ): |
| generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 |
| images_for_logging = rearrange( |
| generated, |
| "(l1 l2) c h w -> c (l1 h) (l2 w)", |
| l1=2) |
|
|
| images_for_logging = images_for_logging.cpu().byte() |
| images_for_saving = F.to_pil_image(images_for_logging) |
|
|
| |
| width, height = images_for_saving.size |
| caption_height = 20 * len(captions) + 10 |
| new_height = height + caption_height |
| new_image = Image.new("RGB", (width, new_height), "black") |
| new_image.paste(images_for_saving, (0, 0)) |
|
|
| |
| draw = ImageDraw.Draw(new_image) |
| font = ImageFont.load_default() |
|
|
| for i, caption in enumerate(captions): |
| draw.text((10, height + 10 + i * 20), caption, fill="white", font=font) |
|
|
| return new_image, images_for_logging |