alexnasa's picture
Upload 243 files
2568013 verified
"""This file contains useful layout utilities for images. They are:
- add_border: Add a border to an image.
- cat/hcat/vcat: Join images by arranging them in a line. If the images have different
sizes, they are aligned as specified (start, end, center). Allows you to specify a gap
between images.
Images are assumed to be float32 tensors with shape (channel, height, width).
"""
from typing import Any, Generator, Iterable, Literal, Optional, Union
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
Alignment = Literal["start", "center", "end"]
Axis = Literal["horizontal", "vertical"]
Color = Union[
int,
float,
Iterable[int],
Iterable[float],
Float[Tensor, "#channel"],
Float[Tensor, ""],
]
def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]:
# Convert tensor to list (or individual item).
if isinstance(color, torch.Tensor):
color = color.tolist()
# Turn iterators and individual items into lists.
if isinstance(color, Iterable):
color = list(color)
else:
color = [color]
return torch.tensor(color, dtype=torch.float32)
def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]:
it = iter(iterable)
yield next(it)
for item in it:
yield delimiter
yield item
def _get_main_dim(main_axis: Axis) -> int:
return {
"horizontal": 2,
"vertical": 1,
}[main_axis]
def _get_cross_dim(main_axis: Axis) -> int:
return {
"horizontal": 1,
"vertical": 2,
}[main_axis]
def _compute_offset(base: int, overlay: int, align: Alignment) -> slice:
assert base >= overlay
offset = {
"start": 0,
"center": (base - overlay) // 2,
"end": base - overlay,
}[align]
return slice(offset, offset + overlay)
def overlay(
base: Float[Tensor, "channel base_height base_width"],
overlay: Float[Tensor, "channel overlay_height overlay_width"],
main_axis: Axis,
main_axis_alignment: Alignment,
cross_axis_alignment: Alignment,
) -> Float[Tensor, "channel base_height base_width"]:
# The overlay must be smaller than the base.
_, base_height, base_width = base.shape
_, overlay_height, overlay_width = overlay.shape
assert base_height >= overlay_height and base_width >= overlay_width
# Compute spacing on the main dimension.
main_dim = _get_main_dim(main_axis)
main_slice = _compute_offset(
base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment
)
# Compute spacing on the cross dimension.
cross_dim = _get_cross_dim(main_axis)
cross_slice = _compute_offset(
base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment
)
# Combine the slices and paste the overlay onto the base accordingly.
selector = [..., None, None]
selector[main_dim] = main_slice
selector[cross_dim] = cross_slice
result = base.clone()
result[selector] = overlay
return result
def cat(
main_axis: Axis,
*images: Iterable[Float[Tensor, "channel _ _"]],
align: Alignment = "center",
gap: int = 8,
gap_color: Color = 1,
) -> Float[Tensor, "channel height width"]:
"""Arrange images in a line. The interface resembles a CSS div with flexbox."""
device = images[0].device
gap_color = _sanitize_color(gap_color).to(device)
# Find the maximum image side length in the cross axis dimension.
cross_dim = _get_cross_dim(main_axis)
cross_axis_length = max(image.shape[cross_dim] for image in images)
# Pad the images.
padded_images = []
for image in images:
# Create an empty image with the correct size.
padded_shape = list(image.shape)
padded_shape[cross_dim] = cross_axis_length
base = torch.ones(padded_shape, dtype=torch.float32, device=device)
base = base * gap_color[:, None, None]
padded_images.append(overlay(base, image, main_axis, "start", align))
# Intersperse separators if necessary.
if gap > 0:
# Generate a separator.
c, _, _ = images[0].shape
separator_size = [gap, gap]
separator_size[cross_dim - 1] = cross_axis_length
separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device)
separator = separator * gap_color[:, None, None]
# Intersperse the separator between the images.
padded_images = list(_intersperse(padded_images, separator))
return torch.cat(padded_images, dim=_get_main_dim(main_axis))
def hcat(
*images: Iterable[Float[Tensor, "channel _ _"]],
align: Literal["start", "center", "end", "top", "bottom"] = "start",
gap: int = 8,
gap_color: Color = 1,
):
"""Shorthand for a horizontal linear concatenation."""
return cat(
"horizontal",
*images,
align={
"start": "start",
"center": "center",
"end": "end",
"top": "start",
"bottom": "end",
}[align],
gap=gap,
gap_color=gap_color,
)
def vcat(
*images: Iterable[Float[Tensor, "channel _ _"]],
align: Literal["start", "center", "end", "left", "right"] = "start",
gap: int = 8,
gap_color: Color = 1,
):
"""Shorthand for a horizontal linear concatenation."""
return cat(
"vertical",
*images,
align={
"start": "start",
"center": "center",
"end": "end",
"left": "start",
"right": "end",
}[align],
gap=gap,
gap_color=gap_color,
)
def add_border(
image: Float[Tensor, "channel height width"],
border: int = 8,
color: Color = 1,
) -> Float[Tensor, "channel new_height new_width"]:
color = _sanitize_color(color).to(image)
c, h, w = image.shape
result = torch.empty(
(c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device
)
result[:] = color[:, None, None]
result[:, border : h + border, border : w + border] = image
return result
def resize(
image: Float[Tensor, "channel height width"],
shape: Optional[tuple[int, int]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> Float[Tensor, "channel new_height new_width"]:
assert (shape is not None) + (width is not None) + (height is not None) == 1
_, h, w = image.shape
if width is not None:
shape = (int(h * width / w), width)
elif height is not None:
shape = (height, int(w * height / h))
return F.interpolate(
image[None],
shape,
mode="bilinear",
align_corners=False,
antialias="bilinear",
)[0]