Spaces:
Running
Running
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import torch | |
import numpy as np | |
import os, time, random | |
import argparse | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image as PILImage | |
from glob import glob | |
from tqdm import tqdm | |
import rawpy | |
import colour_demosaicing | |
from .InvISP.model.model import InvISPNet | |
from .utils.common import Notify | |
from datasets.noise import ( | |
camera_params, | |
addGStarNoise, | |
addPStarNoise, | |
addQuantNoise, | |
addRowNoise, | |
sampleK, | |
) | |
class NoiseSimulator: | |
def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"): | |
self.device = device | |
# load Invertible ISP Network | |
self.net = ( | |
InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() | |
) | |
self.net.load_state_dict(torch.load(ckpt_path), strict=False) | |
print( | |
Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC | |
) | |
# white balance parameters | |
self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) | |
# use Canon EOS 5D4 noise parameters provided by ELD | |
self.camera_params = camera_params | |
# random specify exposure time ratio from 50 to 150 | |
self.ratio_min = 50 | |
self.ratio_max = 150 | |
pass | |
# inverse demosaic | |
# input: [H, W, 3] | |
# output: [H, W] | |
def invDemosaic(self, img): | |
img_R = img[::2, ::2, 0] | |
img_G1 = img[::2, 1::2, 1] | |
img_G2 = img[1::2, ::2, 1] | |
img_B = img[1::2, 1::2, 2] | |
raw_img = np.ones(img.shape[:2]) | |
raw_img[::2, ::2] = img_R | |
raw_img[::2, 1::2] = img_G1 | |
raw_img[1::2, ::2] = img_G2 | |
raw_img[1::2, 1::2] = img_B | |
return raw_img | |
# demosaic - nearest ver | |
# input: [H, W] | |
# output: [H, W, 3] | |
def demosaicNearest(self, img): | |
raw = np.ones((img.shape[0], img.shape[1], 3)) | |
raw[::2, ::2, 0] = img[::2, ::2] | |
raw[::2, 1::2, 0] = img[::2, ::2] | |
raw[1::2, ::2, 0] = img[::2, ::2] | |
raw[1::2, 1::2, 0] = img[::2, ::2] | |
raw[::2, ::2, 2] = img[1::2, 1::2] | |
raw[::2, 1::2, 2] = img[1::2, 1::2] | |
raw[1::2, ::2, 2] = img[1::2, 1::2] | |
raw[1::2, 1::2, 2] = img[1::2, 1::2] | |
raw[::2, ::2, 1] = img[::2, 1::2] | |
raw[::2, 1::2, 1] = img[::2, 1::2] | |
raw[1::2, ::2, 1] = img[1::2, ::2] | |
raw[1::2, 1::2, 1] = img[1::2, ::2] | |
return raw | |
# demosaic | |
# input: [H, W] | |
# output: [H, W, 3] | |
def demosaic(self, img): | |
return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB") | |
# load rgb image | |
def path2rgb(self, path): | |
return torch.from_numpy(np.array(PILImage.open(path)) / 255.0) | |
# InvISP | |
# input: rgb image [H, W, 3] | |
# output: raw image [H, W] | |
def rgb2raw(self, rgb, batched=False): | |
# 1. rgb -> invnet | |
if not batched: | |
rgb = rgb.unsqueeze(0) | |
rgb = rgb.permute(0, 3, 1, 2).float().to(self.device) | |
with torch.no_grad(): | |
reconstruct_raw = self.net(rgb, rev=True) | |
pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) | |
pred_raw = torch.clamp(pred_raw, 0, 1) | |
if not batched: | |
pred_raw = pred_raw[0, ...] | |
pred_raw = pred_raw.cpu().numpy() | |
# 2. -> inv gamma | |
norm_value = np.power(16383, 1 / 2.2) | |
pred_raw *= norm_value | |
pred_raw = np.power(pred_raw, 2.2) | |
# 3. -> inv white balance | |
wb = self.wb / self.wb.max() | |
pred_raw = pred_raw / wb[:-1] | |
# 4. -> add black level | |
pred_raw += self.camera_params["black_level"] | |
# 5. -> inv demosaic | |
if not batched: | |
pred_raw = self.invDemosaic(pred_raw) | |
else: | |
preds = [] | |
for i in range(pred_raw.shape[0]): | |
preds.append(self.invDemosaic(pred_raw[i])) | |
pred_raw = np.stack(preds, axis=0) | |
return pred_raw | |
def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): | |
if not batched: | |
ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 | |
raw = raw.copy() / ratio | |
K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) | |
q = 1 / ( | |
self.camera_params["max_value"] - self.camera_params["black_level"] | |
) | |
raw = addPStarNoise(raw, K) | |
raw = addGStarNoise( | |
raw, | |
K, | |
self.camera_params["G_shape"], | |
self.camera_params["Profile-1"]["G_scale"], | |
) | |
raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"]) | |
raw = addQuantNoise(raw, q) | |
raw *= ratio | |
return raw | |
else: | |
raw = raw.copy() | |
for i in range(raw.shape[0]): | |
ratio = random.uniform(self.ratio_min, self.ratio_max) | |
raw[i] /= ratio | |
K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) | |
q = 1 / ( | |
self.camera_params["max_value"] - self.camera_params["black_level"] | |
) | |
raw[i] = addPStarNoise(raw[i], K) | |
raw[i] = addGStarNoise( | |
raw[i], | |
K, | |
self.camera_params["G_shape"], | |
self.camera_params["Profile-1"]["G_scale"], | |
) | |
raw[i] = addRowNoise( | |
raw[i], K, self.camera_params["Profile-1"]["R_scale"] | |
) | |
raw[i] = addQuantNoise(raw[i], q) | |
raw[i] *= ratio | |
return raw | |
def raw2rgb(self, raw, batched=False): | |
# 1. -> demosaic | |
if not batched: | |
raw = self.demosaic(raw) | |
else: | |
raws = [] | |
for i in range(raw.shape[0]): | |
raws.append(self.demosaic(raw[i])) | |
raw = np.stack(raws, axis=0) | |
# 2. -> substract black level | |
raw -= self.camera_params["black_level"] | |
raw = np.clip( | |
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] | |
) | |
# 3. -> white balance | |
wb = self.wb / self.wb.max() | |
raw = raw * wb[:-1] | |
# 4. -> gamma | |
norm_value = np.power(16383, 1 / 2.2) | |
raw = np.power(raw, 1 / 2.2) | |
raw /= norm_value | |
# 5. -> ispnet | |
if not batched: | |
input_raw_img = ( | |
torch.Tensor(raw) | |
.permute(2, 0, 1) | |
.float() | |
.to(self.device)[np.newaxis, ...] | |
) | |
else: | |
input_raw_img = ( | |
torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device) | |
) | |
with torch.no_grad(): | |
reconstruct_rgb = self.net(input_raw_img) | |
reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) | |
pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1) | |
if not batched: | |
pred_rgb = pred_rgb[0, ...] | |
pred_rgb = pred_rgb.cpu().numpy() | |
return pred_rgb | |
def raw2packedRaw(self, raw, batched=False): | |
# 1. -> substract black level | |
raw -= self.camera_params["black_level"] | |
raw = np.clip( | |
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] | |
) | |
raw /= self.camera_params["max_value"] | |
# 2. pack | |
if not batched: | |
im = np.expand_dims(raw, axis=2) | |
img_shape = im.shape | |
H = img_shape[0] | |
W = img_shape[1] | |
out = np.concatenate( | |
( | |
im[0:H:2, 0:W:2, :], | |
im[0:H:2, 1:W:2, :], | |
im[1:H:2, 1:W:2, :], | |
im[1:H:2, 0:W:2, :], | |
), | |
axis=2, | |
) | |
else: | |
im = np.expand_dims(raw, axis=3) | |
img_shape = im.shape | |
H = img_shape[1] | |
W = img_shape[2] | |
out = np.concatenate( | |
( | |
im[:, 0:H:2, 0:W:2, :], | |
im[:, 0:H:2, 1:W:2, :], | |
im[:, 1:H:2, 1:W:2, :], | |
im[:, 1:H:2, 0:W:2, :], | |
), | |
axis=3, | |
) | |
return out | |
def raw2demosaicRaw(self, raw, batched=False): | |
# 1. -> demosaic | |
if not batched: | |
raw = self.demosaic(raw) | |
else: | |
raws = [] | |
for i in range(raw.shape[0]): | |
raws.append(self.demosaic(raw[i])) | |
raw = np.stack(raws, axis=0) | |
# 2. -> substract black level | |
raw -= self.camera_params["black_level"] | |
raw = np.clip( | |
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] | |
) | |
raw /= self.camera_params["max_value"] | |
return raw | |