|
|
|
|
|
""" |
|
use kornia and albumentations for transformations |
|
@author: Tu Bui @University of Surrey |
|
""" |
|
import os |
|
from . import utils |
|
import torch |
|
import numpy as np |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import kornia as ko |
|
import albumentations as ab |
|
|
|
|
|
class IdentityAugment(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, **kwargs): |
|
return x |
|
|
|
|
|
class RandomCompress(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
self.jpeg_quality = 70 |
|
elif severity == 'medium': |
|
self.jpeg_quality = 50 |
|
elif severity == 'high': |
|
self.jpeg_quality = 40 |
|
|
|
def forward(self, x, ramp=1.): |
|
|
|
|
|
if torch.rand(1)[0] >= self.p: |
|
return x |
|
jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality) |
|
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality) |
|
return x |
|
|
|
|
|
class RandomBoxBlur(nn.Module): |
|
def __init__(self, severity='medium', border_type='reflect', normalize=True, p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
kernel_size = 3 |
|
elif severity == 'medium': |
|
kernel_size = 5 |
|
elif severity == 'high': |
|
kernel_size = 7 |
|
|
|
self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalize=normalize, p=self.p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomMedianBlur(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
class RandomBrightness(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
brightness = (0.9, 1.1) |
|
elif severity == 'medium': |
|
brightness = (0.75, 1.25) |
|
elif severity == 'high': |
|
brightness = (0.5, 1.5) |
|
self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
class RandomContrast(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
contrast = (0.9, 1.1) |
|
elif severity == 'medium': |
|
contrast = (0.75, 1.25) |
|
elif severity == 'high': |
|
contrast = (0.5, 1.5) |
|
self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
class RandomSaturation(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
sat = (0.9, 1.1) |
|
elif severity == 'medium': |
|
sat = (0.75, 1.25) |
|
elif severity == 'high': |
|
sat = (0.5, 1.5) |
|
self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomSharpness(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
sharpness = 0.5 |
|
elif severity == 'medium': |
|
sharpness = 1.0 |
|
elif severity == 'high': |
|
sharpness = 2.5 |
|
self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomColorJiggle(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
factor = (0.05, 0.05, 0.05, 0.01) |
|
elif severity == 'medium': |
|
factor = (0.1, 0.1, 0.1, 0.02) |
|
elif severity == 'high': |
|
factor = (0.1, 0.1, 0.1, 0.05) |
|
self.tform = ko.augmentation.ColorJiggle(*factor, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomHue(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
hue = 0.01 |
|
elif severity == 'medium': |
|
hue = 0.02 |
|
elif severity == 'high': |
|
hue = 0.05 |
|
self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomGamma(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
gamma, gain = (0.9, 1.1), (0.9,1.1) |
|
elif severity == 'medium': |
|
gamma, gain = (0.75, 1.25), (0.75,1.25) |
|
elif severity == 'high': |
|
gamma, gain = (0.5, 1.5), (0.5,1.5) |
|
self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomGaussianBlur(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
kernel_size, sigma = 3, (0.1, 1.0) |
|
elif severity == 'medium': |
|
kernel_size, sigma = 5, (0.1, 1.5) |
|
elif severity == 'high': |
|
kernel_size, sigma = 7, (0.1, 2.0) |
|
self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomGaussianNoise(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
std = 0.02 |
|
elif severity == 'medium': |
|
std = 0.04 |
|
elif severity == 'high': |
|
std = 0.08 |
|
self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomMotionBlur(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25) |
|
elif severity == 'medium': |
|
kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5) |
|
elif severity == 'high': |
|
kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0) |
|
self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomPosterize(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
bits = 5 |
|
elif severity == 'medium': |
|
bits = 4 |
|
elif severity == 'high': |
|
bits = 3 |
|
self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
class RandomRGBShift(nn.Module): |
|
def __init__(self, severity='medium', p=0.5): |
|
super().__init__() |
|
self.p = p |
|
if severity == 'low': |
|
rgb = 0.02 |
|
elif severity == 'medium': |
|
rgb = 0.05 |
|
elif severity == 'high': |
|
rgb = 0.1 |
|
self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.tform(x) |
|
|
|
|
|
|
|
class TransformNet(nn.Module): |
|
def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=True, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5): |
|
super().__init__() |
|
self.n_optional = n_optional |
|
self.p = p |
|
p_flip = 0.5 if flip else 0 |
|
rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip) |
|
self.ramp = ramp |
|
self.register_buffer('step0', torch.tensor(0)) |
|
|
|
assert crop_mode in ['random_crop', 'resized_crop'] |
|
if crop_mode == 'random_crop': |
|
rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample") |
|
elif crop_mode == 'resized_crop': |
|
rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample') |
|
|
|
self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer] |
|
self.optional_transforms = [] |
|
if compress: |
|
self.optional_transforms.append(RandomCompress(severity, p=p)) |
|
if brightness: |
|
self.optional_transforms.append(RandomBrightness(severity, p=p)) |
|
if contrast: |
|
self.optional_transforms.append(RandomContrast(severity, p=p)) |
|
if color_jiggle: |
|
self.optional_transforms.append(RandomColorJiggle(severity, p=p)) |
|
if gamma: |
|
self.optional_transforms.append(RandomGamma(severity, p=p)) |
|
if grayscale: |
|
self.optional_transforms.append(ko.augmentation.RandomGrayscale(p=p/4)) |
|
if gaussian_blur: |
|
self.optional_transforms.append(RandomGaussianBlur(severity, p=p)) |
|
if gaussian_noise: |
|
self.optional_transforms.append(RandomGaussianNoise(severity, p=p)) |
|
if hue: |
|
self.optional_transforms.append(RandomHue(severity, p=p)) |
|
if motion_blur: |
|
self.optional_transforms.append(RandomMotionBlur(severity, p=p)) |
|
if posterize: |
|
self.optional_transforms.append(RandomPosterize(severity, p=p)) |
|
if rgb_shift: |
|
self.optional_transforms.append(RandomRGBShift(severity, p=p)) |
|
if saturation: |
|
self.optional_transforms.append(RandomSaturation(severity, p=p)) |
|
if sharpness: |
|
self.optional_transforms.append(RandomSharpness(severity, p=p)) |
|
if median_blur: |
|
self.optional_transforms.append(RandomMedianBlur(severity, p=p)) |
|
|
|
def activate(self, global_step): |
|
if self.step0 == 0: |
|
print(f'[TRAINING] Activating TransformNet at step {global_step}') |
|
self.step0 = torch.tensor(global_step) |
|
|
|
def is_activated(self): |
|
return self.step0 > 0 |
|
|
|
def forward(self, x, global_step, p=0.9): |
|
|
|
x = x * 0.5 + 0.5 |
|
|
|
for tform in self.fixed_transforms: |
|
x = tform(x) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
|
|
|
|
ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.]) |
|
try: |
|
if len(self.optional_transforms) > 0: |
|
tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy() |
|
for tform_id in tform_ids: |
|
tform = self.optional_transforms[tform_id] |
|
x = tform(x, ramp=ramp) |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
except Exception as e: |
|
print(tform_id, ramp) |
|
import pdb; pdb.set_trace() |
|
return x * 2 - 1 |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |