#!/usr/bin/env python # -*- coding: utf-8 -*- """ use kornia and albumentations for transformations @author: Tu Bui @University of Surrey """ import os from . import utils import torch import numpy as np from torch import nn import torch.nn.functional as F from PIL import Image import kornia as ko import albumentations as ab class IdentityAugment(nn.Module): def __init__(self): super().__init__() def forward(self, x, **kwargs): return x class RandomCompress(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': self.jpeg_quality = 70 elif severity == 'medium': self.jpeg_quality = 50 elif severity == 'high': self.jpeg_quality = 40 def forward(self, x, ramp=1.): # x (B, C, H, W) in range [0, 1] # ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality if torch.rand(1)[0] >= self.p: return x jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality) x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality) return x class RandomBoxBlur(nn.Module): def __init__(self, severity='medium', border_type='reflect', normalize=True, p=0.5): super().__init__() self.p = p if severity == 'low': kernel_size = 3 elif severity == 'medium': kernel_size = 5 elif severity == 'high': kernel_size = 7 self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalize=normalize, p=self.p) def forward(self, x, **kwargs): return self.tform(x) class RandomMedianBlur(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomBrightness(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': brightness = (0.9, 1.1) elif severity == 'medium': brightness = (0.75, 1.25) elif severity == 'high': brightness = (0.5, 1.5) self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomContrast(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': contrast = (0.9, 1.1) elif severity == 'medium': contrast = (0.75, 1.25) elif severity == 'high': contrast = (0.5, 1.5) self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomSaturation(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': sat = (0.9, 1.1) elif severity == 'medium': sat = (0.75, 1.25) elif severity == 'high': sat = (0.5, 1.5) self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomSharpness(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': sharpness = 0.5 elif severity == 'medium': sharpness = 1.0 elif severity == 'high': sharpness = 2.5 self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomColorJiggle(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': factor = (0.05, 0.05, 0.05, 0.01) elif severity == 'medium': factor = (0.1, 0.1, 0.1, 0.02) elif severity == 'high': factor = (0.1, 0.1, 0.1, 0.05) self.tform = ko.augmentation.ColorJiggle(*factor, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomHue(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': hue = 0.01 elif severity == 'medium': hue = 0.02 elif severity == 'high': hue = 0.05 self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomGamma(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': gamma, gain = (0.9, 1.1), (0.9,1.1) elif severity == 'medium': gamma, gain = (0.75, 1.25), (0.75,1.25) elif severity == 'high': gamma, gain = (0.5, 1.5), (0.5,1.5) self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomGaussianBlur(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': kernel_size, sigma = 3, (0.1, 1.0) elif severity == 'medium': kernel_size, sigma = 5, (0.1, 1.5) elif severity == 'high': kernel_size, sigma = 7, (0.1, 2.0) self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p) def forward(self, x, **kwargs): return self.tform(x) class RandomGaussianNoise(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': std = 0.02 elif severity == 'medium': std = 0.04 elif severity == 'high': std = 0.08 self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomMotionBlur(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25) elif severity == 'medium': kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5) elif severity == 'high': kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0) self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomPosterize(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': bits = 5 elif severity == 'medium': bits = 4 elif severity == 'high': bits = 3 self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomRGBShift(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': rgb = 0.02 elif severity == 'medium': rgb = 0.05 elif severity == 'high': rgb = 0.1 self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p) def forward(self, x, **kwargs): return self.tform(x) class TransformNet(nn.Module): def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=True, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5): super().__init__() self.n_optional = n_optional self.p = p p_flip = 0.5 if flip else 0 rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip) self.ramp = ramp self.register_buffer('step0', torch.tensor(0)) assert crop_mode in ['random_crop', 'resized_crop'] if crop_mode == 'random_crop': rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample") elif crop_mode == 'resized_crop': rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample') self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer] self.optional_transforms = [] if compress: self.optional_transforms.append(RandomCompress(severity, p=p)) if brightness: self.optional_transforms.append(RandomBrightness(severity, p=p)) if contrast: self.optional_transforms.append(RandomContrast(severity, p=p)) if color_jiggle: self.optional_transforms.append(RandomColorJiggle(severity, p=p)) if gamma: self.optional_transforms.append(RandomGamma(severity, p=p)) if grayscale: self.optional_transforms.append(ko.augmentation.RandomGrayscale(p=p/4)) if gaussian_blur: self.optional_transforms.append(RandomGaussianBlur(severity, p=p)) if gaussian_noise: self.optional_transforms.append(RandomGaussianNoise(severity, p=p)) if hue: self.optional_transforms.append(RandomHue(severity, p=p)) if motion_blur: self.optional_transforms.append(RandomMotionBlur(severity, p=p)) if posterize: self.optional_transforms.append(RandomPosterize(severity, p=p)) if rgb_shift: self.optional_transforms.append(RandomRGBShift(severity, p=p)) if saturation: self.optional_transforms.append(RandomSaturation(severity, p=p)) if sharpness: self.optional_transforms.append(RandomSharpness(severity, p=p)) if median_blur: self.optional_transforms.append(RandomMedianBlur(severity, p=p)) def activate(self, global_step): if self.step0 == 0: print(f'[TRAINING] Activating TransformNet at step {global_step}') self.step0 = torch.tensor(global_step) def is_activated(self): return self.step0 > 0 def forward(self, x, global_step, p=0.9): # x: [batch_size, 3, H, W] in range [-1, 1] x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1] # fixed transforms for tform in self.fixed_transforms: x = tform(x) if isinstance(x, tuple): x = x[0] # optional transforms ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.]) try: if len(self.optional_transforms) > 0: tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy() for tform_id in tform_ids: tform = self.optional_transforms[tform_id] x = tform(x, ramp=ramp) if isinstance(x, tuple): x = x[0] except Exception as e: print(tform_id, ramp) import pdb; pdb.set_trace() return x * 2 - 1 # [0, 1] -> [-1, 1] if __name__ == '__main__': pass