|
""" |
|
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from typing import Optional, Tuple, Union |
|
from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale |
|
|
|
from datasets.utils import GaussianBlur |
|
from datasets.geometric_transforms import ( |
|
random_scale, |
|
random_crop, |
|
random_hflip, |
|
) |
|
|
|
def geometric_augmentations( |
|
image: Image.Image, |
|
random_scale_range: Optional[Tuple[float, float]] = None, |
|
random_crop_size: Optional[int] = None, |
|
random_hflip_p: Optional[float] = None, |
|
mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
) -> Tuple[Image.Image, torch.Tensor]: |
|
"""Note. image and mask are assumed to be of base size, thus share a spatial shape.""" |
|
if random_scale_range is not None: |
|
image, mask = random_scale( |
|
image=image, random_scale_range=random_scale_range, mask=mask |
|
) |
|
|
|
if random_crop_size is not None: |
|
crop_size = (random_crop_size, random_crop_size) |
|
fill = tuple(np.array(image).mean(axis=(0, 1)).astype(np.uint8).tolist()) |
|
image, offset = random_crop(image=image, crop_size=crop_size, fill=fill) |
|
|
|
if mask is not None: |
|
assert ignore_index is not None |
|
mask = random_crop( |
|
image=mask, crop_size=crop_size, fill=ignore_index, offset=offset |
|
)[0] |
|
|
|
if random_hflip_p is not None: |
|
image, mask = random_hflip(image=image, p=random_hflip_p, mask=mask) |
|
return image, mask |
|
|
|
def photometric_augmentations( |
|
image: Image.Image, |
|
random_color_jitter: bool, |
|
random_grayscale: bool, |
|
random_gaussian_blur: bool, |
|
proba_photometric_aug: float, |
|
) -> torch.Tensor: |
|
if random_color_jitter: |
|
color_jitter = ColorJitter( |
|
brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2 |
|
) |
|
image = RandomApply([color_jitter], p=proba_photometric_aug)(image) |
|
|
|
if random_grayscale: |
|
image = RandomGrayscale(proba_photometric_aug)(image) |
|
|
|
if random_gaussian_blur: |
|
w, h = image.size |
|
image = GaussianBlur(kernel_size=int((0.1 * min(w, h) // 2 * 2) + 1))( |
|
image, proba_photometric_aug |
|
) |
|
return image |