from __future__ import print_function, division import os, random, time import torch import numpy as np from torch.utils.data import Dataset from torchvision import transforms, utils import rawpy from glob import glob from PIL import Image as PILImage import numbers from scipy.misc import imread from .base_dataset import BaseDataset class FiveKDatasetTrain(BaseDataset): def __init__(self, opt): super().__init__(opt=opt) self.patch_size = 256 input_RAWs_WBs, target_RGBs = self.load(is_train=True) assert len(input_RAWs_WBs) == len(target_RGBs) self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} def random_flip(self, input_raw, target_rgb): idx = np.random.randint(2) input_raw = np.flip(input_raw, axis=idx).copy() target_rgb = np.flip(target_rgb, axis=idx).copy() return input_raw, target_rgb def random_rotate(self, input_raw, target_rgb): idx = np.random.randint(4) input_raw = np.rot90(input_raw, k=idx) target_rgb = np.rot90(target_rgb, k=idx) return input_raw, target_rgb def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False): H, W, _ = input_raw.shape rnd_h = random.randint(0, max(0, H - patch_size)) rnd_w = random.randint(0, max(0, W - patch_size)) patch_input_raw = input_raw[ rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : ] if flow or demos: patch_target_rgb = target_rgb[ rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : ] else: patch_target_rgb = target_rgb[ rnd_h * 2 : rnd_h * 2 + patch_size * 2, rnd_w * 2 : rnd_w * 2 + patch_size * 2, :, ] return patch_input_raw, patch_target_rgb def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): input_raw, target_rgb = self.random_crop( patch_size, input_raw, target_rgb, flow=flow, demos=demos ) input_raw, target_rgb = self.random_rotate(input_raw, target_rgb) input_raw, target_rgb = self.random_flip(input_raw, target_rgb) return input_raw, target_rgb def __len__(self): return len(self.data["input_RAWs_WBs"]) def __getitem__(self, idx): input_raw_wb_path = self.data["input_RAWs_WBs"][idx] target_rgb_path = self.data["target_RGBs"][idx] target_rgb_img = imread(target_rgb_path) input_raw_wb = np.load(input_raw_wb_path) input_raw_img = input_raw_wb["raw"] wb = input_raw_wb["wb"] wb = wb / wb.max() input_raw_img = input_raw_img * wb[:-1] self.patch_size = 256 input_raw_img, target_rgb_img = self.aug( self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True ) if self.gamma: norm_value = ( np.power(4095, 1 / 2.2) if self.camera_name == "Canon_EOS_5D" else np.power(16383, 1 / 2.2) ) input_raw_img = np.power(input_raw_img, 1 / 2.2) else: norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 target_rgb_img = self.norm_img(target_rgb_img, max_value=255) input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) target_raw_img = input_raw_img.copy() input_raw_img = self.np2tensor(input_raw_img).float() target_rgb_img = self.np2tensor(target_rgb_img).float() target_raw_img = self.np2tensor(target_raw_img).float() sample = { "input_raw": input_raw_img, "target_rgb": target_rgb_img, "target_raw": target_raw_img, "file_name": input_raw_wb_path.split("/")[-1].split(".")[0], } return sample class FiveKDatasetTest(BaseDataset): def __init__(self, opt): super().__init__(opt=opt) self.patch_size = 256 input_RAWs_WBs, target_RGBs = self.load(is_train=False) assert len(input_RAWs_WBs) == len(target_RGBs) self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} def __len__(self): return len(self.data["input_RAWs_WBs"]) def __getitem__(self, idx): input_raw_wb_path = self.data["input_RAWs_WBs"][idx] target_rgb_path = self.data["target_RGBs"][idx] target_rgb_img = imread(target_rgb_path) input_raw_wb = np.load(input_raw_wb_path) input_raw_img = input_raw_wb["raw"] wb = input_raw_wb["wb"] wb = wb / wb.max() input_raw_img = input_raw_img * wb[:-1] if self.gamma: norm_value = ( np.power(4095, 1 / 2.2) if self.camera_name == "Canon_EOS_5D" else np.power(16383, 1 / 2.2) ) input_raw_img = np.power(input_raw_img, 1 / 2.2) else: norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 target_rgb_img = self.norm_img(target_rgb_img, max_value=255) input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) target_raw_img = input_raw_img.copy() input_raw_img = self.np2tensor(input_raw_img).float() target_rgb_img = self.np2tensor(target_rgb_img).float() target_raw_img = self.np2tensor(target_raw_img).float() sample = { "input_raw": input_raw_img, "target_rgb": target_rgb_img, "target_raw": target_raw_img, "file_name": input_raw_wb_path.split("/")[-1].split(".")[0], } return sample