Spaces:
Runtime error
Runtime error
import math | |
from typing import List, Optional, Union | |
import numpy as np | |
import torch | |
from PIL import Image | |
def tensor_to_image( | |
data: Union[Image.Image, torch.Tensor, np.ndarray], | |
batched: bool = False, | |
format: str = "HWC", | |
) -> Union[Image.Image, List[Image.Image]]: | |
if isinstance(data, Image.Image): | |
return data | |
if isinstance(data, torch.Tensor): | |
data = data.detach().cpu().numpy() | |
if data.dtype == np.float32 or data.dtype == np.float16: | |
data = (data * 255).astype(np.uint8) | |
elif data.dtype == np.bool_: | |
data = data.astype(np.uint8) * 255 | |
assert data.dtype == np.uint8 | |
if format == "CHW": | |
if batched and data.ndim == 4: | |
data = data.transpose((0, 2, 3, 1)) | |
elif not batched and data.ndim == 3: | |
data = data.transpose((1, 2, 0)) | |
if batched: | |
return [Image.fromarray(d) for d in data] | |
return Image.fromarray(data) | |
def largest_factor_near_sqrt(n: int) -> int: | |
""" | |
Finds the largest factor of n that is closest to the square root of n. | |
Args: | |
n (int): The integer for which to find the largest factor near its square root. | |
Returns: | |
int: The largest factor of n that is closest to the square root of n. | |
""" | |
sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root | |
# First, check if the square root itself is a factor | |
if sqrt_n * sqrt_n == n: | |
return sqrt_n | |
# Otherwise, find the largest factor by iterating from sqrt_n downwards | |
for i in range(sqrt_n, 0, -1): | |
if n % i == 0: | |
return i | |
# If n is 1, return 1 | |
return 1 | |
def make_image_grid( | |
images: List[Image.Image], | |
rows: Optional[int] = None, | |
cols: Optional[int] = None, | |
resize: Optional[int] = None, | |
) -> Image.Image: | |
""" | |
Prepares a single grid of images. Useful for visualization purposes. | |
""" | |
if rows is None and cols is not None: | |
assert len(images) % cols == 0 | |
rows = len(images) // cols | |
elif cols is None and rows is not None: | |
assert len(images) % rows == 0 | |
cols = len(images) // rows | |
elif rows is None and cols is None: | |
rows = largest_factor_near_sqrt(len(images)) | |
cols = len(images) // rows | |
assert len(images) == rows * cols | |
if resize is not None: | |
images = [img.resize((resize, resize)) for img in images] | |
w, h = images[0].size | |
grid = Image.new("RGB", size=(cols * w, rows * h)) | |
for i, img in enumerate(images): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |