test / cldm /transformations.py
Tu Bui
first commit
6142a25
import os
from . import utils
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from tools.augment_imagenetc import RandomImagenetC
from PIL import Image
import kornia as ko
# from kornia.augmentation import RandomHorizontalFlip, RandomCrop
class TransformNet(nn.Module):
def __init__(self, rnd_bri=0.3, rnd_hue=0.1, do_jpeg=False, jpeg_quality=50, rnd_noise=0.02, rnd_sat=1.0, rnd_trans=0.1,contrast=[0.5, 1.5], rnd_flip=False, ramp=1000, imagenetc_level=0, crop_mode='crop') -> None:
super().__init__()
self.rnd_bri = rnd_bri
self.rnd_hue = rnd_hue
self.jpeg_quality = jpeg_quality
self.rnd_noise = rnd_noise
self.rnd_sat = rnd_sat
self.rnd_trans = rnd_trans
self.contrast_low, self.contrast_high = contrast
self.do_jpeg = do_jpeg
p_flip = 0.5 if rnd_flip else 0
self.rnd_flip = ko.augmentation.RandomHorizontalFlip(p_flip)
self.ramp = ramp
self.register_buffer('step0', torch.tensor(0)) # large number
assert crop_mode in ['crop', 'resized_crop']
if crop_mode == 'crop':
self.rnd_crop = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
elif crop_mode == 'resized_crop':
self.rnd_crop = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
if imagenetc_level > 0:
self.imagenetc = ImagenetCTransform(max_severity=imagenetc_level)
def activate(self, global_step):
if self.step0 == 0:
print(f'[TRAINING] Activating TransformNet at step {global_step}')
self.step0 = torch.tensor(global_step)
def is_activated(self):
return self.step0 > 0
def forward(self, x, global_step, p=0.9):
# x: [batch_size, 3, H, W] in range [-1, 1]
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
# flip
x = self.rnd_flip(x)
# random crop
x = self.rnd_crop(x)
if isinstance(x, tuple):
x = x[0] # weird bug in kornia 0.6.0 that returns transform matrix occasionally
if torch.rand(1)[0] >= p:
return x * 2 - 1 # [0, 1] -> [-1, 1]
if hasattr(self, 'imagenetc') and torch.rand(1)[0] < 0.5:
x = self.imagenetc(x * 2 - 1) # [0, 1] -> [-1, 1])
return x
batch_size, sh, device = x.shape[0], x.size(), x.device
# x0 = x.clone().detach()
ramp_fn = lambda ramp: np.min([(global_step-self.step0.cpu().item()) / ramp, 1.])
rnd_bri = ramp_fn(self.ramp) * self.rnd_bri
rnd_hue = ramp_fn(self.ramp) * self.rnd_hue
rnd_brightness = utils.get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size).to(device) # [batch_size, 3, 1, 1]
rnd_noise = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_noise
contrast_low = 1. - (1. - self.contrast_low) * ramp_fn(self.ramp)
contrast_high = 1. + (self.contrast_high - 1.) * ramp_fn(self.ramp)
contrast_params = [contrast_low, contrast_high]
# blur
N_blur = 7
f = utils.random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.],
wmin_line=3).to(device)
x = F.conv2d(x, f, bias=None, padding=int((N_blur - 1) / 2))
# noise
noise = torch.normal(mean=0, std=rnd_noise, size=x.size(), dtype=torch.float32).to(device)
x = x + noise
x = torch.clamp(x, 0, 1)
# contrast & brightness
contrast_scale = torch.Tensor(x.size()[0]).uniform_(contrast_params[0], contrast_params[1])
contrast_scale = contrast_scale.reshape(x.size()[0], 1, 1, 1).to(device)
x = x * contrast_scale
x = x + rnd_brightness
x = torch.clamp(x, 0, 1)
# saturation
# rnd_sat = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_sat
# sat_weight = torch.FloatTensor([.3, .6, .1]).reshape(1, 3, 1, 1).to(device)
# encoded_image_lum = torch.mean(x * sat_weight, dim=1).unsqueeze_(1)
# x = (1 - rnd_sat) * x + rnd_sat * encoded_image_lum
rnd_sat = (torch.rand(1)[0]*2.0 - 1.0)*ramp_fn(self.ramp) * self.rnd_sat + 1.0
x = ko.enhance.adjust.adjust_saturation(x, rnd_sat)
# jpeg
x = x.reshape(sh)
if self.do_jpeg:
jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.ramp) * (100. - self.jpeg_quality)
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
x = x * 2 - 1 # [0, 1] -> [-1, 1]
return x
class ImagenetCTransform(nn.Module):
def __init__(self, max_severity=5) -> None:
super().__init__()
self.max_severity = max_severity
self.tform = RandomImagenetC(max_severity=max_severity, phase='train')
def forward(self, x):
# x: [batch_size, 3, H, W] in range [-1, 1]
img0 = x.detach().cpu().numpy()
img = img0 * 127.5 + 127.5 # [-1, 1] -> [0, 255]
img = img.transpose(0, 2, 3, 1).astype(np.uint8)
img = [Image.fromarray(i) for i in img]
img = [self.tform(i) for i in img]
img = np.array([np.array(i) for i in img], dtype=np.float32)
img = img.transpose(0, 3, 1, 2) / 127.5 - 1. # [0, 255] -> [-1, 1]
residual = torch.from_numpy(img - img0).to(x.device)
x = x + residual
return x