File size: 2,803 Bytes
a80d6bb
 
 
 
 
c74a070
a80d6bb
 
 
 
c74a070
 
a80d6bb
 
 
c74a070
a80d6bb
 
 
 
 
 
 
c74a070
 
 
 
 
 
 
 
 
a80d6bb
c74a070
a80d6bb
c74a070
 
a80d6bb
 
 
 
 
 
 
 
c74a070
 
a80d6bb
c74a070
a80d6bb
c74a070
a80d6bb
 
 
 
 
c74a070
 
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
c74a070
 
 
a80d6bb
 
 
c74a070
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import print_function, division
import numpy as np
from torch.utils.data import Dataset
import torch


class BaseDataset(Dataset):
    def __init__(self, opt):
        self.crop_size = 512
        self.debug_mode = opt.debug_mode
        self.data_path = opt.data_path  # dataset path. e.g., ./data/
        self.camera_name = opt.camera
        self.gamma = opt.gamma

    def norm_img(self, img, max_value):
        img = img / float(max_value)
        return img

    def pack_raw(self, raw):
        # pack Bayer image to 4 channels
        im = np.expand_dims(raw, axis=2)
        H, W = raw.shape[0], raw.shape[1]
        # RGBG
        out = np.concatenate(
            (
                im[0:H:2, 0:W:2, :],
                im[0:H:2, 1:W:2, :],
                im[1:H:2, 1:W:2, :],
                im[1:H:2, 0:W:2, :],
            ),
            axis=2,
        )
        return out

    def np2tensor(self, array):
        return torch.Tensor(array).permute(2, 0, 1)

    def center_crop(self, img, crop_size=None):
        H = img.shape[0]
        W = img.shape[1]

        if crop_size is not None:
            th, tw = crop_size[0], crop_size[1]
        else:
            th, tw = self.crop_size, self.crop_size
        x1_img = int(round((W - tw) / 2.0))
        y1_img = int(round((H - th) / 2.0))
        if img.ndim == 3:
            input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw, :]
        else:
            input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw]

        return input_patch

    def load(self, is_train=True):
        # ./data
        # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB
        # ./data/Canon EOS 5D/RAW,  ./data/Canon EOS 5D/RGB
        # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt
        # ./data/NIKON D700_train.txt: a0016, ...
        input_RAWs_WBs = []
        target_RGBs = []

        data_path = self.data_path  # ./data/
        if is_train:
            txt_path = data_path + self.camera_name + "_train.txt"
        else:
            txt_path = data_path + self.camera_name + "_test.txt"

        with open(txt_path, "r") as f_read:
            # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()]
            valid_camera_list = [line.strip() for line in f_read.readlines()]

        if self.debug_mode:
            valid_camera_list = valid_camera_list[:10]

        for i, name in enumerate(valid_camera_list):
            full_name = data_path + self.camera_name
            input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz")
            target_RGBs.append(full_name + "/RGB/" + name + ".jpg")

        return input_RAWs_WBs, target_RGBs

    def __len__(self):
        return 0

    def __getitem__(self, idx):

        return None