import os |
from numpy.lib.function_base import interp |
import torch |
import torch.nn as nn |
if not os.path.exists('README.md'): |
os.chdir('..') |
from processing.pipeline_numpy import processing as default_processing |
from utils.base import np2torch, torch2np |
import segmentation_models_pytorch as smp |
K_G = torch.Tensor([[0, 1, 0], |
[1, 4, 1], |
[0, 1, 0]]) / 4 |
K_RB = torch.Tensor([[1, 2, 1], |
[2, 4, 2], |
[1, 2, 1]]) / 4 |
M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114], |
[-0.14714119, -0.28886916, 0.43601035], |
[0.61497538, -0.51496512, -0.10001026]]) |
M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00], |
[1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01], |
[1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]]) |
K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08], |
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05], |
[2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04], |
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05], |
[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]]) |
K_SHARP = torch.Tensor([[0, -1, 0], |
[-1, 5, -1], |
[0, -1, 0]]) |
[0., 0., 0., 0.], |
[1., 1., 1.], |
[1., 0., 0., 0., 1., 0., 0., 0., 1.], |
) |
class RawToRGB(nn.Module): |
"""transforms a raw image with 1 channel to rgb with 3 channels |
Args: |
reduce_size (bool, optional): if False, the output image will have the same height and width |
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros. |
if True, the output dimensions are reduced by half (B, C, H//2, W//2), |
out_channels (int, optional): number of output channels. One of {3, 4}. |
for 3 channels, the two green channels are averaged. |
track_stages (bool, optional): whether or not to retain intermediary steps in processing |
normalize_mosaic (function, optional): applies normalization transformation to rgb image |
""" |
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None): |
super().__init__() |
self.stages = None |
self.buffer = None |
self.reduce_size = reduce_size |
self.out_channels = out_channels |
self.track_stages = track_stages |
self.normalize_mosaic = normalize_mosaic |
def forward(self, raw): |
self.stages = {} |
self.buffer = {} |
rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels) |
self.stages['demosaic'] = rgb |
if self.normalize_mosaic: |
rgb = self.normalize_mosaic(rgb) |
if self.track_stages and raw.requires_grad: |
for stage in self.stages.values(): |
stage.retain_grad() |
self.buffer['processed_rgb'] = rgb |
return rgb |
class NNProcessing(nn.Module): |
"""Transforms raw images to processed rgb via a segmentation Unet |
Args: |
track_stages (bool, optional): whether or not to retain intermediary steps in processing |
normalize_mosaic (function, optional): applies normalization transformation to rgb image |
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing |
""" |
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True): |
super().__init__() |
self.stages = None |
self.buffer = None |
self.track_stages = track_stages |
self.model = smp.UnetPlusPlus( |
encoder_name='resnet34', |
encoder_depth=3, |
decoder_channels=[256, 128, 64], |
in_channels=3, |
classes=3, |
) |
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False) |
self.normalize_mosaic = normalize_mosaic |
def forward(self, raw): |
self.stages = {} |
self.buffer = {} |
rgb = raw2rgb(raw) |
if self.normalize_mosaic: |
rgb = self.normalize_mosaic(rgb) |
self.stages['demosaic'] = rgb |
rgb = self.model(rgb) |
if self.batch_norm is not None: |
rgb = self.batch_norm(rgb) |
self.stages['rgb'] = rgb |
if self.track_stages and raw.requires_grad: |
for stage in self.stages.values(): |
stage.retain_grad() |
self.buffer['processed_rgb'] = rgb |
return rgb |
def append_additive_layer(processor): |
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256))) |
class ParametrizedProcessing(nn.Module): |
"""Differentiable processing pipeline via torch transformations |
Args: |
camera_parameters (tuple(list), optional): applies given camera parameters in processing |
track_stages (bool, optional): whether or not to retain intermediary steps in processing |
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing |
""" |
def __init__(self, camera_parameters=None, track_stages=False, batch_norm_output=True): |
super().__init__() |
self.stages = None |
self.buffer = None |
self.track_stages = track_stages |
if camera_parameters is None: |
camera_parameters = DEFAULT_CAMERA_PARAMS |
black_level, white_balance, colour_matrix = camera_parameters |
self.black_level = nn.Parameter(torch.as_tensor(black_level)) |
self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3)) |
self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3)) |
self.gamma_correct = nn.Parameter(torch.Tensor([2.2])) |
self.debayer = Debayer() |
self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) |
self.sharpening_filter.weight.data[0][0] = K_SHARP.clone() |
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False) |
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone() |
self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None |
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone()) |
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone()) |
self.additive_layer = None |
def forward(self, raw): |
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}" |
self.stages = {} |
self.buffer = {} |
rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False) |
rgb = rgb.contiguous() |
self.stages['demosaic'] = rgb |
rgb = self.debayer(rgb) |
rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous() |
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous() |
self.stages['color_correct'] = rgb |
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous() |
yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...]) |
if self.track_stages: |
rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous() |
self.stages['sharpening'] = rgb |
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous() |
yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...]) |
rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous() |
self.stages['gaussian'] = rgb |
rgb = torch.clip(rgb, 1e-5, 1) |
self.stages['clipped'] = rgb |
rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb)) |
self.stages['gamma_correct'] = rgb |
if self.additive_layer is not None: |
rgb = rgb + self.additive_layer |
self.stages['noise'] = rgb |
if self.batch_norm is not None: |
rgb = self.batch_norm(rgb) |
if self.track_stages and raw.requires_grad: |
for stage in self.stages.values(): |
stage.retain_grad() |
self.buffer['processed_rgb'] = rgb |
return rgb |
class Debayer(nn.Conv2d): |
"""Separates the mosaiced raw image into its channels and interpolates bilinearly. Output is of same size as input. |
""" |
def __init__(self): |
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) |
self.weight.data.fill_(0) |
self.weight.data[0, 0] = K_RB.clone() |
self.weight.data[1, 1] = K_G.clone() |
self.weight.data[2, 2] = K_RB.clone() |
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3): |
"""Transforms a raw image with 1 channel to rgb with 3 channels |
Args: |
raw (Tensor): raw Tensor of shape (B, H, W) |
black_level (iterable, optional): RGGB black level values to subtract |
reduce_size (bool, optional): if False, the output image will have the same height and width |
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros. |
if True, the output dimensions are reduced by half (B, C, H//2, W//2), |
out_channels (int, optional): number of output channels. One of {3, 4}. |
The two green channels are averaged if out_channels == 3. |
""" |
assert out_channels in [3, 4] |
if black_level is None: |
black_level = [0, 0, 0, 0] |
Bch, H, W = raw.shape |
R = raw[:, 0::2, 0::2] - black_level[0] |
G1 = raw[:, 0::2, 1::2] - black_level[1] |
G2 = raw[:, 1::2, 0::2] - black_level[2] |
B = raw[:, 1::2, 1::2] - black_level[3] |
if reduce_size: |
rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device) |
if out_channels == 3: |
rgb[:, 0, :, :] = R |
rgb[:, 1, :, :] = (G1 + G2) / 2 |
rgb[:, 2, :, :] = B |
elif out_channels == 4: |
rgb[:, 0, :, :] = R |
rgb[:, 1, :, :] = G1 |
rgb[:, 2, :, :] = G2 |
rgb[:, 3, :, :] = B |
else: |
rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device) |
if out_channels == 3: |
rgb[:, 0, 0::2, 0::2] = R |
rgb[:, 1, 0::2, 1::2] = G1 |
rgb[:, 1, 1::2, 0::2] = G2 |
rgb[:, 2, 1::2, 1::2] = B |
elif out_channels == 4: |
rgb[:, 0, 0::2, 0::2] = R |
rgb[:, 1, 0::2, 1::2] = G1 |
rgb[:, 2, 1::2, 0::2] = G2 |
rgb[:, 3, 1::2, 1::2] = B |
return rgb |
if __name__ == "__main__": |
import torch |
import numpy as np |
if not os.path.exists('README.md'): |
os.chdir('..') |
import matplotlib.pyplot as plt |
from dataset import get_dataset |
from utils.base import np2torch, torch2np |
from utils.debug import debug |
from processing.pipeline_numpy import processing as default_processing |
raw_dataset = get_dataset('DS') |
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1) |
batch_raw, batch_mask = next(iter(loader)) |
camera_parameters = raw_dataset.camera_parameters |
black_level = camera_parameters[0] |
proc = ParametrizedProcessing(camera_parameters) |
batch_rgb = proc(batch_raw) |
rgb = batch_rgb[0] |
raw_img = batch_raw[0] |
numpy_raw = torch2np(raw_img) |
default_rgb = default_processing(numpy_raw, *camera_parameters, |
sharpening='sharpening_filter', denoising='gaussian_denoising') |
rgb_valid = np2torch(default_rgb) |
print("pipeline norm difference:", (rgb - rgb_valid).norm().item()) |
rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze() |
rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze() |
plt.figure(figsize=(16, 8)) |
plt.subplot(151) |
plt.title('Raw') |
plt.imshow(torch2np(raw_img)) |
plt.subplot(152) |
plt.title('RGB Mosaic') |
plt.imshow(torch2np(rgb_mosaic)) |
plt.subplot(153) |
plt.title('RGB Reduced') |
plt.imshow(torch2np(rgb_reduced)) |
plt.subplot(154) |
plt.title('Torch Pipeline') |
plt.imshow(torch2np(rgb)) |
plt.subplot(155) |
plt.title('Default Pipeline') |
plt.imshow(torch2np(rgb_valid)) |
plt.show() |