Spaces:
Runtime error
Runtime error
| import math | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| def tensor_to_image( | |
| data: Union[Image.Image, torch.Tensor, np.ndarray], | |
| batched: bool = False, | |
| format: str = "HWC", | |
| ) -> Union[Image.Image, List[Image.Image]]: | |
| if isinstance(data, Image.Image): | |
| return data | |
| if isinstance(data, torch.Tensor): | |
| data = data.detach().cpu().numpy() | |
| if data.dtype == np.float32 or data.dtype == np.float16: | |
| data = (data * 255).astype(np.uint8) | |
| elif data.dtype == np.bool_: | |
| data = data.astype(np.uint8) * 255 | |
| assert data.dtype == np.uint8 | |
| if format == "CHW": | |
| if batched and data.ndim == 4: | |
| data = data.transpose((0, 2, 3, 1)) | |
| elif not batched and data.ndim == 3: | |
| data = data.transpose((1, 2, 0)) | |
| if batched: | |
| return [Image.fromarray(d) for d in data] | |
| return Image.fromarray(data) | |
| def largest_factor_near_sqrt(n: int) -> int: | |
| """ | |
| Finds the largest factor of n that is closest to the square root of n. | |
| Args: | |
| n (int): The integer for which to find the largest factor near its square root. | |
| Returns: | |
| int: The largest factor of n that is closest to the square root of n. | |
| """ | |
| sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root | |
| # First, check if the square root itself is a factor | |
| if sqrt_n * sqrt_n == n: | |
| return sqrt_n | |
| # Otherwise, find the largest factor by iterating from sqrt_n downwards | |
| for i in range(sqrt_n, 0, -1): | |
| if n % i == 0: | |
| return i | |
| # If n is 1, return 1 | |
| return 1 | |
| def make_image_grid( | |
| images: List[Image.Image], | |
| rows: Optional[int] = None, | |
| cols: Optional[int] = None, | |
| resize: Optional[int] = None, | |
| ) -> Image.Image: | |
| """ | |
| Prepares a single grid of images. Useful for visualization purposes. | |
| """ | |
| if rows is None and cols is not None: | |
| assert len(images) % cols == 0 | |
| rows = len(images) // cols | |
| elif cols is None and rows is not None: | |
| assert len(images) % rows == 0 | |
| cols = len(images) // rows | |
| elif rows is None and cols is None: | |
| rows = largest_factor_near_sqrt(len(images)) | |
| cols = len(images) // rows | |
| assert len(images) == rows * cols | |
| if resize is not None: | |
| images = [img.resize((resize, resize)) for img in images] | |
| w, h = images[0].size | |
| grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| for i, img in enumerate(images): | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| return grid | |