import random from typing import Any, List, Optional, Union import blobfile as bf import numpy as np import torch import torch.nn.functional as F from PIL import Image def center_crop( img: Union[Image.Image, torch.Tensor, np.ndarray] ) -> Union[Image.Image, torch.Tensor, np.ndarray]: """ Center crops an image. """ if isinstance(img, (np.ndarray, torch.Tensor)): height, width = img.shape[:2] else: width, height = img.size size = min(width, height) left, top = (width - size) // 2, (height - size) // 2 right, bottom = left + size, top + size if isinstance(img, (np.ndarray, torch.Tensor)): img = img[top:bottom, left:right] else: img = img.crop((left, top, right, bottom)) return img def resize( img: Union[Image.Image, torch.Tensor, np.ndarray], *, height: int, width: int, min_value: Optional[Any] = None, max_value: Optional[Any] = None, ) -> Union[Image.Image, torch.Tensor, np.ndarray]: """ :param: img: image in HWC order :return: currently written for downsampling """ orig, cls = img, type(img) if isinstance(img, Image.Image): img = np.array(img) dtype = img.dtype if isinstance(img, np.ndarray): img = torch.from_numpy(img) ndim = img.ndim if img.ndim == 2: img = img.unsqueeze(-1) if min_value is None and max_value is None: # .clamp throws an error when both are None min_value = -np.inf img = img.permute(2, 0, 1) size = (height, width) img = ( F.interpolate(img[None].float(), size=size, mode="area")[0] .clamp(min_value, max_value) .to(img.dtype) .permute(1, 2, 0) ) if ndim < img.ndim: img = img.squeeze(-1) if not isinstance(orig, torch.Tensor): img = img.numpy() img = img.astype(dtype) if isinstance(orig, Image.Image): img = Image.fromarray(img) return img def get_alpha(img: Image.Image) -> Image.Image: """ :return: the alpha channel separated out as a grayscale image """ img_arr = np.asarray(img) if img_arr.shape[2] == 4: alpha = img_arr[:, :, 3] else: alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8) alpha = Image.fromarray(alpha) return alpha def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image: """ No op if the image doesn't have an alpha channel. :param: mode: Defaults to "random" but has an option to use a "black" or "white" background :return: image with alpha removed """ img_arr = np.asarray(img) if img_arr.shape[2] == 4: # Add bg to get rid of alpha channel if mode == "random": height, width = img_arr.shape[:2] bg = Image.fromarray( random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width) ) bg.paste(img, mask=img) img = bg elif mode == "black" or mode == "white": img_arr = img_arr.astype(float) rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255 background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255) rgb = rgb * alpha + background * (1 - alpha) img = Image.fromarray(np.round(rgb).astype(np.uint8)) return img def _black_bg(h: int, w: int) -> np.ndarray: return np.zeros([h, w, 3], dtype=np.uint8) def _gray_bg(h: int, w: int) -> np.ndarray: return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8) def _checker_bg(h: int, w: int) -> np.ndarray: checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w)))) c1 = np.random.randint(low=0, high=256) c2 = np.random.randint(low=0, high=256) xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1) ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1) fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0) return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8) def _noise_bg(h: int, w: int) -> np.ndarray: return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8) def load_image(image_path: str) -> Image.Image: with bf.BlobFile(image_path, "rb") as thefile: img = Image.open(thefile) img.load() return img def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image: """ to test, run >>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)])) """ images = list(map(np.array, images)) size = images[0].shape[0] n = round_up(len(images), columns) n_blanks = n - len(images) images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks) images = ( np.array(images) .reshape(n // columns, columns, size, size, 3) .transpose([0, 2, 1, 3, 4]) .reshape(n // columns * size, columns * size, 3) ) return Image.fromarray(images) def round_up(n: int, b: int) -> int: return (n + b - 1) // b * b