File size: 6,284 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# Standard libraries
import itertools
import numpy as np
# PyTorch
import torch
import torch.nn as nn
# Local
from . import JPEG_utils as utils
class y_dequantize(nn.Module):
"""Dequantize Y channel
Inputs:
image(tensor): batch x height x width
factor(float): compression factor
Outputs:
image(tensor): batch x height x width
"""
def __init__(self, factor=1):
super(y_dequantize, self).__init__()
self.y_table = utils.y_table
self.factor = factor
def forward(self, image):
return image * (self.y_table * self.factor)
class c_dequantize(nn.Module):
"""Dequantize CbCr channel
Inputs:
image(tensor): batch x height x width
factor(float): compression factor
Outputs:
image(tensor): batch x height x width
"""
def __init__(self, factor=1):
super(c_dequantize, self).__init__()
self.factor = factor
self.c_table = utils.c_table
def forward(self, image):
return image * (self.c_table * self.factor)
class idct_8x8(nn.Module):
"""Inverse discrete Cosine Transformation
Input:
dcp(tensor): batch x height x width
Output:
image(tensor): batch x height x width
"""
def __init__(self):
super(idct_8x8, self).__init__()
alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
(2 * v + 1) * y * np.pi / 16
)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
def forward(self, image):
image = image * self.alpha
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
result.view(image.shape)
return result
class block_merging(nn.Module):
"""Merge pathces into image
Inputs:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Output:
image(tensor): batch x height x width
"""
def __init__(self):
super(block_merging, self).__init__()
def forward(self, patches, height, width):
k = 8
batch_size = patches.shape[0]
# print(patches.shape) # (1,1024,8,8)
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, height, width)
class chroma_upsampling(nn.Module):
"""Upsample chroma layers
Input:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Ouput:
image(tensor): batch x height x width x 3
"""
def __init__(self):
super(chroma_upsampling, self).__init__()
def forward(self, y, cb, cr):
def repeat(x, k=2):
height, width = x.shape[1:3]
x = x.unsqueeze(-1)
x = x.repeat(1, 1, k, k)
x = x.view(-1, height * k, width * k)
return x
cb = repeat(cb)
cr = repeat(cr)
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
class ycbcr_to_rgb_jpeg(nn.Module):
"""Converts YCbCr image to RGB JPEG
Input:
image(tensor): batch x height x width x 3
Outpput:
result(tensor): batch x 3 x height x width
"""
def __init__(self):
super(ycbcr_to_rgb_jpeg, self).__init__()
matrix = np.array(
[[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
dtype=np.float32,
).T
self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
# result = torch.from_numpy(result)
result.view(image.shape)
return result.permute(0, 3, 1, 2)
class decompress_jpeg(nn.Module):
"""Full JPEG decompression algortihm
Input:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
rounding(function): rounding function to use
factor(float): Compression factor
Ouput:
image(tensor): batch x 3 x height x width
"""
# def __init__(self, height, width, rounding=torch.round, factor=1):
def __init__(self, rounding=torch.round, factor=1):
super(decompress_jpeg, self).__init__()
self.c_dequantize = c_dequantize(factor=factor)
self.y_dequantize = y_dequantize(factor=factor)
self.idct = idct_8x8()
self.merging = block_merging()
# comment this line if no subsampling
self.chroma = chroma_upsampling()
self.colors = ycbcr_to_rgb_jpeg()
# self.height, self.width = height, width
def forward(self, y, cb, cr, height, width):
components = {"y": y, "cb": cb, "cr": cr}
# height = y.shape[0]
# width = y.shape[1]
self.height = height
self.width = width
for k in components.keys():
if k in ("cb", "cr"):
comp = self.c_dequantize(components[k])
# comment this line if no subsampling
height, width = int(self.height / 2), int(self.width / 2)
# height, width = int(self.height), int(self.width)
else:
comp = self.y_dequantize(components[k])
# comment this line if no subsampling
height, width = self.height, self.width
comp = self.idct(comp)
components[k] = self.merging(comp, height, width)
#
# comment this line if no subsampling
image = self.chroma(components["y"], components["cb"], components["cr"])
# image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3)
image = self.colors(image)
image = torch.min(
255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)
)
return image / 255
|