| """ |
| Visualization utilities for evaluation. |
| |
| Functions: |
| - make_image_grid: Create and optionally save a grid of images |
| - visualize_denoising: Visualize the denoising process |
| - format_prompt_caption: Format prompts for display in image captions |
| """ |
|
|
| import os |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
| import torchvision |
|
|
|
|
| def make_image_grid( |
| images: torch.Tensor, |
| rows: Optional[int] = None, |
| cols: Optional[int] = None, |
| save_path: Optional[str] = None, |
| normalize: bool = True, |
| value_range: Optional[tuple] = None, |
| ) -> torch.Tensor: |
| """ |
| Create a grid of images and optionally save it. |
| |
| Args: |
| images: Tensor of shape [B, C, H, W] |
| rows: Number of rows (optional) |
| cols: Number of columns (optional) |
| save_path: Path to save the grid image |
| normalize: Whether to normalize images to [0, 1] |
| value_range: Range of values in input images (min, max) |
| |
| Returns: |
| Grid tensor |
| """ |
| if rows is None and cols is None: |
| nrow = int(np.ceil(np.sqrt(images.shape[0]))) |
| elif cols is not None: |
| nrow = cols |
| else: |
| nrow = int(np.ceil(images.shape[0] / rows)) |
|
|
| grid = torchvision.utils.make_grid( |
| images, nrow=nrow, normalize=normalize, value_range=value_range, padding=2 |
| ) |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| torchvision.utils.save_image(grid, save_path) |
|
|
| return grid |
|
|
|
|
| def visualize_denoising( |
| intermediate_steps: List[torch.Tensor], save_path: str, num_steps_to_show: int = 10 |
| ): |
| """ |
| Visualize the denoising process by selecting a subset of steps. |
| |
| Args: |
| intermediate_steps: List of tensors [B, C, H, W] from the sampling process |
| save_path: Path to save the visualization |
| num_steps_to_show: Number of steps to display |
| """ |
| total_steps = len(intermediate_steps) |
| if total_steps < num_steps_to_show: |
| indices = list(range(total_steps)) |
| else: |
| indices = np.linspace(0, total_steps - 1, num_steps_to_show, dtype=int).tolist() |
|
|
| selected_steps = [intermediate_steps[i] for i in indices] |
|
|
| |
| first_sample_steps = [step[0] for step in selected_steps] |
|
|
| |
| stacked = torch.stack(first_sample_steps) |
|
|
| |
| make_image_grid( |
| stacked, |
| rows=1, |
| cols=len(selected_steps), |
| save_path=save_path, |
| normalize=True, |
| value_range=(-1, 1), |
| ) |
|
|
|
|
| def format_prompt_caption(prompts: List[str], limit: int = 32) -> str: |
| """ |
| Format a list of prompts for display as an image caption. |
| |
| Args: |
| prompts: List of prompt strings |
| limit: Maximum number of prompts to include |
| |
| Returns: |
| Formatted caption string |
| """ |
| if not prompts: |
| return "" |
| trimmed = [p.replace("\n", " ").strip() for p in prompts[:limit]] |
| lines = [f"{idx + 1}. {text}" for idx, text in enumerate(trimmed)] |
| remaining = len(prompts) - len(trimmed) |
| if remaining > 0: |
| lines.append(f"... (+{remaining} more)") |
| return "\n\n".join(lines) |
|
|
|
|