PKUWilliamYang commited on
Commit
3b98894
1 Parent(s): c3f9de8

Upload 6 files

Browse files
datasets/__init__.py ADDED
File without changes
datasets/augmentations.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torchvision import transforms
6
+
7
+
8
+ class ToOneHot(object):
9
+ """ Convert the input PIL image to a one-hot torch tensor """
10
+ def __init__(self, n_classes=None):
11
+ self.n_classes = n_classes
12
+
13
+ def onehot_initialization(self, a):
14
+ if self.n_classes is None:
15
+ self.n_classes = len(np.unique(a))
16
+ out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
17
+ out[self.__all_idx(a, axis=2)] = 1
18
+ return out
19
+
20
+ def __all_idx(self, idx, axis):
21
+ grid = np.ogrid[tuple(map(slice, idx.shape))]
22
+ grid.insert(axis, idx)
23
+ return tuple(grid)
24
+
25
+ def __call__(self, img):
26
+ img = np.array(img)
27
+ one_hot = self.onehot_initialization(img)
28
+ return one_hot
29
+
30
+
31
+ class BilinearResize(object):
32
+ def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
33
+ self.factors = factors
34
+
35
+ def __call__(self, image):
36
+ factor = np.random.choice(self.factors, size=1)[0]
37
+ D = BicubicDownSample(factor=factor, cuda=False)
38
+ img_tensor = transforms.ToTensor()(image).unsqueeze(0)
39
+ img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
40
+ img_low_res = transforms.ToPILImage()(img_tensor_lr)
41
+ return img_low_res
42
+
43
+
44
+ class BicubicDownSample(nn.Module):
45
+ def bicubic_kernel(self, x, a=-0.50):
46
+ """
47
+ This equation is exactly copied from the website below:
48
+ https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
49
+ """
50
+ abs_x = torch.abs(x)
51
+ if abs_x <= 1.:
52
+ return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
53
+ elif 1. < abs_x < 2.:
54
+ return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
55
+ else:
56
+ return 0.0
57
+
58
+ def __init__(self, factor=4, cuda=True, padding='reflect'):
59
+ super().__init__()
60
+ self.factor = factor
61
+ size = factor * 4
62
+ k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
63
+ for i in range(size)], dtype=torch.float32)
64
+ k = k / torch.sum(k)
65
+ k1 = torch.reshape(k, shape=(1, 1, size, 1))
66
+ self.k1 = torch.cat([k1, k1, k1], dim=0)
67
+ k2 = torch.reshape(k, shape=(1, 1, 1, size))
68
+ self.k2 = torch.cat([k2, k2, k2], dim=0)
69
+ self.cuda = '.cuda' if cuda else ''
70
+ self.padding = padding
71
+ for param in self.parameters():
72
+ param.requires_grad = False
73
+
74
+ def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
75
+ filter_height = self.factor * 4
76
+ filter_width = self.factor * 4
77
+ stride = self.factor
78
+
79
+ pad_along_height = max(filter_height - stride, 0)
80
+ pad_along_width = max(filter_width - stride, 0)
81
+ filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
82
+ filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
83
+
84
+ # compute actual padding values for each side
85
+ pad_top = pad_along_height // 2
86
+ pad_bottom = pad_along_height - pad_top
87
+ pad_left = pad_along_width // 2
88
+ pad_right = pad_along_width - pad_left
89
+
90
+ # apply mirror padding
91
+ if nhwc:
92
+ x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
93
+
94
+ # downscaling performed by 1-d convolution
95
+ x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
96
+ x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
97
+ if clip_round:
98
+ x = torch.clamp(torch.round(x), 0.0, 255.)
99
+
100
+ x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
101
+ x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
102
+ if clip_round:
103
+ x = torch.clamp(torch.round(x), 0.0, 255.)
104
+
105
+ if nhwc:
106
+ x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
107
+ if byte_output:
108
+ return x.type('torch.ByteTensor'.format(self.cuda))
109
+ else:
110
+ return x
datasets/ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os.path as osp
5
+ import torch
6
+ import torch.utils.data as data
7
+ from basicsr.data import degradations as degradations
8
+ from basicsr.data.data_util import paths_from_folder
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
11
+ from basicsr.utils.registry import DATASET_REGISTRY
12
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
13
+ normalize)
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class FFHQDegradationDataset(data.Dataset):
18
+ """FFHQ dataset for GFPGAN.
19
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
20
+ Args:
21
+ opt (dict): Config for train datasets. It contains the following keys:
22
+ dataroot_gt (str): Data root path for gt.
23
+ io_backend (dict): IO backend type and other kwarg.
24
+ mean (list | tuple): Image mean.
25
+ std (list | tuple): Image std.
26
+ use_hflip (bool): Whether to horizontally flip.
27
+ Please see more options in the codes.
28
+ """
29
+
30
+ def __init__(self, opt):
31
+ super(FFHQDegradationDataset, self).__init__()
32
+ self.opt = opt
33
+ # file client (io backend)
34
+ self.file_client = None
35
+ self.io_backend_opt = opt['io_backend']
36
+
37
+ self.gt_folder = opt['dataroot_gt']
38
+ self.mean = opt['mean']
39
+ self.std = opt['std']
40
+ self.out_size = opt['out_size']
41
+
42
+ self.crop_components = opt.get('crop_components', False) # facial components
43
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
44
+
45
+ if self.crop_components:
46
+ # load component list from a pre-process pth files
47
+ self.components_list = torch.load(opt.get('component_path'))
48
+
49
+ # file client (lmdb io backend)
50
+ if self.io_backend_opt['type'] == 'lmdb':
51
+ self.io_backend_opt['db_paths'] = self.gt_folder
52
+ if not self.gt_folder.endswith('.lmdb'):
53
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
54
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
55
+ self.paths = [line.split('.')[0] for line in fin]
56
+ else:
57
+ # disk backend: scan file list from a folder
58
+ self.paths = paths_from_folder(self.gt_folder)
59
+
60
+ # degradation configurations
61
+ self.blur_kernel_size = opt['blur_kernel_size']
62
+ self.kernel_list = opt['kernel_list']
63
+ self.kernel_prob = opt['kernel_prob']
64
+ self.blur_sigma = opt['blur_sigma']
65
+ self.downsample_range = opt['downsample_range']
66
+ self.noise_range = opt['noise_range']
67
+ self.jpeg_range = opt['jpeg_range']
68
+
69
+ # color jitter
70
+ self.color_jitter_prob = opt.get('color_jitter_prob')
71
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
72
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
73
+ # to gray
74
+ self.gray_prob = opt.get('gray_prob')
75
+
76
+ logger = get_root_logger()
77
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
78
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
79
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
80
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
81
+
82
+ if self.color_jitter_prob is not None:
83
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
84
+ if self.gray_prob is not None:
85
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
86
+ self.color_jitter_shift /= 255.
87
+
88
+ @staticmethod
89
+ def color_jitter(img, shift):
90
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
91
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
92
+ img = img + jitter_val
93
+ img = np.clip(img, 0, 1)
94
+ return img
95
+
96
+ @staticmethod
97
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
98
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
99
+ fn_idx = torch.randperm(4)
100
+ for fn_id in fn_idx:
101
+ if fn_id == 0 and brightness is not None:
102
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
103
+ img = adjust_brightness(img, brightness_factor)
104
+
105
+ if fn_id == 1 and contrast is not None:
106
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
107
+ img = adjust_contrast(img, contrast_factor)
108
+
109
+ if fn_id == 2 and saturation is not None:
110
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
111
+ img = adjust_saturation(img, saturation_factor)
112
+
113
+ if fn_id == 3 and hue is not None:
114
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
115
+ img = adjust_hue(img, hue_factor)
116
+ return img
117
+
118
+ def get_component_coordinates(self, index, status):
119
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
120
+ components_bbox = self.components_list[f'{index:08d}']
121
+ if status[0]: # hflip
122
+ # exchange right and left eye
123
+ tmp = components_bbox['left_eye']
124
+ components_bbox['left_eye'] = components_bbox['right_eye']
125
+ components_bbox['right_eye'] = tmp
126
+ # modify the width coordinate
127
+ components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
128
+ components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
129
+ components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
130
+
131
+ # get coordinates
132
+ locations = []
133
+ for part in ['left_eye', 'right_eye', 'mouth']:
134
+ mean = components_bbox[part][0:2]
135
+ mean[0] = mean[0] * 2 + 128 ########
136
+ mean[1] = mean[1] * 2 + 128 ########
137
+ half_len = components_bbox[part][2] * 2 ########
138
+ if 'eye' in part:
139
+ half_len *= self.eye_enlarge_ratio
140
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
141
+ loc = torch.from_numpy(loc).float()
142
+ locations.append(loc)
143
+ return locations
144
+
145
+ def __getitem__(self, index):
146
+ if self.file_client is None:
147
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
+
149
+ # load gt image
150
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
151
+ gt_path = self.paths[index]
152
+ img_bytes = self.file_client.get(gt_path)
153
+ img_gt = imfrombytes(img_bytes, float32=True)
154
+
155
+ # random horizontal flip
156
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
157
+ h, w, _ = img_gt.shape
158
+
159
+ # get facial component coordinates
160
+ if self.crop_components:
161
+ locations = self.get_component_coordinates(index, status)
162
+ loc_left_eye, loc_right_eye, loc_mouth = locations
163
+
164
+ # ------------------------ generate lq image ------------------------ #
165
+ # blur
166
+ kernel = degradations.random_mixed_kernels(
167
+ self.kernel_list,
168
+ self.kernel_prob,
169
+ self.blur_kernel_size,
170
+ self.blur_sigma,
171
+ self.blur_sigma, [-math.pi, math.pi],
172
+ noise_range=None)
173
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
174
+ # downsample
175
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
176
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
177
+ # noise
178
+ if self.noise_range is not None:
179
+ img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
180
+ # jpeg compression
181
+ if self.jpeg_range is not None:
182
+ img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
183
+
184
+ # resize to original size
185
+ img_lq = cv2.resize(img_lq, (int(w // self.opt['scale']), int(h // self.opt['scale'])), interpolation=cv2.INTER_LINEAR)
186
+
187
+ # random color jitter (only for lq)
188
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
189
+ img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
190
+ # random to gray (only for lq)
191
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
192
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
193
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
194
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
195
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
196
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
197
+
198
+ # BGR to RGB, HWC to CHW, numpy to tensor
199
+ #img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
200
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
201
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
202
+
203
+ # random color jitter (pytorch version) (only for lq)
204
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
205
+ brightness = self.opt.get('brightness', (0.5, 1.5))
206
+ contrast = self.opt.get('contrast', (0.5, 1.5))
207
+ saturation = self.opt.get('saturation', (0, 1.5))
208
+ hue = self.opt.get('hue', (-0.1, 0.1))
209
+ img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
210
+
211
+ # round and clip
212
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
213
+
214
+ # normalize
215
+ normalize(img_gt, self.mean, self.std, inplace=True)
216
+ normalize(img_lq, self.mean, self.std, inplace=True)
217
+
218
+ '''
219
+ if self.crop_components:
220
+ return_dict = {
221
+ 'lq': img_lq,
222
+ 'gt': img_gt,
223
+ 'gt_path': gt_path,
224
+ 'loc_left_eye': loc_left_eye,
225
+ 'loc_right_eye': loc_right_eye,
226
+ 'loc_mouth': loc_mouth
227
+ }
228
+ return return_dict
229
+ else:
230
+ return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
231
+ '''
232
+ return img_lq, img_gt
233
+
234
+ def __len__(self):
235
+ return len(self.paths)
datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
datasets/inference_dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.opts = opts
12
+
13
+ def __len__(self):
14
+ return len(self.paths)
15
+
16
+ def __getitem__(self, index):
17
+ from_path = self.paths[index]
18
+ from_im = Image.open(from_path)
19
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
20
+ if self.transform:
21
+ from_im = self.transform(from_im)
22
+ return from_im