test / cldm /tmp.py
Tu Bui
first commit
6142a25
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
use kornia and albumentations for transformations
@author: Tu Bui @University of Surrey
"""
import os
from . import utils
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from PIL import Image
import kornia as ko
import albumentations as ab
class IdentityAugment(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, **kwargs):
return x
class RandomCompress(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
self.jpeg_quality = 70
elif severity == 'medium':
self.jpeg_quality = 50
elif severity == 'high':
self.jpeg_quality = 40
def forward(self, x, ramp=1.):
# x (B, C, H, W) in range [0, 1]
# ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality
if torch.rand(1)[0] >= self.p:
return x
jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality)
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
return x
class RandomBoxBlur(nn.Module):
def __init__(self, severity='medium', border_type='reflect', normalize=True, p=0.5):
super().__init__()
self.p = p
if severity == 'low':
kernel_size = 3
elif severity == 'medium':
kernel_size = 5
elif severity == 'high':
kernel_size = 7
self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalize=normalize, p=self.p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomMedianBlur(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomBrightness(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
brightness = (0.9, 1.1)
elif severity == 'medium':
brightness = (0.75, 1.25)
elif severity == 'high':
brightness = (0.5, 1.5)
self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomContrast(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
contrast = (0.9, 1.1)
elif severity == 'medium':
contrast = (0.75, 1.25)
elif severity == 'high':
contrast = (0.5, 1.5)
self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomSaturation(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
sat = (0.9, 1.1)
elif severity == 'medium':
sat = (0.75, 1.25)
elif severity == 'high':
sat = (0.5, 1.5)
self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomSharpness(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
sharpness = 0.5
elif severity == 'medium':
sharpness = 1.0
elif severity == 'high':
sharpness = 2.5
self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomColorJiggle(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
factor = (0.05, 0.05, 0.05, 0.01)
elif severity == 'medium':
factor = (0.1, 0.1, 0.1, 0.02)
elif severity == 'high':
factor = (0.1, 0.1, 0.1, 0.05)
self.tform = ko.augmentation.ColorJiggle(*factor, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomHue(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
hue = 0.01
elif severity == 'medium':
hue = 0.02
elif severity == 'high':
hue = 0.05
self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomGamma(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
gamma, gain = (0.9, 1.1), (0.9,1.1)
elif severity == 'medium':
gamma, gain = (0.75, 1.25), (0.75,1.25)
elif severity == 'high':
gamma, gain = (0.5, 1.5), (0.5,1.5)
self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomGaussianBlur(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
kernel_size, sigma = 3, (0.1, 1.0)
elif severity == 'medium':
kernel_size, sigma = 5, (0.1, 1.5)
elif severity == 'high':
kernel_size, sigma = 7, (0.1, 2.0)
self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomGaussianNoise(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
std = 0.02
elif severity == 'medium':
std = 0.04
elif severity == 'high':
std = 0.08
self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomMotionBlur(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25)
elif severity == 'medium':
kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5)
elif severity == 'high':
kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0)
self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomPosterize(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
bits = 5
elif severity == 'medium':
bits = 4
elif severity == 'high':
bits = 3
self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class RandomRGBShift(nn.Module):
def __init__(self, severity='medium', p=0.5):
super().__init__()
self.p = p
if severity == 'low':
rgb = 0.02
elif severity == 'medium':
rgb = 0.05
elif severity == 'high':
rgb = 0.1
self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p)
def forward(self, x, **kwargs):
return self.tform(x)
class TransformNet(nn.Module):
def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=True, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5):
super().__init__()
self.n_optional = n_optional
self.p = p
p_flip = 0.5 if flip else 0
rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip)
self.ramp = ramp
self.register_buffer('step0', torch.tensor(0))
assert crop_mode in ['random_crop', 'resized_crop']
if crop_mode == 'random_crop':
rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
elif crop_mode == 'resized_crop':
rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer]
self.optional_transforms = []
if compress:
self.optional_transforms.append(RandomCompress(severity, p=p))
if brightness:
self.optional_transforms.append(RandomBrightness(severity, p=p))
if contrast:
self.optional_transforms.append(RandomContrast(severity, p=p))
if color_jiggle:
self.optional_transforms.append(RandomColorJiggle(severity, p=p))
if gamma:
self.optional_transforms.append(RandomGamma(severity, p=p))
if grayscale:
self.optional_transforms.append(ko.augmentation.RandomGrayscale(p=p/4))
if gaussian_blur:
self.optional_transforms.append(RandomGaussianBlur(severity, p=p))
if gaussian_noise:
self.optional_transforms.append(RandomGaussianNoise(severity, p=p))
if hue:
self.optional_transforms.append(RandomHue(severity, p=p))
if motion_blur:
self.optional_transforms.append(RandomMotionBlur(severity, p=p))
if posterize:
self.optional_transforms.append(RandomPosterize(severity, p=p))
if rgb_shift:
self.optional_transforms.append(RandomRGBShift(severity, p=p))
if saturation:
self.optional_transforms.append(RandomSaturation(severity, p=p))
if sharpness:
self.optional_transforms.append(RandomSharpness(severity, p=p))
if median_blur:
self.optional_transforms.append(RandomMedianBlur(severity, p=p))
def activate(self, global_step):
if self.step0 == 0:
print(f'[TRAINING] Activating TransformNet at step {global_step}')
self.step0 = torch.tensor(global_step)
def is_activated(self):
return self.step0 > 0
def forward(self, x, global_step, p=0.9):
# x: [batch_size, 3, H, W] in range [-1, 1]
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
# fixed transforms
for tform in self.fixed_transforms:
x = tform(x)
if isinstance(x, tuple):
x = x[0]
# optional transforms
ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.])
try:
if len(self.optional_transforms) > 0:
tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy()
for tform_id in tform_ids:
tform = self.optional_transforms[tform_id]
x = tform(x, ramp=ramp)
if isinstance(x, tuple):
x = x[0]
except Exception as e:
print(tform_id, ramp)
import pdb; pdb.set_trace()
return x * 2 - 1 # [0, 1] -> [-1, 1]
if __name__ == '__main__':
pass