Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
raw
history blame
No virus
4.6 kB
import collections.abc as collections
from pathlib import Path
from typing import Optional, Tuple
import cv2
import kornia
import numpy as np
import torch
from omegaconf import OmegaConf
class ImagePreprocessor:
default_conf = {
"resize": None, # target edge length, None for no resizing
"edge_divisible_by": None,
"side": "long",
"interpolation": "bilinear",
"align_corners": None,
"antialias": True,
"square_pad": False,
"add_padding_mask": False,
}
def __init__(self, conf) -> None:
super().__init__()
default_conf = OmegaConf.create(self.default_conf)
OmegaConf.set_struct(default_conf, True)
self.conf = OmegaConf.merge(default_conf, conf)
def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
"""Resize and preprocess an image, return image and resize scale"""
h, w = img.shape[-2:]
size = h, w
if self.conf.resize is not None:
if interpolation is None:
interpolation = self.conf.interpolation
size = self.get_new_image_size(h, w)
img = kornia.geometry.transform.resize(
img,
size,
side=self.conf.side,
antialias=self.conf.antialias,
align_corners=self.conf.align_corners,
interpolation=interpolation,
)
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
T = np.diag([scale[0], scale[1], 1])
data = {
"scales": scale,
"image_size": np.array(size[::-1]),
"transform": T,
"original_image_size": np.array([w, h]),
}
if self.conf.square_pad:
sl = max(img.shape[-2:])
data["image"] = torch.zeros(
*img.shape[:-2], sl, sl, device=img.device, dtype=img.dtype
)
data["image"][:, : img.shape[-2], : img.shape[-1]] = img
if self.conf.add_padding_mask:
data["padding_mask"] = torch.zeros(
*img.shape[:-3], 1, sl, sl, device=img.device, dtype=torch.bool
)
data["padding_mask"][:, : img.shape[-2], : img.shape[-1]] = True
else:
data["image"] = img
return data
def load_image(self, image_path: Path) -> dict:
return self(load_image(image_path))
def get_new_image_size(
self,
h: int,
w: int,
) -> Tuple[int, int]:
side = self.conf.side
if isinstance(self.conf.resize, collections.Iterable):
assert len(self.conf.resize) == 2
return tuple(self.conf.resize)
side_size = self.conf.resize
aspect_ratio = w / h
if side not in ("short", "long", "vert", "horz"):
raise ValueError(
f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
)
if side == "vert":
size = side_size, int(side_size * aspect_ratio)
elif side == "horz":
size = int(side_size / aspect_ratio), side_size
elif (side == "short") ^ (aspect_ratio < 1.0):
size = side_size, int(side_size * aspect_ratio)
else:
size = int(side_size / aspect_ratio), side_size
if self.conf.edge_divisible_by is not None:
df = self.conf.edge_divisible_by
size = list(map(lambda x: int(x // df * df), size))
return size
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
"""Read an image from path as RGB or grayscale"""
if not Path(path).exists():
raise FileNotFoundError(f"No image at path {path}.")
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
image = cv2.imread(str(path), mode)
if image is None:
raise IOError(f"Could not read image at {path}.")
if not grayscale:
image = image[..., ::-1]
return image
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
"""Normalize the image tensor and reorder the dimensions."""
if image.ndim == 3:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f"Not an image: {image.shape}")
return torch.tensor(image / 255.0, dtype=torch.float)
def load_image(path: Path, grayscale=False) -> torch.Tensor:
image = read_image(path, grayscale=grayscale)
return numpy_image_to_torch(image)