|
import os |
|
from numpy.lib.function_base import interp |
|
import torch |
|
import torch.nn as nn |
|
if not os.path.exists('README.md'): |
|
os.chdir('..') |
|
|
|
from processing.pipeline_numpy import processing as default_processing |
|
from utils.base import np2torch, torch2np |
|
|
|
import segmentation_models_pytorch as smp |
|
|
|
K_G = torch.Tensor([[0, 1, 0], |
|
[1, 4, 1], |
|
[0, 1, 0]]) / 4 |
|
|
|
K_RB = torch.Tensor([[1, 2, 1], |
|
[2, 4, 2], |
|
[1, 2, 1]]) / 4 |
|
|
|
M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114], |
|
[-0.14714119, -0.28886916, 0.43601035], |
|
[0.61497538, -0.51496512, -0.10001026]]) |
|
M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00], |
|
[1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01], |
|
[1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]]) |
|
|
|
K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08], |
|
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05], |
|
[2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04], |
|
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05], |
|
[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]]) |
|
K_SHARP = torch.Tensor([[0, -1, 0], |
|
[-1, 5, -1], |
|
[0, -1, 0]]) |
|
DEFAULT_CAMERA_PARAMS = ( |
|
[0., 0., 0., 0.], |
|
[1., 1., 1.], |
|
[1., 0., 0., 0., 1., 0., 0., 0., 1.], |
|
) |
|
|
|
|
|
class RawToRGB(nn.Module): |
|
"""transforms a raw image with 1 channel to rgb with 3 channels |
|
|
|
Args: |
|
reduce_size (bool, optional): if False, the output image will have the same height and width |
|
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros. |
|
if True, the output dimensions are reduced by half (B, C, H//2, W//2), |
|
out_channels (int, optional): number of output channels. One of {3, 4}. |
|
for 3 channels, the two green channels are averaged. |
|
track_stages (bool, optional): whether or not to retain intermediary steps in processing |
|
normalize_mosaic (function, optional): applies normalization transformation to rgb image |
|
""" |
|
|
|
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None): |
|
super().__init__() |
|
self.stages = None |
|
self.buffer = None |
|
self.reduce_size = reduce_size |
|
self.out_channels = out_channels |
|
self.track_stages = track_stages |
|
self.normalize_mosaic = normalize_mosaic |
|
|
|
def forward(self, raw): |
|
self.stages = {} |
|
self.buffer = {} |
|
|
|
rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels) |
|
self.stages['demosaic'] = rgb |
|
if self.normalize_mosaic: |
|
rgb = self.normalize_mosaic(rgb) |
|
|
|
if self.track_stages and raw.requires_grad: |
|
for stage in self.stages.values(): |
|
stage.retain_grad() |
|
|
|
self.buffer['processed_rgb'] = rgb |
|
|
|
return rgb |
|
|
|
|
|
class NNProcessing(nn.Module): |
|
"""Transforms raw images to processed rgb via a segmentation Unet |
|
|
|
Args: |
|
track_stages (bool, optional): whether or not to retain intermediary steps in processing |
|
normalize_mosaic (function, optional): applies normalization transformation to rgb image |
|
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing |
|
""" |
|
|
|
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True): |
|
super().__init__() |
|
self.stages = None |
|
self.buffer = None |
|
self.track_stages = track_stages |
|
self.model = smp.UnetPlusPlus( |
|
encoder_name='resnet34', |
|
encoder_depth=3, |
|
decoder_channels=[256, 128, 64], |
|
in_channels=3, |
|
classes=3, |
|
) |
|
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False) |
|
self.normalize_mosaic = normalize_mosaic |
|
|
|
def forward(self, raw): |
|
self.stages = {} |
|
self.buffer = {} |
|
|
|
rgb = raw2rgb(raw) |
|
if self.normalize_mosaic: |
|
rgb = self.normalize_mosaic(rgb) |
|
self.stages['demosaic'] = rgb |
|
rgb = self.model(rgb) |
|
if self.batch_norm is not None: |
|
rgb = self.batch_norm(rgb) |
|
self.stages['rgb'] = rgb |
|
|
|
if self.track_stages and raw.requires_grad: |
|
for stage in self.stages.values(): |
|
stage.retain_grad() |
|
|
|
self.buffer['processed_rgb'] = rgb |
|
|
|
return rgb |
|
|
|
|
|
def append_additive_layer(processor): |
|
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256))) |
|
|
|
|
|
|
|
class ParametrizedProcessing(nn.Module): |
|
"""Differentiable processing pipeline via torch transformations |
|
|
|
Args: |
|
camera_parameters (tuple(list), optional): applies given camera parameters in processing |
|
track_stages (bool, optional): whether or not to retain intermediary steps in processing |
|
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing |
|
""" |
|
|
|
def __init__(self, camera_parameters=None, track_stages=False, batch_norm_output=True): |
|
super().__init__() |
|
self.stages = None |
|
self.buffer = None |
|
self.track_stages = track_stages |
|
|
|
if camera_parameters is None: |
|
camera_parameters = DEFAULT_CAMERA_PARAMS |
|
|
|
black_level, white_balance, colour_matrix = camera_parameters |
|
|
|
self.black_level = nn.Parameter(torch.as_tensor(black_level)) |
|
self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3)) |
|
self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3)) |
|
|
|
self.gamma_correct = nn.Parameter(torch.Tensor([2.2])) |
|
|
|
self.debayer = Debayer() |
|
|
|
self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) |
|
self.sharpening_filter.weight.data[0][0] = K_SHARP.clone() |
|
|
|
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False) |
|
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone() |
|
|
|
self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None |
|
|
|
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone()) |
|
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone()) |
|
|
|
self.additive_layer = None |
|
|
|
def forward(self, raw): |
|
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}" |
|
|
|
self.stages = {} |
|
self.buffer = {} |
|
|
|
|
|
|
|
rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False) |
|
rgb = rgb.contiguous() |
|
self.stages['demosaic'] = rgb |
|
|
|
rgb = self.debayer(rgb) |
|
|
|
|
|
rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous() |
|
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous() |
|
self.stages['color_correct'] = rgb |
|
|
|
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous() |
|
yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...]) |
|
|
|
if self.track_stages: |
|
rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous() |
|
self.stages['sharpening'] = rgb |
|
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous() |
|
|
|
yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...]) |
|
rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous() |
|
self.stages['gaussian'] = rgb |
|
|
|
rgb = torch.clip(rgb, 1e-5, 1) |
|
self.stages['clipped'] = rgb |
|
|
|
rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb)) |
|
self.stages['gamma_correct'] = rgb |
|
|
|
if self.additive_layer is not None: |
|
rgb = rgb + self.additive_layer |
|
self.stages['noise'] = rgb |
|
|
|
if self.batch_norm is not None: |
|
rgb = self.batch_norm(rgb) |
|
|
|
if self.track_stages and raw.requires_grad: |
|
for stage in self.stages.values(): |
|
stage.retain_grad() |
|
|
|
self.buffer['processed_rgb'] = rgb |
|
|
|
return rgb |
|
|
|
|
|
class Debayer(nn.Conv2d): |
|
"""Separates the mosaiced raw image into its channels and interpolates bilinearly. Output is of same size as input. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) |
|
self.weight.data.fill_(0) |
|
self.weight.data[0, 0] = K_RB.clone() |
|
self.weight.data[1, 1] = K_G.clone() |
|
self.weight.data[2, 2] = K_RB.clone() |
|
|
|
|
|
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3): |
|
"""Transforms a raw image with 1 channel to rgb with 3 channels |
|
|
|
Args: |
|
raw (Tensor): raw Tensor of shape (B, H, W) |
|
black_level (iterable, optional): RGGB black level values to subtract |
|
reduce_size (bool, optional): if False, the output image will have the same height and width |
|
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros. |
|
if True, the output dimensions are reduced by half (B, C, H//2, W//2), |
|
out_channels (int, optional): number of output channels. One of {3, 4}. |
|
The two green channels are averaged if out_channels == 3. |
|
""" |
|
assert out_channels in [3, 4] |
|
if black_level is None: |
|
black_level = [0, 0, 0, 0] |
|
Bch, H, W = raw.shape |
|
R = raw[:, 0::2, 0::2] - black_level[0] |
|
G1 = raw[:, 0::2, 1::2] - black_level[1] |
|
G2 = raw[:, 1::2, 0::2] - black_level[2] |
|
B = raw[:, 1::2, 1::2] - black_level[3] |
|
if reduce_size: |
|
rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device) |
|
if out_channels == 3: |
|
rgb[:, 0, :, :] = R |
|
rgb[:, 1, :, :] = (G1 + G2) / 2 |
|
rgb[:, 2, :, :] = B |
|
elif out_channels == 4: |
|
rgb[:, 0, :, :] = R |
|
rgb[:, 1, :, :] = G1 |
|
rgb[:, 2, :, :] = G2 |
|
rgb[:, 3, :, :] = B |
|
else: |
|
rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device) |
|
if out_channels == 3: |
|
rgb[:, 0, 0::2, 0::2] = R |
|
rgb[:, 1, 0::2, 1::2] = G1 |
|
rgb[:, 1, 1::2, 0::2] = G2 |
|
rgb[:, 2, 1::2, 1::2] = B |
|
elif out_channels == 4: |
|
rgb[:, 0, 0::2, 0::2] = R |
|
rgb[:, 1, 0::2, 1::2] = G1 |
|
rgb[:, 2, 1::2, 0::2] = G2 |
|
rgb[:, 3, 1::2, 1::2] = B |
|
return rgb |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import torch |
|
import numpy as np |
|
|
|
if not os.path.exists('README.md'): |
|
os.chdir('..') |
|
|
|
import matplotlib.pyplot as plt |
|
from dataset import get_dataset |
|
from utils.base import np2torch, torch2np |
|
|
|
from utils.debug import debug |
|
from processing.pipeline_numpy import processing as default_processing |
|
|
|
raw_dataset = get_dataset('DS') |
|
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1) |
|
batch_raw, batch_mask = next(iter(loader)) |
|
|
|
|
|
camera_parameters = raw_dataset.camera_parameters |
|
black_level = camera_parameters[0] |
|
|
|
proc = ParametrizedProcessing(camera_parameters) |
|
|
|
batch_rgb = proc(batch_raw) |
|
rgb = batch_rgb[0] |
|
|
|
|
|
raw_img = batch_raw[0] |
|
numpy_raw = torch2np(raw_img) |
|
|
|
default_rgb = default_processing(numpy_raw, *camera_parameters, |
|
sharpening='sharpening_filter', denoising='gaussian_denoising') |
|
|
|
rgb_valid = np2torch(default_rgb) |
|
|
|
print("pipeline norm difference:", (rgb - rgb_valid).norm().item()) |
|
|
|
rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze() |
|
rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze() |
|
|
|
plt.figure(figsize=(16, 8)) |
|
plt.subplot(151) |
|
plt.title('Raw') |
|
plt.imshow(torch2np(raw_img)) |
|
plt.subplot(152) |
|
plt.title('RGB Mosaic') |
|
plt.imshow(torch2np(rgb_mosaic)) |
|
plt.subplot(153) |
|
plt.title('RGB Reduced') |
|
plt.imshow(torch2np(rgb_reduced)) |
|
plt.subplot(154) |
|
plt.title('Torch Pipeline') |
|
plt.imshow(torch2np(rgb)) |
|
plt.subplot(155) |
|
plt.title('Default Pipeline') |
|
plt.imshow(torch2np(rgb_valid)) |
|
plt.show() |
|
|
|
|
|
|