File size: 5,502 Bytes
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from . import utils
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from tools.augment_imagenetc import RandomImagenetC
from PIL import Image
import kornia as ko
# from kornia.augmentation import RandomHorizontalFlip, RandomCrop


class TransformNet(nn.Module):
    def __init__(self, rnd_bri=0.3, rnd_hue=0.1, do_jpeg=False, jpeg_quality=50, rnd_noise=0.02, rnd_sat=1.0, rnd_trans=0.1,contrast=[0.5, 1.5], rnd_flip=False, ramp=1000, imagenetc_level=0, crop_mode='crop') -> None:
        super().__init__()
        self.rnd_bri = rnd_bri
        self.rnd_hue = rnd_hue
        self.jpeg_quality = jpeg_quality
        self.rnd_noise = rnd_noise
        self.rnd_sat = rnd_sat
        self.rnd_trans = rnd_trans
        self.contrast_low, self.contrast_high = contrast
        self.do_jpeg = do_jpeg
        p_flip = 0.5 if rnd_flip else 0
        self.rnd_flip = ko.augmentation.RandomHorizontalFlip(p_flip)
        self.ramp = ramp
        self.register_buffer('step0', torch.tensor(0))  # large number
        assert crop_mode in ['crop', 'resized_crop']
        if crop_mode == 'crop':
            self.rnd_crop = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
        elif crop_mode == 'resized_crop':
            self.rnd_crop = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
        if imagenetc_level > 0:
            self.imagenetc = ImagenetCTransform(max_severity=imagenetc_level)
    
    def activate(self, global_step):
        if self.step0 == 0:
            print(f'[TRAINING] Activating TransformNet at step {global_step}')
            self.step0 = torch.tensor(global_step)
    
    def is_activated(self):
        return self.step0 > 0
    
    def forward(self, x, global_step, p=0.9):
        # x: [batch_size, 3, H, W] in range [-1, 1]
        x = x * 0.5 + 0.5  # [-1, 1] -> [0, 1]

        # flip
        x = self.rnd_flip(x)
        # random crop
        x = self.rnd_crop(x)
        if isinstance(x, tuple):
            x = x[0]  # weird bug in kornia 0.6.0 that returns transform matrix occasionally
        
        if torch.rand(1)[0] >= p:
            return x * 2 - 1  # [0, 1] -> [-1, 1]
        if hasattr(self, 'imagenetc') and torch.rand(1)[0] < 0.5:
            x = self.imagenetc(x * 2 - 1)  # [0, 1] -> [-1, 1])
            return x
        
        batch_size, sh, device = x.shape[0], x.size(), x.device
        # x0 = x.clone().detach()
        ramp_fn = lambda ramp: np.min([(global_step-self.step0.cpu().item()) / ramp, 1.])

        rnd_bri = ramp_fn(self.ramp) * self.rnd_bri
        rnd_hue = ramp_fn(self.ramp) * self.rnd_hue
        rnd_brightness = utils.get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size).to(device)  # [batch_size, 3, 1, 1]
        rnd_noise = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_noise

        contrast_low = 1. - (1. - self.contrast_low) * ramp_fn(self.ramp)
        contrast_high = 1. + (self.contrast_high - 1.) * ramp_fn(self.ramp)
        contrast_params = [contrast_low, contrast_high]

        # blur
        N_blur = 7
        f = utils.random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.],
                                    wmin_line=3).to(device)
        x = F.conv2d(x, f, bias=None, padding=int((N_blur - 1) / 2))

        # noise
        noise = torch.normal(mean=0, std=rnd_noise, size=x.size(), dtype=torch.float32).to(device)
        x = x + noise
        x = torch.clamp(x, 0, 1)

        # contrast & brightness
        contrast_scale = torch.Tensor(x.size()[0]).uniform_(contrast_params[0], contrast_params[1])
        contrast_scale = contrast_scale.reshape(x.size()[0], 1, 1, 1).to(device)
        x = x * contrast_scale
        x = x + rnd_brightness
        x = torch.clamp(x, 0, 1)

        # saturation
        # rnd_sat = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_sat
        # sat_weight = torch.FloatTensor([.3, .6, .1]).reshape(1, 3, 1, 1).to(device)
        # encoded_image_lum = torch.mean(x * sat_weight, dim=1).unsqueeze_(1)
        # x = (1 - rnd_sat) * x + rnd_sat * encoded_image_lum
        rnd_sat = (torch.rand(1)[0]*2.0 - 1.0)*ramp_fn(self.ramp) * self.rnd_sat + 1.0
        x = ko.enhance.adjust.adjust_saturation(x, rnd_sat)
        
        # jpeg
        x = x.reshape(sh)
        if self.do_jpeg:
            jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.ramp) * (100. - self.jpeg_quality)
            x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
        
        x = x * 2 - 1  # [0, 1] -> [-1, 1]
        return x


class ImagenetCTransform(nn.Module):
    def __init__(self, max_severity=5) -> None:
        super().__init__()
        self.max_severity = max_severity
        self.tform = RandomImagenetC(max_severity=max_severity, phase='train')
    
    def forward(self, x):
        # x: [batch_size, 3, H, W] in range [-1, 1]
        img0 = x.detach().cpu().numpy()
        img = img0 * 127.5 + 127.5  # [-1, 1] -> [0, 255]
        img = img.transpose(0, 2, 3, 1).astype(np.uint8)
        img = [Image.fromarray(i) for i in img]
        img = [self.tform(i) for i in img]
        img = np.array([np.array(i) for i in img], dtype=np.float32)
        img = img.transpose(0, 3, 1, 2) / 127.5 - 1.  # [0, 255] -> [-1, 1]
        residual = torch.from_numpy(img - img0).to(x.device)
        x = x + residual
        return x