|
|
|
|
|
""" |
|
use kornia and albumentations for transformations |
|
@author: Tu Bui @University of Surrey |
|
""" |
|
import os |
|
from . import utils |
|
import torch |
|
import numpy as np |
|
from torch import nn |
|
import torch.nn.functional as thf |
|
from PIL import Image |
|
import kornia as ko |
|
import albumentations as ab |
|
from torchvision import transforms |
|
|
|
|
|
class IdentityAugment(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, **kwargs): |
|
return x |
|
|
|
|
|
class RandomCompress(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
self.jpeg_quality = 70 |
|
elif severity == 'medium': |
|
self.jpeg_quality = 50 |
|
elif severity == 'high': |
|
self.jpeg_quality = 40 |
|
|
|
def forward(self, x, ramp=1.): |
|
|
|
|
|
if torch.rand(1)[0] >= self.p: |
|
return x |
|
jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality) |
|
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality) |
|
return x |
|
|
|
|
|
class RandomBoxBlur(nn.Module): |
|
def __init__(self, severity='medium', border_type='reflect', normalized=True, p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
kernel_size = 3 |
|
elif severity == 'medium': |
|
kernel_size = 5 |
|
elif severity == 'high': |
|
kernel_size = 7 |
|
|
|
self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalized=normalized, p=self.p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomMedianBlur(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
class RandomBrightness(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
brightness = (0.9, 1.1) |
|
elif severity == 'medium': |
|
brightness = (0.75, 1.25) |
|
elif severity == 'high': |
|
brightness = (0.5, 1.5) |
|
self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
class RandomContrast(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
contrast = (0.9, 1.1) |
|
elif severity == 'medium': |
|
contrast = (0.75, 1.25) |
|
elif severity == 'high': |
|
contrast = (0.5, 1.5) |
|
self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
class RandomSaturation(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
sat = (0.9, 1.1) |
|
elif severity == 'medium': |
|
sat = (0.75, 1.25) |
|
elif severity == 'high': |
|
sat = (0.5, 1.5) |
|
self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomSharpness(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
sharpness = 0.5 |
|
elif severity == 'medium': |
|
sharpness = 1.0 |
|
elif severity == 'high': |
|
sharpness = 2.5 |
|
self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomColorJiggle(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
factor = (0.05, 0.05, 0.05, 0.01) |
|
elif severity == 'medium': |
|
factor = (0.1, 0.1, 0.1, 0.02) |
|
elif severity == 'high': |
|
factor = (0.1, 0.1, 0.1, 0.05) |
|
self.tform = ko.augmentation.ColorJiggle(*factor, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomHue(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
hue = 0.01 |
|
elif severity == 'medium': |
|
hue = 0.02 |
|
elif severity == 'high': |
|
hue = 0.05 |
|
self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomGamma(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
gamma, gain = (0.9, 1.1), (0.9,1.1) |
|
elif severity == 'medium': |
|
gamma, gain = (0.75, 1.25), (0.75,1.25) |
|
elif severity == 'high': |
|
gamma, gain = (0.5, 1.5), (0.5,1.5) |
|
self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomGaussianBlur(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
kernel_size, sigma = 3, (0.1, 1.0) |
|
elif severity == 'medium': |
|
kernel_size, sigma = 5, (0.1, 1.5) |
|
elif severity == 'high': |
|
kernel_size, sigma = 7, (0.1, 2.0) |
|
self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomGaussianNoise(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
std = 0.02 |
|
elif severity == 'medium': |
|
std = 0.04 |
|
elif severity == 'high': |
|
std = 0.08 |
|
self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomMotionBlur(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25) |
|
elif severity == 'medium': |
|
kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5) |
|
elif severity == 'high': |
|
kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0) |
|
self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomPosterize(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
bits = 5 |
|
elif severity == 'medium': |
|
bits = 4 |
|
elif severity == 'high': |
|
bits = 3 |
|
self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomRGBShift(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
rgb = 0.02 |
|
elif severity == 'medium': |
|
rgb = 0.05 |
|
elif severity == 'high': |
|
rgb = 0.1 |
|
self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
|
|
class TransformNet(nn.Module): |
|
def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=False, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, box_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5): |
|
super().__init__() |
|
self.n_optional = n_optional |
|
self.p = p |
|
p_flip = 0.5 if flip else 0 |
|
rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip) |
|
self.ramp = ramp |
|
self.register_buffer('step0', torch.tensor(0)) |
|
|
|
self.crop_mode = crop_mode |
|
assert crop_mode in ['random_crop', 'resized_crop'] |
|
if crop_mode == 'random_crop': |
|
rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample") |
|
elif crop_mode == 'resized_crop': |
|
rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample') |
|
|
|
self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer] |
|
if compress: |
|
self.register(RandomCompress(severity, p=p), 'Random Compress') |
|
if brightness: |
|
self.register(RandomBrightness(severity, p=p), 'Random Brightness') |
|
if contrast: |
|
self.register(RandomContrast(severity, p=p), 'Random Contrast') |
|
if color_jiggle: |
|
self.register(RandomColorJiggle(severity, p=p), 'Random Color') |
|
if gamma: |
|
self.register(RandomGamma(severity, p=p), 'Random Gamma') |
|
if grayscale: |
|
self.register(ko.augmentation.RandomGrayscale(p=p), 'Grayscale') |
|
if gaussian_blur: |
|
self.register(RandomGaussianBlur(severity, p=p), 'Random Gaussian Blur') |
|
if gaussian_noise: |
|
self.register(RandomGaussianNoise(severity, p=p), 'Random Gaussian Noise') |
|
if hue: |
|
self.register(RandomHue(severity, p=p), 'Random Hue') |
|
if motion_blur: |
|
self.register(RandomMotionBlur(severity, p=p), 'Random Motion Blur') |
|
if posterize: |
|
self.register(RandomPosterize(severity, p=p), 'Random Posterize') |
|
if rgb_shift: |
|
self.register(RandomRGBShift(severity, p=p), 'Random RGB Shift') |
|
if saturation: |
|
self.register(RandomSaturation(severity, p=p), 'Random Saturation') |
|
if sharpness: |
|
self.register(RandomSharpness(severity, p=p), 'Random Sharpness') |
|
if median_blur: |
|
self.register(RandomMedianBlur(severity, p=p), 'Random Median Blur') |
|
if box_blur: |
|
self.register(RandomBoxBlur(severity, p=p), 'Random Box Blur') |
|
|
|
def register(self, tform, name): |
|
|
|
if not hasattr(self, 'optional_transforms'): |
|
self.optional_transforms = [] |
|
self.optional_names = [] |
|
self.optional_transforms.append(tform) |
|
self.optional_names.append(name) |
|
|
|
def activate(self, global_step): |
|
if self.step0 == 0: |
|
print(f'[TRAINING] Activating TransformNet at step {global_step}') |
|
self.step0 = torch.tensor(global_step) |
|
|
|
def is_activated(self): |
|
return self.step0 > 0 |
|
|
|
def forward(self, x, global_step, p=0.9): |
|
|
|
x = x * 0.5 + 0.5 |
|
|
|
for tform in self.fixed_transforms: |
|
x = tform(x) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
|
|
|
|
ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.]) |
|
if len(self.optional_transforms) > 0: |
|
tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy() |
|
for tform_id in tform_ids: |
|
tform = self.optional_transforms[tform_id] |
|
x = tform(x, ramp=ramp) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
|
|
return x * 2 - 1 |
|
|
|
def transform_by_id(self, x, tform_id): |
|
|
|
x = x * 0.5 + 0.5 |
|
|
|
for tform in self.fixed_transforms: |
|
x = tform(x) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
|
|
|
|
tform = self.optional_transforms[tform_id] |
|
x = tform(x) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
return x * 2 - 1 |
|
|
|
def transform_by_name(self, x, tform_name): |
|
assert tform_name in self.optional_names |
|
tform_id = self.optional_names.index(tform_name) |
|
return self.transform_by_id(x, tform_id) |
|
|
|
def apply_transform_on_pil_image(self, x, tform_name): |
|
|
|
|
|
assert tform_name in self.optional_names + ['Fixed Augment'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
w, h = x.size |
|
x = x.resize((256, 256), Image.BILINEAR) |
|
x = np.array(x).astype(np.float32) / 255. |
|
x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) |
|
if tform_name == 'Fixed Augment': |
|
for tform in self.fixed_transforms: |
|
x = tform(x) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
else: |
|
tform_id = self.optional_names.index(tform_name) |
|
tform = self.optional_transforms[tform_id] |
|
x = tform(x) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 |
|
x = Image.fromarray(x.astype(np.uint8)) |
|
if (tform_name == 'Random Crop') and (self.crop_mode == 'random_crop'): |
|
w, h = int(224 / 256 * w), int(224 / 256 * h) |
|
x = x.resize((w, h), Image.BILINEAR) |
|
return x |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |