|
""" |
|
Raw Image Pipeline |
|
""" |
|
__author__ = "Marco Aversa" |
|
|
|
import numpy as np |
|
|
|
from rawpy import * |
|
from scipy import ndimage |
|
from scipy import fftpack |
|
from scipy.signal import convolve2d |
|
|
|
from skimage.filters import unsharp_mask |
|
from skimage.color import rgb2yuv, yuv2rgb, rgb2hsv, hsv2rgb |
|
from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman, denoise_nl_means, denoise_bilateral, denoise_wavelet, estimate_sigma |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear, |
|
demosaicing_CFA_Bayer_Malvar2004, |
|
demosaicing_CFA_Bayer_Menon2007) |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from dataset import Subset |
|
from torch.utils.data import DataLoader |
|
|
|
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear, |
|
demosaicing_CFA_Bayer_Malvar2004, |
|
demosaicing_CFA_Bayer_Menon2007) |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class RawProcessingPipeline(object): |
|
|
|
"""Applies the raw-processing pipeline from pipeline.py""" |
|
|
|
def __init__(self, camera_parameters, debayer='bilinear', sharpening='unsharp_masking', denoising='gaussian'): |
|
''' |
|
Args: |
|
camera_parameters (tuple): (black_level, white_balance, colour_matrix) |
|
debayer (str): specifies the algorithm used as debayer; choose from {'bilinear','malvar2004','menon2007'} |
|
sharpening (str): specifies the algorithm used for sharpening; choose from {'sharpening_filter','unsharp_masking'} |
|
denoising (str): specifies the algorithm used for denoising; choose from choose from {'gaussian_denoising','median_denoising','fft_denoising'} |
|
''' |
|
|
|
self.camera_parameters = camera_parameters |
|
|
|
self.debayer = debayer |
|
self.sharpening = sharpening |
|
self.denoising = denoising |
|
|
|
def __call__(self, img): |
|
""" |
|
Args: |
|
img (ndarry of dtype float.32): image of size (H,W) |
|
return: |
|
img (tensor of dtype float): image of size (3,H,W) |
|
""" |
|
black_level, white_balance, colour_matrix = self.camera_parameters |
|
img = processing(img, black_level, white_balance, colour_matrix, |
|
debayer=self.debayer, sharpening=self.sharpening, denoising=self.denoising) |
|
img = img.transpose(2, 0, 1) |
|
|
|
return torch.Tensor(img) |
|
|
|
|
|
def processing(img, black_level, white_balance, colour_matrix, debayer="bilinear", sharpening="unsharp_masking", |
|
sharp_radius=1.0, sharp_amount=1.0, denoising="median_filter", median_kernel_size=3, |
|
gaussian_sigma=0.5, fft_fraction=0.3, weight_chambolle=0.01, weight_bregman=100, |
|
sigma_bilateral=0.6, gamma=2.2, bits=16): |
|
"""Apply pipeline on a raw image |
|
|
|
Args: |
|
rawImg (ndarray): raw image |
|
debayer (str): debayer algorithm |
|
white_balance (None, ndarray): white balance array (if None it will take the default camera white balance array) |
|
colour_matrix (None, ndarray): colour matrix (if None it will take the default camera colour matrix) - Size: 3x3 |
|
gamma (float): exponent for the non linear gamma correction. |
|
|
|
Returns: |
|
img (ndarray): post-processed image |
|
|
|
""" |
|
|
|
|
|
img = remove_blacklv(img, black_level) |
|
|
|
|
|
if debayer == "bilinear": |
|
img = demosaicing_CFA_Bayer_bilinear(img) |
|
if debayer == "malvar2004": |
|
img = demosaicing_CFA_Bayer_Malvar2004(img) |
|
if debayer == "menon2007": |
|
img = demosaicing_CFA_Bayer_Menon2007(img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = wb_correction(img, white_balance) |
|
|
|
|
|
img = colour_correction(img, colour_matrix) |
|
|
|
|
|
if sharpening == "sharpening_filter": |
|
img = sharpening_filter(img) |
|
if sharpening == "unsharp_masking": |
|
img = unsharp_masking(img, radius=sharp_radius, amount=sharp_amount, multichannel=True) |
|
|
|
|
|
if denoising == "median_denoising": |
|
img = median_denoising(img, size=median_kernel_size) |
|
if denoising == "gaussian_denoising": |
|
img = gaussian_denoising(img, sigma=gaussian_sigma) |
|
if denoising == "fft_denoising": |
|
img = fft_denoising(img, keep_fraction=fft_fraction, row_cut=False, column_cut=True) |
|
|
|
|
|
if denoising == "tv_chambolle": |
|
img = denoise_tv_chambolle(img, weight=weight_chambolle, eps=0.0002, n_iter_max=200, multichannel=True) |
|
if denoising == "tv_bregman": |
|
img = denoise_tv_bregman(img, weight=weight_bregman, max_iter=100, |
|
eps=0.001, isotropic=True, multichannel=True) |
|
|
|
|
|
|
|
if denoising == "bilateral": |
|
img = denoise_bilateral(img, win_size=None, sigma_color=None, sigma_spatial=sigma_bilateral, |
|
bins=10000, mode='constant', cval=0, multichannel=True) |
|
|
|
|
|
img = np.clip(img, 0, 1) |
|
img = adjust_gamma(img, gamma=gamma) |
|
|
|
return img |
|
|
|
|
|
def get_camera_parameters(rawpyImg): |
|
black_level = rawpyImg.black_level_per_channel |
|
white_balance = rawpyImg.camera_whitebalance[:3] |
|
colour_matrix = rawpyImg.color_matrix[:, :3].flatten().tolist() |
|
|
|
return black_level, white_balance, colour_matrix |
|
|
|
|
|
def remove_blacklv(rawImg, black_level): |
|
rawImg[0::2, 0::2] -= black_level[0] |
|
rawImg[0::2, 1::2] -= black_level[1] |
|
rawImg[1::2, 0::2] -= black_level[2] |
|
rawImg[1::2, 1::2] -= black_level[3] |
|
|
|
return rawImg |
|
|
|
|
|
def wb_correction(img, white_balance): |
|
return img * white_balance |
|
|
|
|
|
def colour_correction(img, colour_matrix): |
|
colour_matrix = np.array(colour_matrix).reshape(3, 3) |
|
return np.einsum('ijk,lk->ijl', img, colour_matrix) |
|
|
|
|
|
def unsharp_masking(img, radius=1.0, amount=1.0, |
|
multichannel=False, preserve_range=True): |
|
|
|
img = rgb2yuv(img) |
|
img[:, :, 0] = unsharp_mask(img[:, :, 0], radius=radius, amount=amount, |
|
multichannel=multichannel, preserve_range=preserve_range) |
|
img = yuv2rgb(img) |
|
return img |
|
|
|
|
|
def sharpening_filter(image, iterations=1, kernel=np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])): |
|
|
|
|
|
|
|
img_yuv = rgb2yuv(image) |
|
|
|
for i in range(iterations): |
|
img_yuv[:, :, 0] = convolve2d(img_yuv[:, :, 0], kernel, 'same', boundary='fill', fillvalue=0) |
|
|
|
final_image = yuv2rgb(img_yuv) |
|
|
|
return final_image |
|
|
|
|
|
def median_denoising(img, size=3): |
|
|
|
img = rgb2yuv(img) |
|
img[:, :, 0] = ndimage.median_filter(img[:, :, 0], size) |
|
img = yuv2rgb(img) |
|
|
|
return img |
|
|
|
|
|
def gaussian_denoising(img, sigma=0.5): |
|
|
|
img = rgb2yuv(img) |
|
img[:, :, 0] = ndimage.gaussian_filter(img[:, :, 0], sigma) |
|
img = yuv2rgb(img) |
|
|
|
return img |
|
|
|
|
|
def fft_denoising(img, keep_fraction=0.3, row_cut=False, column_cut=True): |
|
""" keep_fraction = 0.5 --> same image as input |
|
keep_fraction --> 0 --> remove all details """ |
|
|
|
|
|
im_fft = fftpack.fft2(img) |
|
|
|
|
|
|
|
im_fft2 = im_fft |
|
|
|
|
|
r, c, _ = im_fft2.shape |
|
|
|
|
|
if row_cut == True: |
|
im_fft2[int(r * keep_fraction):int(r * (1 - keep_fraction))] = 0 |
|
|
|
|
|
if column_cut == True: |
|
im_fft2[:, int(c * keep_fraction):int(c * (1 - keep_fraction))] = 0 |
|
|
|
|
|
|
|
im_new = fftpack.ifft2(im_fft2).real |
|
|
|
return im_new |
|
|
|
|
|
def adjust_gamma(img, gamma=1.0): |
|
invGamma = 1.0 / gamma |
|
img = (img ** invGamma) |
|
return img |
|
|
|
|
|
def show_img(img, title="no_title", size=12, histo=True, bins=300, bits=16, x_range=-1): |
|
"""Plot image and its histogram |
|
|
|
Args: |
|
img (ndarray): image to plot |
|
title (str): title of the plot |
|
histo (bool): True - Plot histrograms per channel of the image. False - Plot the curve of histogram in a continue way |
|
bins (int): number of bins of the histogram |
|
size (int): figure size |
|
bits (int): number of bits per pixel in the ndarray |
|
x_range (list): maximum x range of the histogram (if -1 it will be take all x values) |
|
""" |
|
shape = img.shape |
|
|
|
fig = plt.figure(figsize=(size, size)) |
|
|
|
|
|
fig.add_subplot(221) |
|
if len(shape) > 2 and img.max() > 255: |
|
img_to_show = (img.copy() * 255. / (2**bits - 1)).astype(int) |
|
else: |
|
img_to_show = img.copy().astype(int) |
|
plt.imshow(img_to_show) |
|
if title != "no_title": |
|
plt.title(title) |
|
|
|
fig.add_subplot(222) |
|
|
|
if len(shape) > 2: |
|
if histo == True: |
|
plt.hist(img[:, :, 0].flatten(), bins=bins, label="Channel1", color="red", alpha=0.5) |
|
plt.hist(img[:, :, 1].flatten(), bins=bins, label="Channel2", color="green", alpha=0.5) |
|
plt.hist(img[:, :, 2].flatten(), bins=bins, label="Channel3", color="blue", alpha=0.5) |
|
if x_range != -1: |
|
plt.xlim([x_range[0], x_range[1]]) |
|
else: |
|
h1, b1 = np.histogram(img[:, :, 0].flatten(), bins=bins) |
|
h2, b2 = np.histogram(img[:, :, 1].flatten(), bins=bins) |
|
h3, b3 = np.histogram(img[:, :, 2].flatten(), bins=bins) |
|
plt.plot(b1[:-1], h1, label="Channel1", color="red", alpha=0.5) |
|
plt.plot(b2[:-1], h2, label="Channel2", color="green", alpha=0.5) |
|
plt.plot(b3[:-1], h3, label="Channel3", color="blue", alpha=0.5) |
|
|
|
plt.legend() |
|
else: |
|
if histo == True: |
|
plt.hist(img.flatten(), bins=bins) |
|
if x_range != -1: |
|
plt.xlim([x_range[0], x_range[1]]) |
|
else: |
|
h, b = np.histogram(img.flatten(), bins=bins) |
|
plt.plot(b[:-1], h) |
|
|
|
plt.xlabel("Intensities") |
|
plt.ylabel("Counts") |
|
|
|
plt.show() |
|
|
|
|
|
def get_statistics(dataset, train_indices, transform=None): |
|
"""Calculates the mean and the standard deviation of a given sub train set of dataset |
|
|
|
Args: |
|
dataset (Subset of DroneDataset): |
|
train_indices (tensor): indicies correponding to a subset of the dataset |
|
transform (Compose): list of transformations compatible with Compose to be applied before calculations |
|
return: |
|
mean (tensor of dtype float): size (C,1,1) |
|
std (tensor of dtype float): size (C,1,1) |
|
""" |
|
|
|
trainset = Subset(dataset, indices=train_indices, transform=transform) |
|
dataloader = DataLoader(trainset, batch_size=len(trainset), shuffle=False) |
|
dataiter = iter(dataloader) |
|
|
|
images, labels = dataiter.next() |
|
|
|
if len(images.shape) == 3: |
|
mean, std = torch.mean(images, axis=(0, 1, 2)), torch.std(images, axis=(0, 1, 2)) |
|
return mean, std |
|
else: |
|
mean, std = torch.mean(images, axis=(0, 2, 3))[:, None, None], torch.std(images, axis=(0, 2, 3))[:, None, None] |
|
return mean, std |
|
|