|
"""This file contains useful layout utilities for images. They are: |
|
|
|
- add_border: Add a border to an image. |
|
- cat/hcat/vcat: Join images by arranging them in a line. If the images have different |
|
sizes, they are aligned as specified (start, end, center). Allows you to specify a gap |
|
between images. |
|
|
|
Images are assumed to be float32 tensors with shape (channel, height, width). |
|
""" |
|
|
|
from typing import Any, Generator, Iterable, Literal, Optional, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
Alignment = Literal["start", "center", "end"] |
|
Axis = Literal["horizontal", "vertical"] |
|
Color = Union[ |
|
int, |
|
float, |
|
Iterable[int], |
|
Iterable[float], |
|
Float[Tensor, "#channel"], |
|
Float[Tensor, ""], |
|
] |
|
|
|
|
|
def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]: |
|
|
|
if isinstance(color, torch.Tensor): |
|
color = color.tolist() |
|
|
|
|
|
if isinstance(color, Iterable): |
|
color = list(color) |
|
else: |
|
color = [color] |
|
|
|
return torch.tensor(color, dtype=torch.float32) |
|
|
|
|
|
def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: |
|
it = iter(iterable) |
|
yield next(it) |
|
for item in it: |
|
yield delimiter |
|
yield item |
|
|
|
|
|
def _get_main_dim(main_axis: Axis) -> int: |
|
return { |
|
"horizontal": 2, |
|
"vertical": 1, |
|
}[main_axis] |
|
|
|
|
|
def _get_cross_dim(main_axis: Axis) -> int: |
|
return { |
|
"horizontal": 1, |
|
"vertical": 2, |
|
}[main_axis] |
|
|
|
|
|
def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: |
|
assert base >= overlay |
|
offset = { |
|
"start": 0, |
|
"center": (base - overlay) // 2, |
|
"end": base - overlay, |
|
}[align] |
|
return slice(offset, offset + overlay) |
|
|
|
|
|
def overlay( |
|
base: Float[Tensor, "channel base_height base_width"], |
|
overlay: Float[Tensor, "channel overlay_height overlay_width"], |
|
main_axis: Axis, |
|
main_axis_alignment: Alignment, |
|
cross_axis_alignment: Alignment, |
|
) -> Float[Tensor, "channel base_height base_width"]: |
|
|
|
_, base_height, base_width = base.shape |
|
_, overlay_height, overlay_width = overlay.shape |
|
assert base_height >= overlay_height and base_width >= overlay_width |
|
|
|
|
|
main_dim = _get_main_dim(main_axis) |
|
main_slice = _compute_offset( |
|
base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment |
|
) |
|
|
|
|
|
cross_dim = _get_cross_dim(main_axis) |
|
cross_slice = _compute_offset( |
|
base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment |
|
) |
|
|
|
|
|
selector = [..., None, None] |
|
selector[main_dim] = main_slice |
|
selector[cross_dim] = cross_slice |
|
result = base.clone() |
|
result[selector] = overlay |
|
return result |
|
|
|
|
|
def cat( |
|
main_axis: Axis, |
|
*images: Iterable[Float[Tensor, "channel _ _"]], |
|
align: Alignment = "center", |
|
gap: int = 8, |
|
gap_color: Color = 1, |
|
) -> Float[Tensor, "channel height width"]: |
|
"""Arrange images in a line. The interface resembles a CSS div with flexbox.""" |
|
device = images[0].device |
|
gap_color = _sanitize_color(gap_color).to(device) |
|
|
|
|
|
cross_dim = _get_cross_dim(main_axis) |
|
cross_axis_length = max(image.shape[cross_dim] for image in images) |
|
|
|
|
|
padded_images = [] |
|
for image in images: |
|
|
|
padded_shape = list(image.shape) |
|
padded_shape[cross_dim] = cross_axis_length |
|
base = torch.ones(padded_shape, dtype=torch.float32, device=device) |
|
base = base * gap_color[:, None, None] |
|
padded_images.append(overlay(base, image, main_axis, "start", align)) |
|
|
|
|
|
if gap > 0: |
|
|
|
c, _, _ = images[0].shape |
|
separator_size = [gap, gap] |
|
separator_size[cross_dim - 1] = cross_axis_length |
|
separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) |
|
separator = separator * gap_color[:, None, None] |
|
|
|
|
|
padded_images = list(_intersperse(padded_images, separator)) |
|
|
|
return torch.cat(padded_images, dim=_get_main_dim(main_axis)) |
|
|
|
|
|
def hcat( |
|
*images: Iterable[Float[Tensor, "channel _ _"]], |
|
align: Literal["start", "center", "end", "top", "bottom"] = "start", |
|
gap: int = 8, |
|
gap_color: Color = 1, |
|
): |
|
"""Shorthand for a horizontal linear concatenation.""" |
|
return cat( |
|
"horizontal", |
|
*images, |
|
align={ |
|
"start": "start", |
|
"center": "center", |
|
"end": "end", |
|
"top": "start", |
|
"bottom": "end", |
|
}[align], |
|
gap=gap, |
|
gap_color=gap_color, |
|
) |
|
|
|
|
|
def vcat( |
|
*images: Iterable[Float[Tensor, "channel _ _"]], |
|
align: Literal["start", "center", "end", "left", "right"] = "start", |
|
gap: int = 8, |
|
gap_color: Color = 1, |
|
): |
|
"""Shorthand for a horizontal linear concatenation.""" |
|
return cat( |
|
"vertical", |
|
*images, |
|
align={ |
|
"start": "start", |
|
"center": "center", |
|
"end": "end", |
|
"left": "start", |
|
"right": "end", |
|
}[align], |
|
gap=gap, |
|
gap_color=gap_color, |
|
) |
|
|
|
|
|
def add_border( |
|
image: Float[Tensor, "channel height width"], |
|
border: int = 8, |
|
color: Color = 1, |
|
) -> Float[Tensor, "channel new_height new_width"]: |
|
color = _sanitize_color(color).to(image) |
|
c, h, w = image.shape |
|
result = torch.empty( |
|
(c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device |
|
) |
|
result[:] = color[:, None, None] |
|
result[:, border : h + border, border : w + border] = image |
|
return result |
|
|
|
|
|
def resize( |
|
image: Float[Tensor, "channel height width"], |
|
shape: Optional[tuple[int, int]] = None, |
|
width: Optional[int] = None, |
|
height: Optional[int] = None, |
|
) -> Float[Tensor, "channel new_height new_width"]: |
|
assert (shape is not None) + (width is not None) + (height is not None) == 1 |
|
_, h, w = image.shape |
|
|
|
if width is not None: |
|
shape = (int(h * width / w), width) |
|
elif height is not None: |
|
shape = (height, int(w * height / h)) |
|
|
|
return F.interpolate( |
|
image[None], |
|
shape, |
|
mode="bilinear", |
|
align_corners=False, |
|
antialias="bilinear", |
|
)[0] |
|
|