Spaces:
Sleeping
Sleeping
""" | |
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask | |
""" | |
import numpy as np | |
import torch | |
from PIL import Image | |
from typing import Optional, Tuple, Union | |
from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale | |
from datasets.utils import GaussianBlur | |
from datasets.geometric_transforms import ( | |
random_scale, | |
random_crop, | |
random_hflip, | |
) | |
def geometric_augmentations( | |
image: Image.Image, | |
random_scale_range: Optional[Tuple[float, float]] = None, | |
random_crop_size: Optional[int] = None, | |
random_hflip_p: Optional[float] = None, | |
mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None, | |
ignore_index: Optional[int] = None, | |
) -> Tuple[Image.Image, torch.Tensor]: | |
"""Note. image and mask are assumed to be of base size, thus share a spatial shape.""" | |
if random_scale_range is not None: | |
image, mask = random_scale( | |
image=image, random_scale_range=random_scale_range, mask=mask | |
) | |
if random_crop_size is not None: | |
crop_size = (random_crop_size, random_crop_size) | |
fill = tuple(np.array(image).mean(axis=(0, 1)).astype(np.uint8).tolist()) | |
image, offset = random_crop(image=image, crop_size=crop_size, fill=fill) | |
if mask is not None: | |
assert ignore_index is not None | |
mask = random_crop( | |
image=mask, crop_size=crop_size, fill=ignore_index, offset=offset | |
)[0] | |
if random_hflip_p is not None: | |
image, mask = random_hflip(image=image, p=random_hflip_p, mask=mask) | |
return image, mask | |
def photometric_augmentations( | |
image: Image.Image, | |
random_color_jitter: bool, | |
random_grayscale: bool, | |
random_gaussian_blur: bool, | |
proba_photometric_aug: float, | |
) -> torch.Tensor: | |
if random_color_jitter: | |
color_jitter = ColorJitter( | |
brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2 | |
) | |
image = RandomApply([color_jitter], p=proba_photometric_aug)(image) | |
if random_grayscale: | |
image = RandomGrayscale(proba_photometric_aug)(image) | |
if random_gaussian_blur: | |
w, h = image.size | |
image = GaussianBlur(kernel_size=int((0.1 * min(w, h) // 2 * 2) + 1))( | |
image, proba_photometric_aug | |
) | |
return image | |