File size: 2,302 Bytes
1803579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
"""

import numpy as np
import torch
from PIL import Image
from typing import Optional, Tuple, Union
from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale

from datasets.utils import GaussianBlur
from datasets.geometric_transforms import (
    random_scale,
    random_crop,
    random_hflip,
)


def geometric_augmentations(
    image: Image.Image,
    random_scale_range: Optional[Tuple[float, float]] = None,
    random_crop_size: Optional[int] = None,
    random_hflip_p: Optional[float] = None,
    mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
    ignore_index: Optional[int] = None,
) -> Tuple[Image.Image, torch.Tensor]:
    """Note. image and mask are assumed to be of base size, thus share a spatial shape."""
    if random_scale_range is not None:
        image, mask = random_scale(
            image=image, random_scale_range=random_scale_range, mask=mask
        )

    if random_crop_size is not None:
        crop_size = (random_crop_size, random_crop_size)
        fill = tuple(np.array(image).mean(axis=(0, 1)).astype(np.uint8).tolist())
        image, offset = random_crop(image=image, crop_size=crop_size, fill=fill)

        if mask is not None:
            assert ignore_index is not None
            mask = random_crop(
                image=mask, crop_size=crop_size, fill=ignore_index, offset=offset
            )[0]

    if random_hflip_p is not None:
        image, mask = random_hflip(image=image, p=random_hflip_p, mask=mask)
    return image, mask


def photometric_augmentations(
    image: Image.Image,
    random_color_jitter: bool,
    random_grayscale: bool,
    random_gaussian_blur: bool,
    proba_photometric_aug: float,
) -> torch.Tensor:
    if random_color_jitter:
        color_jitter = ColorJitter(
            brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
        )
        image = RandomApply([color_jitter], p=proba_photometric_aug)(image)

    if random_grayscale:
        image = RandomGrayscale(proba_photometric_aug)(image)

    if random_gaussian_blur:
        w, h = image.size
        image = GaussianBlur(kernel_size=int((0.1 * min(w, h) // 2 * 2) + 1))(
            image, proba_photometric_aug
        )
    return image