FOUND / datasets /geometric_transforms.py
osimeoni's picture
FOUND - second
25cae60
raw
history blame
No virus
5.18 kB
"""
Code adapted from SelfMask: https://github.com/NoelShin/selfmask
"""
from random import randint, random, uniform
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from torchvision.transforms.functional import InterpolationMode as IM
def random_crop(
image: Union[Image.Image, np.ndarray, torch.Tensor],
crop_size: Tuple[int, int], # (h, w)
fill: Union[int, Tuple[int, int, int]], # an unsigned integer or RGB,
offset: Optional[Tuple[int, int]] = None, # (top, left) coordinate of a crop
):
assert type(crop_size) in (tuple, list) and len(crop_size) == 2
if isinstance(image, np.ndarray):
image = torch.tensor(image)
h, w = image.shape[-2:]
elif isinstance(image, Image.Image):
w, h = image.size
elif isinstance(image, torch.Tensor):
h, w = image.shape[-2:]
else:
raise TypeError(type(image))
pad_h, pad_w = max(crop_size[0] - h, 0), max(crop_size[1] - w, 0)
image = TF.pad(image, [0, 0, pad_w, pad_h], fill=fill, padding_mode="constant")
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[-2:]
if offset is None:
offset = (randint(0, h - crop_size[0]), randint(0, w - crop_size[1]))
image = TF.crop(
image, top=offset[0], left=offset[1], height=crop_size[0], width=crop_size[1]
)
return image, offset
def compute_size(
input_size: Tuple[int, int], output_size: int, edge: str # h, w
) -> Tuple[int, int]:
assert edge in ["shorter", "longer"]
h, w = input_size
if edge == "longer":
if w > h:
h = int(float(h) / w * output_size)
w = output_size
else:
w = int(float(w) / h * output_size)
h = output_size
assert w <= output_size and h <= output_size
else:
if w > h:
w = int(float(w) / h * output_size)
h = output_size
else:
h = int(float(h) / w * output_size)
w = output_size
assert w >= output_size and h >= output_size
return h, w
def resize(
image: Union[Image.Image, np.ndarray, torch.Tensor],
size: Union[int, Tuple[int, int]],
interpolation: str,
edge: str = "both",
) -> Union[Image.Image, torch.Tensor]:
"""
:param image: an image to be resized
:param size: a resulting image size
:param interpolation: sampling mode. ["nearest", "bilinear", "bicubic"]
:param edge: Default: "both"
No-op if a size is given as a tuple (h, w).
If set to "both", resize both height and width to the specified size.
If set to "shorter", resize the shorter edge to the specified size keeping the aspect ratio.
If set to "longer", resize the longer edge to the specified size keeping the aspect ratio.
:return: a resized image
"""
assert interpolation in ["nearest", "bilinear", "bicubic"], ValueError(
interpolation
)
assert edge in ["both", "shorter", "longer"], ValueError(edge)
interpolation = {
"nearest": IM.NEAREST,
"bilinear": IM.BILINEAR,
"bicubic": IM.BICUBIC,
}[interpolation]
if type(image) == torch.Tensor:
image = image.clone().detach()
elif type(image) == np.ndarray:
image = torch.from_numpy(image)
if type(size) is tuple:
if type(image) == torch.Tensor and len(image.shape) == 2:
image = TF.resize(
image.unsqueeze(dim=0), size=size, interpolation=interpolation
).squeeze(dim=0)
else:
image = TF.resize(image, size=size, interpolation=interpolation)
else:
if edge == "both":
image = TF.resize(image, size=[size, size], interpolation=interpolation)
else:
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[-2:]
rh, rw = compute_size(input_size=(h, w), output_size=size, edge=edge)
image = TF.resize(image, size=[rh, rw], interpolation=interpolation)
return image
def random_scale(
image: Union[Image.Image, np.ndarray, torch.Tensor],
random_scale_range: Tuple[float, float],
mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
):
scale = uniform(*random_scale_range)
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[-2:]
w_rs, h_rs = int(w * scale), int(h * scale)
image: Image.Image = resize(image, size=(h_rs, w_rs), interpolation="bilinear")
if mask is not None:
mask = resize(mask, size=(h_rs, w_rs), interpolation="nearest")
return image, mask
def random_hflip(
image: Union[Image.Image, np.ndarray, torch.Tensor],
p: float,
mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
):
assert 0.0 <= p <= 1.0, ValueError(random_hflip)
# Return a random floating point number in the range [0.0, 1.0).
if random() > p:
image = TF.hflip(image)
if mask is not None:
mask = TF.hflip(mask)
return image, mask