|
import random
|
|
import numpy as np
|
|
from skimage.color import rgb2hsv, hsv2rgb
|
|
|
|
|
|
import torch
|
|
|
|
def _apply(func, x):
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
return [_apply(func, x_i) for x_i in x]
|
|
elif isinstance(x, dict):
|
|
y = {}
|
|
for key, value in x.items():
|
|
y[key] = _apply(func, value)
|
|
return y
|
|
else:
|
|
return func(x)
|
|
|
|
def crop(*args, ps=256):
|
|
|
|
def _get_shape(*args):
|
|
if isinstance(args[0], (list, tuple)):
|
|
return _get_shape(args[0][0])
|
|
elif isinstance(args[0], dict):
|
|
return _get_shape(list(args[0].values())[0])
|
|
else:
|
|
return args[0].shape
|
|
|
|
h, w, _ = _get_shape(args)
|
|
|
|
py = random.randrange(0, h-ps+1)
|
|
px = random.randrange(0, w-ps+1)
|
|
|
|
def _crop(img):
|
|
if img.ndim == 2:
|
|
return img[py:py+ps, px:px+ps, np.newaxis]
|
|
else:
|
|
return img[py:py+ps, px:px+ps, :]
|
|
|
|
return _apply(_crop, args)
|
|
|
|
def add_noise(*args, sigma_sigma=2, rgb_range=255):
|
|
|
|
if len(args) == 1:
|
|
args = args[0]
|
|
|
|
sigma = np.random.normal() * sigma_sigma * rgb_range/255
|
|
|
|
def _add_noise(img):
|
|
noise = np.random.randn(*img.shape).astype(np.float32) * sigma
|
|
return (img + noise).clip(0, rgb_range)
|
|
|
|
return _apply(_add_noise, args)
|
|
|
|
def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255):
|
|
"""augmentation consistent to input and target"""
|
|
|
|
choices = (False, True)
|
|
|
|
hflip = hflip and random.choice(choices)
|
|
vflip = rot and random.choice(choices)
|
|
rot90 = rot and random.choice(choices)
|
|
|
|
|
|
if shuffle:
|
|
rgb_order = list(range(3))
|
|
random.shuffle(rgb_order)
|
|
if rgb_order == list(range(3)):
|
|
shuffle = False
|
|
|
|
if change_saturation:
|
|
amp_factor = np.random.uniform(0.5, 1.5)
|
|
|
|
def _augment(img):
|
|
if hflip: img = img[:, ::-1, :]
|
|
if vflip: img = img[::-1, :, :]
|
|
if rot90: img = img.transpose(1, 0, 2)
|
|
if shuffle and img.ndim > 2:
|
|
if img.shape[-1] == 3:
|
|
img = img[..., rgb_order]
|
|
|
|
if change_saturation:
|
|
hsv_img = rgb2hsv(img)
|
|
hsv_img[..., 1] *= amp_factor
|
|
|
|
img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range
|
|
|
|
return img.astype(np.float32)
|
|
|
|
return _apply(_augment, args)
|
|
|
|
def pad(img, divisor=4, pad_width=None, negative=False):
|
|
|
|
def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
|
|
if pad_width is None:
|
|
(h, w, _) = img.shape
|
|
pad_h = -h % divisor
|
|
pad_w = -w % divisor
|
|
pad_width = ((0, pad_h), (0, pad_w), (0, 0))
|
|
|
|
img = np.pad(img, pad_width, mode='edge')
|
|
|
|
return img, pad_width
|
|
|
|
def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
|
|
|
|
n, c, h, w = img.shape
|
|
if pad_width is None:
|
|
pad_h = -h % divisor
|
|
pad_w = -w % divisor
|
|
pad_width = (0, pad_w, 0, pad_h)
|
|
else:
|
|
try:
|
|
pad_h = pad_width[0][1]
|
|
pad_w = pad_width[1][1]
|
|
if isinstance(pad_h, torch.Tensor):
|
|
pad_h = pad_h.item()
|
|
if isinstance(pad_w, torch.Tensor):
|
|
pad_w = pad_w.item()
|
|
|
|
pad_width = (0, pad_w, 0, pad_h)
|
|
except:
|
|
pass
|
|
|
|
if negative:
|
|
pad_width = [-val for val in pad_width]
|
|
|
|
img = torch.nn.functional.pad(img, pad_width, 'reflect')
|
|
|
|
return img, pad_width
|
|
|
|
if isinstance(img, np.ndarray):
|
|
return _pad_numpy(img, divisor, pad_width, negative)
|
|
else:
|
|
return _pad_tensor(img, divisor, pad_width, negative)
|
|
|
|
def generate_pyramid(*args, n_scales):
|
|
|
|
def _generate_pyramid(img):
|
|
if img.dtype != np.float32:
|
|
img = img.astype(np.float32)
|
|
|
|
|
|
pyramid = [img]
|
|
|
|
return pyramid
|
|
|
|
return _apply(_generate_pyramid, args)
|
|
|
|
def np2tensor(*args, rgb_range=255):
|
|
def _np2tensor(x):
|
|
np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
|
|
tensor = torch.from_numpy(np_transpose).float()
|
|
tensor.mul_(rgb_range / 255)
|
|
return tensor
|
|
|
|
return _apply(_np2tensor, args)
|
|
|
|
def to(*args, device=None, dtype=torch.float):
|
|
|
|
def _to(x):
|
|
return x.to(device=device, dtype=dtype, non_blocking=True, copy=False)
|
|
|
|
return _apply(_to, args)
|
|
|