#!/usr/bin/env python # -*- coding: utf-8 -*- """ use kornia and albumentations for transformations @author: Tu Bui @University of Surrey """ import os from . import utils import torch import numpy as np from torch import nn import torch.nn.functional as thf from PIL import Image import kornia as ko import albumentations as ab from torchvision import transforms class IdentityAugment(nn.Module): def __init__(self): super().__init__() def forward(self, x, **kwargs): return x class RandomCompress(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': self.jpeg_quality = 70 elif severity == 'medium': self.jpeg_quality = 50 elif severity == 'high': self.jpeg_quality = 40 def forward(self, x, ramp=1.): # x (B, C, H, W) in range [0, 1] # ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality if torch.rand(1)[0] >= self.p: return x jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality) x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality) return x class RandomBoxBlur(nn.Module): def __init__(self, severity='medium', border_type='reflect', normalized=True, p=0.5): super().__init__() self.p = p if severity == 'low': kernel_size = 3 elif severity == 'medium': kernel_size = 5 elif severity == 'high': kernel_size = 7 self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalized=normalized, p=self.p) def forward(self, x, **kwargs): return self.tform(x) class RandomMedianBlur(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomBrightness(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': brightness = (0.9, 1.1) elif severity == 'medium': brightness = (0.75, 1.25) elif severity == 'high': brightness = (0.5, 1.5) self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomContrast(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': contrast = (0.9, 1.1) elif severity == 'medium': contrast = (0.75, 1.25) elif severity == 'high': contrast = (0.5, 1.5) self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomSaturation(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': sat = (0.9, 1.1) elif severity == 'medium': sat = (0.75, 1.25) elif severity == 'high': sat = (0.5, 1.5) self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomSharpness(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': sharpness = 0.5 elif severity == 'medium': sharpness = 1.0 elif severity == 'high': sharpness = 2.5 self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomColorJiggle(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': factor = (0.05, 0.05, 0.05, 0.01) elif severity == 'medium': factor = (0.1, 0.1, 0.1, 0.02) elif severity == 'high': factor = (0.1, 0.1, 0.1, 0.05) self.tform = ko.augmentation.ColorJiggle(*factor, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomHue(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': hue = 0.01 elif severity == 'medium': hue = 0.02 elif severity == 'high': hue = 0.05 self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomGamma(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': gamma, gain = (0.9, 1.1), (0.9,1.1) elif severity == 'medium': gamma, gain = (0.75, 1.25), (0.75,1.25) elif severity == 'high': gamma, gain = (0.5, 1.5), (0.5,1.5) self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomGaussianBlur(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': kernel_size, sigma = 3, (0.1, 1.0) elif severity == 'medium': kernel_size, sigma = 5, (0.1, 1.5) elif severity == 'high': kernel_size, sigma = 7, (0.1, 2.0) self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p) def forward(self, x, **kwargs): return self.tform(x) class RandomGaussianNoise(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': std = 0.02 elif severity == 'medium': std = 0.04 elif severity == 'high': std = 0.08 self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomMotionBlur(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25) elif severity == 'medium': kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5) elif severity == 'high': kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0) self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomPosterize(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': bits = 5 elif severity == 'medium': bits = 4 elif severity == 'high': bits = 3 self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p) def forward(self, x, **kwargs): return self.tform(x) class RandomRGBShift(nn.Module): def __init__(self, severity='medium', p=0.5): super().__init__() self.p = p if severity == 'low': rgb = 0.02 elif severity == 'medium': rgb = 0.05 elif severity == 'high': rgb = 0.1 self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p) def forward(self, x, **kwargs): return self.tform(x) class TransformNet(nn.Module): def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=False, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, box_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5): super().__init__() self.n_optional = n_optional self.p = p p_flip = 0.5 if flip else 0 rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip) self.ramp = ramp self.register_buffer('step0', torch.tensor(0)) self.crop_mode = crop_mode assert crop_mode in ['random_crop', 'resized_crop'] if crop_mode == 'random_crop': rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample") elif crop_mode == 'resized_crop': rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample') self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer] if compress: self.register(RandomCompress(severity, p=p), 'Random Compress') if brightness: self.register(RandomBrightness(severity, p=p), 'Random Brightness') if contrast: self.register(RandomContrast(severity, p=p), 'Random Contrast') if color_jiggle: self.register(RandomColorJiggle(severity, p=p), 'Random Color') if gamma: self.register(RandomGamma(severity, p=p), 'Random Gamma') if grayscale: self.register(ko.augmentation.RandomGrayscale(p=p), 'Grayscale') if gaussian_blur: self.register(RandomGaussianBlur(severity, p=p), 'Random Gaussian Blur') if gaussian_noise: self.register(RandomGaussianNoise(severity, p=p), 'Random Gaussian Noise') if hue: self.register(RandomHue(severity, p=p), 'Random Hue') if motion_blur: self.register(RandomMotionBlur(severity, p=p), 'Random Motion Blur') if posterize: self.register(RandomPosterize(severity, p=p), 'Random Posterize') if rgb_shift: self.register(RandomRGBShift(severity, p=p), 'Random RGB Shift') if saturation: self.register(RandomSaturation(severity, p=p), 'Random Saturation') if sharpness: self.register(RandomSharpness(severity, p=p), 'Random Sharpness') if median_blur: self.register(RandomMedianBlur(severity, p=p), 'Random Median Blur') if box_blur: self.register(RandomBoxBlur(severity, p=p), 'Random Box Blur') def register(self, tform, name): # register a new (optional) transform if not hasattr(self, 'optional_transforms'): self.optional_transforms = [] self.optional_names = [] self.optional_transforms.append(tform) self.optional_names.append(name) def activate(self, global_step): if self.step0 == 0: print(f'[TRAINING] Activating TransformNet at step {global_step}') self.step0 = torch.tensor(global_step) def is_activated(self): return self.step0 > 0 def forward(self, x, global_step, p=0.9): # x: [batch_size, 3, H, W] in range [-1, 1] x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1] # fixed transforms for tform in self.fixed_transforms: x = tform(x) if isinstance(x, tuple): x = x[0] # optional transforms ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.]) if len(self.optional_transforms) > 0: tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy() for tform_id in tform_ids: tform = self.optional_transforms[tform_id] x = tform(x, ramp=ramp) if isinstance(x, tuple): x = x[0] return x * 2 - 1 # [0, 1] -> [-1, 1] def transform_by_id(self, x, tform_id): # x: [batch_size, 3, H, W] in range [-1, 1] x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1] # fixed transforms for tform in self.fixed_transforms: x = tform(x) if isinstance(x, tuple): x = x[0] # optional transforms tform = self.optional_transforms[tform_id] x = tform(x) if isinstance(x, tuple): x = x[0] return x * 2 - 1 # [0, 1] -> [-1, 1] def transform_by_name(self, x, tform_name): assert tform_name in self.optional_names tform_id = self.optional_names.index(tform_name) return self.transform_by_id(x, tform_id) def apply_transform_on_pil_image(self, x, tform_name): # x: PIL image # return: PIL image assert tform_name in self.optional_names + ['Fixed Augment'] # if tform_name == 'Random Crop': # the only transform dependent on image size # # crop equivalent to 224/256 # w, h = x.size # new_w, new_h = int(224 / 256 * w), int(224 / 256 * h) # x = transforms.RandomCrop((new_h, new_w))(x) # return x # x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1] # x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] # if tform_name == 'Random Flip': # x = self.fixed_transforms[0](x) # else: # tform_id = self.optional_names.index(tform_name) # tform = self.optional_transforms[tform_id] # x = tform(x) # if isinstance(x, tuple): # x = x[0] # x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255] # return Image.fromarray(x.astype(np.uint8)) w, h = x.size x = x.resize((256, 256), Image.BILINEAR) x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1] x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] if tform_name == 'Fixed Augment': for tform in self.fixed_transforms: x = tform(x) if isinstance(x, tuple): x = x[0] else: tform_id = self.optional_names.index(tform_name) tform = self.optional_transforms[tform_id] x = tform(x) if isinstance(x, tuple): x = x[0] x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255] x = Image.fromarray(x.astype(np.uint8)) if (tform_name == 'Random Crop') and (self.crop_mode == 'random_crop'): w, h = int(224 / 256 * w), int(224 / 256 * h) x = x.resize((w, h), Image.BILINEAR) return x if __name__ == '__main__': pass