test / cldm /transformations2.py
Tu Bui
fix input arg name
04acf84
raw
history blame contribute delete
No virus
15 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
use kornia and albumentations for transformations
@author: Tu Bui @University of Surrey
"""
import os
from . import utils
import torch
import numpy as np
from torch import nn
import torch.nn.functional as thf
from PIL import Image
import kornia as ko
import albumentations as ab
from torchvision import transforms
class IdentityAugment(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, **kwargs):
return x
class RandomCompress(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
self.jpeg_quality = 70
elif severity == 'medium':
self.jpeg_quality = 50
elif severity == 'high':
self.jpeg_quality = 40
def forward(self, x, ramp=1.):
# x (B, C, H, W) in range [0, 1]
# ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality
if torch.rand(1)[0] >= self.p:
return x
jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality)
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
return x
class RandomBoxBlur(nn.Module):
def __init__(self, severity='medium', border_type='reflect', normalized=True, p=0.5):
super().__init__()
self.p = p
if severity == 'low':
kernel_size = 3
elif severity == 'medium':
kernel_size = 5
elif severity == 'high':
kernel_size = 7
self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalized=normalized, p=self.p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomMedianBlur(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomBrightness(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
brightness = (0.9, 1.1)
elif severity == 'medium':
brightness = (0.75, 1.25)
elif severity == 'high':
brightness = (0.5, 1.5)
self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomContrast(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
contrast = (0.9, 1.1)
elif severity == 'medium':
contrast = (0.75, 1.25)
elif severity == 'high':
contrast = (0.5, 1.5)
self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomSaturation(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
sat = (0.9, 1.1)
elif severity == 'medium':
sat = (0.75, 1.25)
elif severity == 'high':
sat = (0.5, 1.5)
self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomSharpness(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
sharpness = 0.5
elif severity == 'medium':
sharpness = 1.0
elif severity == 'high':
sharpness = 2.5
self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomColorJiggle(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
factor = (0.05, 0.05, 0.05, 0.01)
elif severity == 'medium':
factor = (0.1, 0.1, 0.1, 0.02)
elif severity == 'high':
factor = (0.1, 0.1, 0.1, 0.05)
self.tform = ko.augmentation.ColorJiggle(*factor, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomHue(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
hue = 0.01
elif severity == 'medium':
hue = 0.02
elif severity == 'high':
hue = 0.05
self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomGamma(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
gamma, gain = (0.9, 1.1), (0.9,1.1)
elif severity == 'medium':
gamma, gain = (0.75, 1.25), (0.75,1.25)
elif severity == 'high':
gamma, gain = (0.5, 1.5), (0.5,1.5)
self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomGaussianBlur(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
kernel_size, sigma = 3, (0.1, 1.0)
elif severity == 'medium':
kernel_size, sigma = 5, (0.1, 1.5)
elif severity == 'high':
kernel_size, sigma = 7, (0.1, 2.0)
self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomGaussianNoise(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
std = 0.02
elif severity == 'medium':
std = 0.04
elif severity == 'high':
std = 0.08
self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomMotionBlur(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25)
elif severity == 'medium':
kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5)
elif severity == 'high':
kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0)
self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomPosterize(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
bits = 5
elif severity == 'medium':
bits = 4
elif severity == 'high':
bits = 3
self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomRGBShift(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
rgb = 0.02
elif severity == 'medium':
rgb = 0.05
elif severity == 'high':
rgb = 0.1
self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class TransformNet(nn.Module):
def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=False, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, box_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5):
super().__init__()
self.n_optional = n_optional
self.p = p
p_flip = 0.5 if flip else 0
rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip)
self.ramp = ramp
self.register_buffer('step0', torch.tensor(0))
self.crop_mode = crop_mode
assert crop_mode in ['random_crop', 'resized_crop']
if crop_mode == 'random_crop':
rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
elif crop_mode == 'resized_crop':
rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer]
if compress:
self.register(RandomCompress(severity, p=p), 'Random Compress')
if brightness:
self.register(RandomBrightness(severity, p=p), 'Random Brightness')
if contrast:
self.register(RandomContrast(severity, p=p), 'Random Contrast')
if color_jiggle:
self.register(RandomColorJiggle(severity, p=p), 'Random Color')
if gamma:
self.register(RandomGamma(severity, p=p), 'Random Gamma')
if grayscale:
self.register(ko.augmentation.RandomGrayscale(p=p), 'Grayscale')
if gaussian_blur:
self.register(RandomGaussianBlur(severity, p=p), 'Random Gaussian Blur')
if gaussian_noise:
self.register(RandomGaussianNoise(severity, p=p), 'Random Gaussian Noise')
if hue:
self.register(RandomHue(severity, p=p), 'Random Hue')
if motion_blur:
self.register(RandomMotionBlur(severity, p=p), 'Random Motion Blur')
if posterize:
self.register(RandomPosterize(severity, p=p), 'Random Posterize')
if rgb_shift:
self.register(RandomRGBShift(severity, p=p), 'Random RGB Shift')
if saturation:
self.register(RandomSaturation(severity, p=p), 'Random Saturation')
if sharpness:
self.register(RandomSharpness(severity, p=p), 'Random Sharpness')
if median_blur:
self.register(RandomMedianBlur(severity, p=p), 'Random Median Blur')
if box_blur:
self.register(RandomBoxBlur(severity, p=p), 'Random Box Blur')
def register(self, tform, name):
# register a new (optional) transform
if not hasattr(self, 'optional_transforms'):
self.optional_transforms = []
self.optional_names = []
self.optional_transforms.append(tform)
self.optional_names.append(name)
def activate(self, global_step):
if self.step0 == 0:
print(f'[TRAINING] Activating TransformNet at step {global_step}')
self.step0 = torch.tensor(global_step)
def is_activated(self):
return self.step0 > 0
def forward(self, x, global_step, p=0.9):
# x: [batch_size, 3, H, W] in range [-1, 1]
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
# fixed transforms
for tform in self.fixed_transforms:
x = tform(x)
if isinstance(x, tuple):
x = x[0]
# optional transforms
ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.])
if len(self.optional_transforms) > 0:
tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy()
for tform_id in tform_ids:
tform = self.optional_transforms[tform_id]
x = tform(x, ramp=ramp)
if isinstance(x, tuple):
x = x[0]
return x * 2 - 1 # [0, 1] -> [-1, 1]
def transform_by_id(self, x, tform_id):
# x: [batch_size, 3, H, W] in range [-1, 1]
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
# fixed transforms
for tform in self.fixed_transforms:
x = tform(x)
if isinstance(x, tuple):
x = x[0]
# optional transforms
tform = self.optional_transforms[tform_id]
x = tform(x)
if isinstance(x, tuple):
x = x[0]
return x * 2 - 1 # [0, 1] -> [-1, 1]
def transform_by_name(self, x, tform_name):
assert tform_name in self.optional_names
tform_id = self.optional_names.index(tform_name)
return self.transform_by_id(x, tform_id)
def apply_transform_on_pil_image(self, x, tform_name):
# x: PIL image
# return: PIL image
assert tform_name in self.optional_names + ['Fixed Augment']
# if tform_name == 'Random Crop': # the only transform dependent on image size
# # crop equivalent to 224/256
# w, h = x.size
# new_w, new_h = int(224 / 256 * w), int(224 / 256 * h)
# x = transforms.RandomCrop((new_h, new_w))(x)
# return x
# x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
# x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
# if tform_name == 'Random Flip':
# x = self.fixed_transforms[0](x)
# else:
# tform_id = self.optional_names.index(tform_name)
# tform = self.optional_transforms[tform_id]
# x = tform(x)
# if isinstance(x, tuple):
# x = x[0]
# x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255]
# return Image.fromarray(x.astype(np.uint8))
w, h = x.size
x = x.resize((256, 256), Image.BILINEAR)
x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
if tform_name == 'Fixed Augment':
for tform in self.fixed_transforms:
x = tform(x)
if isinstance(x, tuple):
x = x[0]
else:
tform_id = self.optional_names.index(tform_name)
tform = self.optional_transforms[tform_id]
x = tform(x)
if isinstance(x, tuple):
x = x[0]
x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255]
x = Image.fromarray(x.astype(np.uint8))
if (tform_name == 'Random Crop') and (self.crop_mode == 'random_crop'):
w, h = int(224 / 256 * w), int(224 / 256 * h)
x = x.resize((w, h), Image.BILINEAR)
return x
if __name__ == '__main__':
pass