import cv2 import itertools import numpy as np import random import torch import torch.nn.functional as F import torch.nn as nn from PIL import Image, ImageOps import matplotlib.pyplot as plt def random_blur_kernel(probs, N_blur, sigrange_gauss, sigrange_line, wmin_line): N = N_blur coords = torch.from_numpy(np.stack(np.meshgrid(range(N_blur), range(N_blur), indexing='ij'), axis=-1)) - (0.5 * (N-1)) # (7,7,2) manhat = torch.sum(torch.abs(coords), dim=-1) # (7, 7) # nothing, default vals_nothing = (manhat < 0.5).float() # (7, 7) # gauss sig_gauss = torch.rand(1)[0] * (sigrange_gauss[1] - sigrange_gauss[0]) + sigrange_gauss[0] vals_gauss = torch.exp(-torch.sum(coords ** 2, dim=-1) /2. / sig_gauss ** 2) # line theta = torch.rand(1)[0] * 2.* np.pi v = torch.FloatTensor([torch.cos(theta), torch.sin(theta)]) # (2) dists = torch.sum(coords * v, dim=-1) # (7, 7) sig_line = torch.rand(1)[0] * (sigrange_line[1] - sigrange_line[0]) + sigrange_line[0] w_line = torch.rand(1)[0] * (0.5 * (N-1) + 0.1 - wmin_line) + wmin_line vals_line = torch.exp(-dists ** 2 / 2. / sig_line ** 2) * (manhat < w_line) # (7, 7) t = torch.rand(1)[0] vals = vals_nothing if t < (probs[0] + probs[1]): vals = vals_line else: vals = vals if t < probs[0]: vals = vals_gauss else: vals = vals v = vals / torch.sum(vals) # 归一化 (7, 7) z = torch.zeros_like(v) f = torch.stack([v,z,z, z,v,z, z,z,v], dim=0).reshape([3, 3, N, N]) return f def get_rand_transform_matrix(image_size, d, batch_size): Ms = np.zeros((batch_size, 2, 3, 3)) for i in range(batch_size): tl_x = random.uniform(-d, d) # Top left corner, top tl_y = random.uniform(-d, d) # Top left corner, left bl_x = random.uniform(-d, d) # Bot left corner, bot bl_y = random.uniform(-d, d) # Bot left corner, left tr_x = random.uniform(-d, d) # Top right corner, top tr_y = random.uniform(-d, d) # Top right corner, right br_x = random.uniform(-d, d) # Bot right corner, bot br_y = random.uniform(-d, d) # Bot right corner, right rect = np.array([ [tl_x, tl_y], [tr_x + image_size, tr_y], [br_x + image_size, br_y + image_size], [bl_x, bl_y + image_size]], dtype = "float32") dst = np.array([ [0, 0], [image_size, 0], [image_size, image_size], [0, image_size]], dtype = "float32") M = cv2.getPerspectiveTransform(rect, dst) M_inv = np.linalg.inv(M) Ms[i, 0, :, :] = M_inv Ms[i, 1, :, :] = M Ms = torch.from_numpy(Ms).float() return Ms def get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size): rnd_hue = torch.FloatTensor(batch_size, 3, 1, 1).uniform_(-rnd_hue, rnd_hue) rnd_brightness = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(-rnd_bri, rnd_bri) return rnd_hue + rnd_brightness # reference: https://github.com/mlomnitz/DiffJPEG.git y_table = np.array( [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], dtype=np.float32).T y_table = nn.Parameter(torch.from_numpy(y_table)) c_table = np.empty((8, 8), dtype=np.float32) c_table.fill(99) c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T c_table = nn.Parameter(torch.from_numpy(c_table)) # 1. RGB -> YCbCr class rgb_to_ycbcr_jpeg(nn.Module): """ Converts RGB image to YCbCr Input: image(tensor): batch x 3 x height x width Outpput: result(tensor): batch x height x width x 3 """ def __init__(self): super(rgb_to_ycbcr_jpeg, self).__init__() matrix = np.array( [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], dtype=np.float32).T self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) self.matrix = nn.Parameter(torch.from_numpy(matrix)) def forward(self, image): image = image.permute(0, 2, 3, 1) result = torch.tensordot(image, self.matrix, dims=1) + self.shift result.view(image.shape) return result # 2. Chroma subsampling class chroma_subsampling(nn.Module): """ Chroma subsampling on CbCv channels Input: image(tensor): batch x height x width x 3 Output: y(tensor): batch x height x width cb(tensor): batch x height/2 x width/2 cr(tensor): batch x height/2 x width/2 """ def __init__(self): super(chroma_subsampling, self).__init__() def forward(self, image): image_2 = image.permute(0, 3, 1, 2).clone() avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), count_include_pad=False) cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) cb = cb.permute(0, 2, 3, 1) cr = cr.permute(0, 2, 3, 1) return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) # 3. Block splitting class block_splitting(nn.Module): """ Splitting image into patches Input: image(tensor): batch x height x width Output: patch(tensor): batch x h*w/64 x h x w """ def __init__(self): super(block_splitting, self).__init__() self.k = 8 def forward(self, image): height, width = image.shape[1:3] batch_size = image.shape[0] image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) # 4. DCT class dct_8x8(nn.Module): """ Discrete Cosine Transformation Input: image(tensor): batch x height x width Output: dcp(tensor): batch x height x width """ def __init__(self): super(dct_8x8, self).__init__() tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) for x, y, u, v in itertools.product(range(8), repeat=4): tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( (2 * y + 1) * v * np.pi / 16) alpha = np.array([1. / np.sqrt(2)] + [1] * 7) # self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() ) def forward(self, image): image = image - 128 result = self.scale * torch.tensordot(image, self.tensor, dims=2) result.view(image.shape) return result # 5. Quantization class y_quantize(nn.Module): """ JPEG Quantization for Y channel Input: image(tensor): batch x height x width rounding(function): rounding function to use factor(float): Degree of compression Output: image(tensor): batch x height x width """ def __init__(self, rounding, factor=1): super(y_quantize, self).__init__() self.rounding = rounding self.factor = factor self.y_table = y_table def forward(self, image): image = image.float() / (self.y_table * self.factor) image = self.rounding(image) return image class c_quantize(nn.Module): """ JPEG Quantization for CrCb channels Input: image(tensor): batch x height x width rounding(function): rounding function to use factor(float): Degree of compression Output: image(tensor): batch x height x width """ def __init__(self, rounding, factor=1): super(c_quantize, self).__init__() self.rounding = rounding self.factor = factor self.c_table = c_table def forward(self, image): image = image.float() / (self.c_table * self.factor) image = self.rounding(image) return image class compress_jpeg(nn.Module): """ Full JPEG compression algortihm Input: imgs(tensor): batch x 3 x height x width rounding(function): rounding function to use factor(float): Compression factor Ouput: compressed(dict(tensor)): batch x h*w/64 x 8 x 8 """ def __init__(self, rounding=torch.round, factor=1): super(compress_jpeg, self).__init__() self.l1 = nn.Sequential( rgb_to_ycbcr_jpeg(), chroma_subsampling() ) self.l2 = nn.Sequential( block_splitting(), dct_8x8() ) self.c_quantize = c_quantize(rounding=rounding, factor=factor) self.y_quantize = y_quantize(rounding=rounding, factor=factor) def forward(self, image): y, cb, cr = self.l1(image*255) components = {'y': y, 'cb': cb, 'cr': cr} for k in components.keys(): comp = self.l2(components[k]) if k in ('cb', 'cr'): comp = self.c_quantize(comp) else: comp = self.y_quantize(comp) components[k] = comp return components['y'], components['cb'], components['cr'] # -5. Dequantization class y_dequantize(nn.Module): """ Dequantize Y channel Inputs: image(tensor): batch x height x width factor(float): compression factor Outputs: image(tensor): batch x height x width """ def __init__(self, factor=1): super(y_dequantize, self).__init__() self.y_table = y_table self.factor = factor def forward(self, image): return image * (self.y_table * self.factor) class c_dequantize(nn.Module): """ Dequantize CbCr channel Inputs: image(tensor): batch x height x width factor(float): compression factor Outputs: image(tensor): batch x height x width """ def __init__(self, factor=1): super(c_dequantize, self).__init__() self.factor = factor self.c_table = c_table def forward(self, image): return image * (self.c_table * self.factor) # -4. Inverse DCT class idct_8x8(nn.Module): """ Inverse discrete Cosine Transformation Input: dcp(tensor): batch x height x width Output: image(tensor): batch x height x width """ def __init__(self): super(idct_8x8, self).__init__() alpha = np.array([1. / np.sqrt(2)] + [1] * 7) self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) for x, y, u, v in itertools.product(range(8), repeat=4): tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( (2 * v + 1) * y * np.pi / 16) self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) def forward(self, image): image = image * self.alpha result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 result.view(image.shape) return result # -3. Block joining class block_merging(nn.Module): """ Merge pathces into image Inputs: patches(tensor) batch x height*width/64, height x width height(int) width(int) Output: image(tensor): batch x height x width """ def __init__(self): super(block_merging, self).__init__() def forward(self, patches, height, width): k = 8 batch_size = patches.shape[0] image_reshaped = patches.view(batch_size, height//k, width//k, k, k) image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) return image_transposed.contiguous().view(batch_size, height, width) # -2. Chroma upsampling class chroma_upsampling(nn.Module): """ Upsample chroma layers Input: y(tensor): y channel image cb(tensor): cb channel cr(tensor): cr channel Ouput: image(tensor): batch x height x width x 3 """ def __init__(self): super(chroma_upsampling, self).__init__() def forward(self, y, cb, cr): def repeat(x, k=2): height, width = x.shape[1:3] x = x.unsqueeze(-1) x = x.repeat(1, 1, k, k) x = x.view(-1, height * k, width * k) return x cb = repeat(cb) cr = repeat(cr) return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) # -1: YCbCr -> RGB class ycbcr_to_rgb_jpeg(nn.Module): """ Converts YCbCr image to RGB JPEG Input: image(tensor): batch x height x width x 3 Outpput: result(tensor): batch x 3 x height x width """ def __init__(self): super(ycbcr_to_rgb_jpeg, self).__init__() matrix = np.array( [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) self.matrix = nn.Parameter(torch.from_numpy(matrix)) def forward(self, image): result = torch.tensordot(image + self.shift, self.matrix, dims=1) result.view(image.shape) return result.permute(0, 3, 1, 2) class decompress_jpeg(nn.Module): """ Full JPEG decompression algortihm Input: compressed(dict(tensor)): batch x h*w/64 x 8 x 8 rounding(function): rounding function to use factor(float): Compression factor Ouput: image(tensor): batch x 3 x height x width """ def __init__(self, height, width, rounding=torch.round, factor=1): super(decompress_jpeg, self).__init__() self.c_dequantize = c_dequantize(factor=factor) self.y_dequantize = y_dequantize(factor=factor) self.idct = idct_8x8() self.merging = block_merging() self.chroma = chroma_upsampling() self.colors = ycbcr_to_rgb_jpeg() self.height, self.width = height, width def forward(self, y, cb, cr): components = {'y': y, 'cb': cb, 'cr': cr} for k in components.keys(): if k in ('cb', 'cr'): comp = self.c_dequantize(components[k]) height, width = int(self.height/2), int(self.width/2) else: comp = self.y_dequantize(components[k]) height, width = self.height, self.width comp = self.idct(comp) components[k] = self.merging(comp, height, width) # image = self.chroma(components['y'], components['cb'], components['cr']) image = self.colors(image) image = torch.min(255*torch.ones_like(image), torch.max(torch.zeros_like(image), image)) return image/255 def diff_round(x): """ Differentiable rounding function Input: x(tensor) Output: x(tensor) """ return torch.round(x) + (x - torch.round(x))**3 def round_only_at_0(x): cond = (torch.abs(x) < 0.5).float() return cond * (x ** 3) + (1 - cond) * x def quality_to_factor(quality): """ Calculate factor corresponding to quality Input: quality(float): Quality for jpeg compression Output: factor(float): Compression factor """ if quality < 50: quality = 5000. / quality else: quality = 200. - quality*2 return quality / 100. def jpeg_compress_decompress(image, # downsample_c=True, rounding=round_only_at_0, quality=80): # image_r = image * 255 height, width = image.shape[2:4] # orig_height, orig_width = height, width # if height % 16 != 0 or width % 16 != 0: # # Round up to next multiple of 16 # height = ((height - 1) // 16 + 1) * 16 # width = ((width - 1) // 16 + 1) * 16 # vpad = height - orig_height # wpad = width - orig_width # top = vpad // 2 # bottom = vpad - top # left = wpad // 2 # right = wpad - left # #image = tf.pad(image, [[0, 0], [top, bottom], [left, right], [0, 0]], 'SYMMETRIC') # image = torch.pad(image, [[0, 0], [0, vpad], [0, wpad], [0, 0]], 'reflect') factor = quality_to_factor(quality) compress = compress_jpeg(rounding=rounding, factor=factor).to(image.device) decompress = decompress_jpeg(height, width, rounding=rounding, factor=factor).to(image.device) y, cb, cr = compress(image) recovered = decompress(y, cb, cr) return recovered.contiguous() if __name__ == '__main__': ''' test JPEG compress and decompress''' # img = Image.open('house.jpg') # img = np.array(img) / 255. # img_r = np.transpose(img, [2, 0, 1]) # img_tensor = torch.from_numpy(img_r).unsqueeze(0).float() # recover = jpeg_compress_decompress(img_tensor) # recover_arr = recover.detach().squeeze(0).numpy() # recover_arr = np.transpose(recover_arr, [1, 2, 0]) # plt.subplot(121) # plt.imshow(img) # plt.subplot(122) # plt.imshow(recover_arr) # plt.show() ''' test blur ''' # blur img = Image.open('house.jpg') img = np.array(img) / 255. img_r = np.transpose(img, [2, 0, 1]) img_tensor = torch.from_numpy(img_r).unsqueeze(0).float() print(img_tensor.shape) N_blur=7 f = random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.], wmin_line=3) # print(f.shape) # print(type(f)) encoded_image = F.conv2d(img_tensor, f, bias=None, padding=int((N_blur-1)/2)) encoded_image = encoded_image.detach().squeeze(0).numpy() encoded_image = np.transpose(encoded_image, [1, 2, 0]) plt.subplot(121) plt.imshow(img) plt.subplot(122) plt.imshow(encoded_image) plt.show()