import collections.abc as collections from pathlib import Path from typing import Optional, Tuple import cv2 import kornia import numpy as np import torch from omegaconf import OmegaConf class ImagePreprocessor: default_conf = { "resize": None, # target edge length, None for no resizing "edge_divisible_by": None, "side": "long", "interpolation": "bilinear", "align_corners": None, "antialias": True, "square_pad": False, "add_padding_mask": False, } def __init__(self, conf) -> None: super().__init__() default_conf = OmegaConf.create(self.default_conf) OmegaConf.set_struct(default_conf, True) self.conf = OmegaConf.merge(default_conf, conf) def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict: """Resize and preprocess an image, return image and resize scale""" h, w = img.shape[-2:] size = h, w if self.conf.resize is not None: if interpolation is None: interpolation = self.conf.interpolation size = self.get_new_image_size(h, w) img = kornia.geometry.transform.resize( img, size, side=self.conf.side, antialias=self.conf.antialias, align_corners=self.conf.align_corners, interpolation=interpolation, ) scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) T = np.diag([scale[0], scale[1], 1]) data = { "scales": scale, "image_size": np.array(size[::-1]), "transform": T, "original_image_size": np.array([w, h]), } if self.conf.square_pad: sl = max(img.shape[-2:]) data["image"] = torch.zeros( *img.shape[:-2], sl, sl, device=img.device, dtype=img.dtype ) data["image"][:, : img.shape[-2], : img.shape[-1]] = img if self.conf.add_padding_mask: data["padding_mask"] = torch.zeros( *img.shape[:-3], 1, sl, sl, device=img.device, dtype=torch.bool ) data["padding_mask"][:, : img.shape[-2], : img.shape[-1]] = True else: data["image"] = img return data def load_image(self, image_path: Path) -> dict: return self(load_image(image_path)) def get_new_image_size( self, h: int, w: int, ) -> Tuple[int, int]: side = self.conf.side if isinstance(self.conf.resize, collections.Iterable): assert len(self.conf.resize) == 2 return tuple(self.conf.resize) side_size = self.conf.resize aspect_ratio = w / h if side not in ("short", "long", "vert", "horz"): raise ValueError( f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'" ) if side == "vert": size = side_size, int(side_size * aspect_ratio) elif side == "horz": size = int(side_size / aspect_ratio), side_size elif (side == "short") ^ (aspect_ratio < 1.0): size = side_size, int(side_size * aspect_ratio) else: size = int(side_size / aspect_ratio), side_size if self.conf.edge_divisible_by is not None: df = self.conf.edge_divisible_by size = list(map(lambda x: int(x // df * df), size)) return size def read_image(path: Path, grayscale: bool = False) -> np.ndarray: """Read an image from path as RGB or grayscale""" if not Path(path).exists(): raise FileNotFoundError(f"No image at path {path}.") mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR image = cv2.imread(str(path), mode) if image is None: raise IOError(f"Could not read image at {path}.") if not grayscale: image = image[..., ::-1] return image def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: """Normalize the image tensor and reorder the dimensions.""" if image.ndim == 3: image = image.transpose((2, 0, 1)) # HxWxC to CxHxW elif image.ndim == 2: image = image[None] # add channel axis else: raise ValueError(f"Not an image: {image.shape}") return torch.tensor(image / 255.0, dtype=torch.float) def load_image(path: Path, grayscale=False) -> torch.Tensor: image = read_image(path, grayscale=grayscale) return numpy_image_to_torch(image)