Spaces:
Paused
Paused
from typing import List | |
import PIL.Image | |
import PIL.ImageOps | |
from packaging import version | |
from PIL import Image | |
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): | |
PIL_INTERPOLATION = { | |
"linear": PIL.Image.Resampling.BILINEAR, | |
"bilinear": PIL.Image.Resampling.BILINEAR, | |
"bicubic": PIL.Image.Resampling.BICUBIC, | |
"lanczos": PIL.Image.Resampling.LANCZOS, | |
"nearest": PIL.Image.Resampling.NEAREST, | |
} | |
else: | |
PIL_INTERPOLATION = { | |
"linear": PIL.Image.LINEAR, | |
"bilinear": PIL.Image.BILINEAR, | |
"bicubic": PIL.Image.BICUBIC, | |
"lanczos": PIL.Image.LANCZOS, | |
"nearest": PIL.Image.NEAREST, | |
} | |
def pt_to_pil(images): | |
""" | |
Convert a torch image to a PIL image. | |
""" | |
images = (images / 2 + 0.5).clamp(0, 1) | |
images = images.cpu().permute(0, 2, 3, 1).float().numpy() | |
images = numpy_to_pil(images) | |
return images | |
def numpy_to_pil(images): | |
""" | |
Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: | |
""" | |
Prepares a single grid of images. Useful for visualization purposes. | |
""" | |
assert len(images) == rows * cols | |
if resize is not None: | |
images = [img.resize((resize, resize)) for img in images] | |
w, h = images[0].size | |
grid = Image.new("RGB", size=(cols * w, rows * h)) | |
for i, img in enumerate(images): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |