File size: 4,605 Bytes
8ec10cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import random
import numpy as np
from skimage.color import rgb2hsv, hsv2rgb
from skimage.transform import pyramid_gaussian

import torch

def _apply(func, x):

    if isinstance(x, (list, tuple)):
        return [_apply(func, x_i) for x_i in x]
    elif isinstance(x, dict):
        y = {}
        for key, value in x.items():
            y[key] = _apply(func, value)
        return y
    else:
        return func(x)

def crop(*args, ps=256):    # patch_size
    # args = [input, target]
    def _get_shape(*args):
        if isinstance(args[0], (list, tuple)):
            return _get_shape(args[0][0])
        elif isinstance(args[0], dict):
            return _get_shape(list(args[0].values())[0])
        else:
            return args[0].shape

    h, w, _ = _get_shape(args)

    py = random.randrange(0, h-ps+1)
    px = random.randrange(0, w-ps+1)

    def _crop(img):
        if img.ndim == 2:
            return img[py:py+ps, px:px+ps, np.newaxis]
        else:
            return img[py:py+ps, px:px+ps, :]

    return _apply(_crop, args)

def add_noise(*args, sigma_sigma=2, rgb_range=255):

    if len(args) == 1:  # usually there is only a single input
        args = args[0]

    sigma = np.random.normal() * sigma_sigma * rgb_range/255

    def _add_noise(img):
        noise = np.random.randn(*img.shape).astype(np.float32) * sigma
        return (img + noise).clip(0, rgb_range)

    return _apply(_add_noise, args)

def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255):
    """augmentation consistent to input and target"""

    choices = (False, True)

    hflip = hflip and random.choice(choices)
    vflip = rot and random.choice(choices)
    rot90 = rot and random.choice(choices)
    # shuffle = shuffle

    if shuffle:
        rgb_order = list(range(3))
        random.shuffle(rgb_order)
        if rgb_order == list(range(3)):
            shuffle = False

    if change_saturation:
        amp_factor = np.random.uniform(0.5, 1.5)

    def _augment(img):
        if hflip: img = img[:, ::-1, :]
        if vflip: img = img[::-1, :, :]
        if rot90: img = img.transpose(1, 0, 2)
        if shuffle and img.ndim > 2:
            if img.shape[-1] == 3:  # RGB image only
                img = img[..., rgb_order]

        if change_saturation:
            hsv_img = rgb2hsv(img)
            hsv_img[..., 1] *= amp_factor

            img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range

        return img.astype(np.float32)

    return _apply(_augment, args)

def pad(img, divisor=4, pad_width=None, negative=False):

    def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
        if pad_width is None:
            (h, w, _) = img.shape
            pad_h = -h % divisor
            pad_w = -w % divisor
            pad_width = ((0, pad_h), (0, pad_w), (0, 0))

        img = np.pad(img, pad_width, mode='edge')

        return img, pad_width

    def _pad_tensor(img, divisor=4, pad_width=None, negative=False):

        n, c, h, w = img.shape
        if pad_width is None:
            pad_h = -h % divisor
            pad_w = -w % divisor
            pad_width = (0, pad_w, 0, pad_h)
        else:
            try:
                pad_h = pad_width[0][1]
                pad_w = pad_width[1][1]
                if isinstance(pad_h, torch.Tensor):
                    pad_h = pad_h.item()
                if isinstance(pad_w, torch.Tensor):
                    pad_w = pad_w.item()

                pad_width = (0, pad_w, 0, pad_h)
            except:
                pass

            if negative:
                pad_width = [-val for val in pad_width]

        img = torch.nn.functional.pad(img, pad_width, 'reflect')

        return img, pad_width

    if isinstance(img, np.ndarray):
        return _pad_numpy(img, divisor, pad_width, negative)
    else:   # torch.Tensor
        return _pad_tensor(img, divisor, pad_width, negative)

def generate_pyramid(*args, n_scales):

    def _generate_pyramid(img):
        if img.dtype != np.float32:
            img = img.astype(np.float32)
        pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True))

        return pyramid

    return _apply(_generate_pyramid, args)

def np2tensor(*args):
    def _np2tensor(x):
        np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
        tensor = torch.from_numpy(np_transpose)

        return tensor

    return _apply(_np2tensor, args)

def to(*args, device=None, dtype=torch.float):

    def _to(x):
        return x.to(device=device, dtype=dtype, non_blocking=True, copy=False)

    return _apply(_to, args)