raw2logit / processing /pipeline_torch.py
marco.aversa
updated libraries
4f8704e
import os
from numpy.lib.function_base import interp
import torch
import torch.nn as nn
if not os.path.exists('README.md'):
os.chdir('..')
from processing.pipeline_numpy import processing as default_processing
from utils.base import np2torch, torch2np
import segmentation_models_pytorch as smp
K_G = torch.Tensor([[0, 1, 0],
[1, 4, 1],
[0, 1, 0]]) / 4
K_RB = torch.Tensor([[1, 2, 1],
[2, 4, 2],
[1, 2, 1]]) / 4
M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114],
[-0.14714119, -0.28886916, 0.43601035],
[0.61497538, -0.51496512, -0.10001026]])
M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00],
[1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01],
[1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]])
K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08],
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
[2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04],
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]])
K_SHARP = torch.Tensor([[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]])
DEFAULT_CAMERA_PARAMS = (
[0., 0., 0., 0.],
[1., 1., 1.],
[1., 0., 0., 0., 1., 0., 0., 0., 1.],
)
class RawToRGB(nn.Module):
"""transforms a raw image with 1 channel to rgb with 3 channels
Args:
reduce_size (bool, optional): if False, the output image will have the same height and width
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
out_channels (int, optional): number of output channels. One of {3, 4}.
for 3 channels, the two green channels are averaged.
track_stages (bool, optional): whether or not to retain intermediary steps in processing
normalize_mosaic (function, optional): applies normalization transformation to rgb image
"""
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
super().__init__()
self.stages = None
self.buffer = None
self.reduce_size = reduce_size
self.out_channels = out_channels
self.track_stages = track_stages
self.normalize_mosaic = normalize_mosaic
def forward(self, raw):
self.stages = {}
self.buffer = {}
rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels)
self.stages['demosaic'] = rgb
if self.normalize_mosaic:
rgb = self.normalize_mosaic(rgb)
if self.track_stages and raw.requires_grad:
for stage in self.stages.values():
stage.retain_grad()
self.buffer['processed_rgb'] = rgb
return rgb
class NNProcessing(nn.Module):
"""Transforms raw images to processed rgb via a segmentation Unet
Args:
track_stages (bool, optional): whether or not to retain intermediary steps in processing
normalize_mosaic (function, optional): applies normalization transformation to rgb image
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing
"""
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
super().__init__()
self.stages = None
self.buffer = None
self.track_stages = track_stages
self.model = smp.UnetPlusPlus(
encoder_name='resnet34',
encoder_depth=3,
decoder_channels=[256, 128, 64],
in_channels=3,
classes=3,
)
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False)
self.normalize_mosaic = normalize_mosaic
def forward(self, raw):
self.stages = {}
self.buffer = {}
rgb = raw2rgb(raw)
if self.normalize_mosaic:
rgb = self.normalize_mosaic(rgb)
self.stages['demosaic'] = rgb
rgb = self.model(rgb)
if self.batch_norm is not None:
rgb = self.batch_norm(rgb)
self.stages['rgb'] = rgb
if self.track_stages and raw.requires_grad:
for stage in self.stages.values():
stage.retain_grad()
self.buffer['processed_rgb'] = rgb
return rgb
def append_additive_layer(processor):
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
# processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
class ParametrizedProcessing(nn.Module):
"""Differentiable processing pipeline via torch transformations
Args:
camera_parameters (tuple(list), optional): applies given camera parameters in processing
track_stages (bool, optional): whether or not to retain intermediary steps in processing
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing
"""
def __init__(self, camera_parameters=None, track_stages=False, batch_norm_output=True):
super().__init__()
self.stages = None
self.buffer = None
self.track_stages = track_stages
if camera_parameters is None:
camera_parameters = DEFAULT_CAMERA_PARAMS
black_level, white_balance, colour_matrix = camera_parameters
self.black_level = nn.Parameter(torch.as_tensor(black_level))
self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3))
self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3))
self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
self.debayer = Debayer()
self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
self.sharpening_filter.weight.data[0][0] = K_SHARP.clone()
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
self.additive_layer = None # this can be added in later
def forward(self, raw):
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
self.stages = {}
self.buffer = {}
# self.stages['raw'] = raw
rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False)
rgb = rgb.contiguous()
self.stages['demosaic'] = rgb
rgb = self.debayer(rgb)
# self.stages['debayer'] = rgb
rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous()
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
self.stages['color_correct'] = rgb
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...])
if self.track_stages: # keep stage in computational graph for grad information
rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous()
self.stages['sharpening'] = rgb
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...])
rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous()
self.stages['gaussian'] = rgb
rgb = torch.clip(rgb, 1e-5, 1)
self.stages['clipped'] = rgb
rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb))
self.stages['gamma_correct'] = rgb
if self.additive_layer is not None:
rgb = rgb + self.additive_layer
self.stages['noise'] = rgb
if self.batch_norm is not None:
rgb = self.batch_norm(rgb)
if self.track_stages and raw.requires_grad:
for stage in self.stages.values():
stage.retain_grad()
self.buffer['processed_rgb'] = rgb
return rgb
class Debayer(nn.Conv2d):
"""Separates the mosaiced raw image into its channels and interpolates bilinearly. Output is of same size as input.
"""
def __init__(self):
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # pipeline_numpy uses 'replicate'
self.weight.data.fill_(0)
self.weight.data[0, 0] = K_RB.clone()
self.weight.data[1, 1] = K_G.clone()
self.weight.data[2, 2] = K_RB.clone()
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
"""Transforms a raw image with 1 channel to rgb with 3 channels
Args:
raw (Tensor): raw Tensor of shape (B, H, W)
black_level (iterable, optional): RGGB black level values to subtract
reduce_size (bool, optional): if False, the output image will have the same height and width
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
out_channels (int, optional): number of output channels. One of {3, 4}.
The two green channels are averaged if out_channels == 3.
"""
assert out_channels in [3, 4]
if black_level is None:
black_level = [0, 0, 0, 0]
Bch, H, W = raw.shape
R = raw[:, 0::2, 0::2] - black_level[0] # R
G1 = raw[:, 0::2, 1::2] - black_level[1] # G
G2 = raw[:, 1::2, 0::2] - black_level[2] # G
B = raw[:, 1::2, 1::2] - black_level[3] # B
if reduce_size:
rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device)
if out_channels == 3:
rgb[:, 0, :, :] = R
rgb[:, 1, :, :] = (G1 + G2) / 2
rgb[:, 2, :, :] = B
elif out_channels == 4:
rgb[:, 0, :, :] = R
rgb[:, 1, :, :] = G1
rgb[:, 2, :, :] = G2
rgb[:, 3, :, :] = B
else:
rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device)
if out_channels == 3:
rgb[:, 0, 0::2, 0::2] = R
rgb[:, 1, 0::2, 1::2] = G1
rgb[:, 1, 1::2, 0::2] = G2
rgb[:, 2, 1::2, 1::2] = B
elif out_channels == 4:
rgb[:, 0, 0::2, 0::2] = R
rgb[:, 1, 0::2, 1::2] = G1
rgb[:, 2, 1::2, 0::2] = G2
rgb[:, 3, 1::2, 1::2] = B
return rgb
# pipeline validation
if __name__ == "__main__":
import torch
import numpy as np
if not os.path.exists('README.md'):
os.chdir('..')
import matplotlib.pyplot as plt
from dataset import get_dataset
from utils.base import np2torch, torch2np
from utils.debug import debug
from processing.pipeline_numpy import processing as default_processing
raw_dataset = get_dataset('DS')
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
batch_raw, batch_mask = next(iter(loader))
# torch proc
camera_parameters = raw_dataset.camera_parameters
black_level = camera_parameters[0]
proc = ParametrizedProcessing(camera_parameters)
batch_rgb = proc(batch_raw)
rgb = batch_rgb[0]
# numpy proc
raw_img = batch_raw[0]
numpy_raw = torch2np(raw_img)
default_rgb = default_processing(numpy_raw, *camera_parameters,
sharpening='sharpening_filter', denoising='gaussian_denoising')
rgb_valid = np2torch(default_rgb)
print("pipeline norm difference:", (rgb - rgb_valid).norm().item())
rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze()
rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze()
plt.figure(figsize=(16, 8))
plt.subplot(151)
plt.title('Raw')
plt.imshow(torch2np(raw_img))
plt.subplot(152)
plt.title('RGB Mosaic')
plt.imshow(torch2np(rgb_mosaic))
plt.subplot(153)
plt.title('RGB Reduced')
plt.imshow(torch2np(rgb_reduced))
plt.subplot(154)
plt.title('Torch Pipeline')
plt.imshow(torch2np(rgb))
plt.subplot(155)
plt.title('Default Pipeline')
plt.imshow(torch2np(rgb_valid))
plt.show()
# assert rgb.allclose(rgb_valid)