# python3.7 """Contains utility functions for image processing. The module is primarily built on `cv2`. But, differently, we assume all colorful images are with `RGB` channel order by default. Also, we assume all gray-scale images to be with shape [height, width, 1]. """ import os import cv2 import numpy as np from .misc import IMAGE_EXTENSIONS from .misc import check_file_ext __all__ = [ 'get_blank_image', 'load_image', 'save_image', 'resize_image', 'add_text_to_image', 'preprocess_image', 'postprocess_image', 'parse_image_size', 'get_grid_shape', 'list_images_from_dir' ] def _check_2d_image(image): """Checks whether a given image is valid. A valid image is expected to be with dtype `uint8`. Also, it should have shape like: (1) (height, width, 1) # gray-scale image. (2) (height, width, 3) # colorful image. (3) (height, width, 4) # colorful image with transparency (RGBA) """ assert isinstance(image, np.ndarray) assert image.dtype == np.uint8 assert image.ndim == 3 and image.shape[2] in [1, 3, 4] def get_blank_image(height, width, channels=3, use_black=True): """Gets a blank image, either white of black. NOTE: This function will always return an image with `RGB` channel order for color image and pixel range [0, 255]. Args: height: Height of the returned image. width: Width of the returned image. channels: Number of channels. (default: 3) use_black: Whether to return a black image. (default: True) """ shape = (height, width, channels) if use_black: return np.zeros(shape, dtype=np.uint8) return np.ones(shape, dtype=np.uint8) * 255 def load_image(path): """Loads an image from disk. NOTE: This function will always return an image with `RGB` channel order for color image and pixel range [0, 255]. Args: path: Path to load the image from. Returns: An image with dtype `np.ndarray`, or `None` if `path` does not exist. """ image = cv2.imread(path, cv2.IMREAD_UNCHANGED) if image is None: return None if image.ndim == 2: image = image[:, :, np.newaxis] _check_2d_image(image) if image.shape[2] == 3: return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image.shape[2] == 4: return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) return image def save_image(path, image): """Saves an image to disk. NOTE: The input image (if colorful) is assumed to be with `RGB` channel order and pixel range [0, 255]. Args: path: Path to save the image to. image: Image to save. """ if image is None: return _check_2d_image(image) if image.shape[2] == 1: cv2.imwrite(path, image) elif image.shape[2] == 3: cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) elif image.shape[2] == 4: cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) def resize_image(image, *args, **kwargs): """Resizes image. This is a wrap of `cv2.resize()`. NOTE: The channel order of the input image will not be changed. Args: image: Image to resize. *args: Additional positional arguments. **kwargs: Additional keyword arguments. Returns: An image with dtype `np.ndarray`, or `None` if `image` is empty. """ if image is None: return None _check_2d_image(image) if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image. return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] return cv2.resize(image, *args, **kwargs) def add_text_to_image(image, text='', position=None, font=cv2.FONT_HERSHEY_TRIPLEX, font_size=1.0, line_type=cv2.LINE_8, line_width=1, color=(255, 255, 255)): """Overlays text on given image. NOTE: The input image is assumed to be with `RGB` channel order. Args: image: The image to overlay text on. text: Text content to overlay on the image. (default: empty) position: Target position (bottom-left corner) to add text. If not set, center of the image will be used by default. (default: None) font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) font_size: Font size of the text added. (default: 1.0) line_type: Line type used to depict the text. (default: cv2.LINE_8) line_width: Line width used to depict the text. (default: 1) color: Color of the text added in `RGB` channel order. (default: (255, 255, 255)) Returns: An image with target text overlaid on. """ if image is None or not text: return image _check_2d_image(image) cv2.putText(img=image, text=text, org=position, fontFace=font, fontScale=font_size, color=color, thickness=line_width, lineType=line_type, bottomLeftOrigin=False) return image def preprocess_image(image, min_val=-1.0, max_val=1.0): """Pre-processes image by adjusting the pixel range and to dtype `float32`. This function is particularly used to convert an image or a batch of images to `NCHW` format, which matches the data type commonly used in deep models. NOTE: The input image is assumed to be with pixel range [0, 255] and with format `HWC` or `NHWC`. The returned image will be always be with format `NCHW`. Args: image: The input image for pre-processing. min_val: Minimum value of the output image. max_val: Maximum value of the output image. Returns: The pre-processed image. """ assert isinstance(image, np.ndarray) image = image.astype(np.float64) image = image / 255.0 * (max_val - min_val) + min_val if image.ndim == 3: image = image[np.newaxis] assert image.ndim == 4 and image.shape[3] in [1, 3, 4] return image.transpose(0, 3, 1, 2) def postprocess_image(image, min_val=-1.0, max_val=1.0): """Post-processes image to pixel range [0, 255] with dtype `uint8`. This function is particularly used to handle the results produced by deep models. NOTE: The input image is assumed to be with format `NCHW`, and the returned image will always be with format `NHWC`. Args: image: The input image for post-processing. min_val: Expected minimum value of the input image. max_val: Expected maximum value of the input image. Returns: The post-processed image. """ assert isinstance(image, np.ndarray) image = image.astype(np.float64) image = (image - min_val) / (max_val - min_val) * 255 image = np.clip(image + 0.5, 0, 255).astype(np.uint8) assert image.ndim == 4 and image.shape[1] in [1, 3, 4] return image.transpose(0, 2, 3, 1) def parse_image_size(obj): """Parses an object to a pair of image size, i.e., (height, width). Args: obj: The input object to parse image size from. Returns: A two-element tuple, indicating image height and width respectively. Raises: If the input is invalid, i.e., neither a list or tuple, nor a string. """ if obj is None or obj == '': height = 0 width = 0 elif isinstance(obj, int): height = obj width = obj elif isinstance(obj, (list, tuple, str, np.ndarray)): if isinstance(obj, str): splits = obj.replace(' ', '').split(',') numbers = tuple(map(int, splits)) else: numbers = tuple(obj) if len(numbers) == 0: height = 0 width = 0 elif len(numbers) == 1: height = int(numbers[0]) width = int(numbers[0]) elif len(numbers) == 2: height = int(numbers[0]) width = int(numbers[1]) else: raise ValueError('At most two elements for image size.') else: raise ValueError(f'Invalid type of input: `{type(obj)}`!') return (max(0, height), max(0, width)) def get_grid_shape(size, height=0, width=0, is_portrait=False): """Gets the shape of a grid based on the size. This function makes greatest effort on making the output grid square if neither `height` nor `width` is set. If `is_portrait` is set as `False`, the height will always be equal to or smaller than the width. For example, if input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape will be (3, 5). Otherwise, the height will always be equal to or larger than the width. Args: size: Size (height * width) of the target grid. height: Expected height. If `size % height != 0`, this field will be ignored. (default: 0) width: Expected width. If `size % width != 0`, this field will be ignored. (default: 0) is_portrait: Whether to return a portrait size of a landscape size. (default: False) Returns: A two-element tuple, representing height and width respectively. """ assert isinstance(size, int) assert isinstance(height, int) assert isinstance(width, int) if size <= 0: return (0, 0) if height > 0 and width > 0 and height * width != size: height = 0 width = 0 if height > 0 and width > 0 and height * width == size: return (height, width) if height > 0 and size % height == 0: return (height, size // height) if width > 0 and size % width == 0: return (size // width, width) height = int(np.sqrt(size)) while height > 0: if size % height == 0: width = size // height break height = height - 1 return (width, height) if is_portrait else (height, width) def list_images_from_dir(directory): """Lists all images from the given directory. NOTE: Do NOT support finding images recursively. Args: directory: The directory to find images from. Returns: A list of sorted filenames, with the directory as prefix. """ image_list = [] for filename in os.listdir(directory): if check_file_ext(filename, *IMAGE_EXTENSIONS): image_list.append(os.path.join(directory, filename)) return sorted(image_list)