PKUWilliamYang commited on
Commit
01ad5b5
1 Parent(s): ac1883f

Upload 5 files

Browse files
configs/__init__.py ADDED
File without changes
configs/data_configs.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['ffhq_test'],
11
+ 'test_target_root': dataset_paths['ffhq_test'],
12
+ },
13
+ 'ffhq_sketch_to_face': {
14
+ 'transforms': transforms_config.SketchToImageTransforms,
15
+ 'train_source_root': dataset_paths['ffhq_train_sketch'],
16
+ 'train_target_root': dataset_paths['ffhq'],
17
+ 'test_source_root': dataset_paths['ffhq_test_sketch'],
18
+ 'test_target_root': dataset_paths['ffhq_test'],
19
+ },
20
+ 'ffhq_seg_to_face': {
21
+ 'transforms': transforms_config.SegToImageTransforms,
22
+ 'train_source_root': dataset_paths['ffhq_train_segmentation'],
23
+ 'train_target_root': dataset_paths['ffhq'],
24
+ 'test_source_root': dataset_paths['ffhq_test_segmentation'],
25
+ 'test_target_root': dataset_paths['ffhq_test'],
26
+ },
27
+ 'ffhq_super_resolution': {
28
+ 'transforms': transforms_config.SuperResTransforms,
29
+ 'train_source_root': dataset_paths['ffhq'],
30
+ 'train_target_root': dataset_paths['ffhq1280'],
31
+ 'test_source_root': dataset_paths['ffhq_test'],
32
+ 'test_target_root': dataset_paths['ffhq1280_test'],
33
+ },
34
+ 'toonify': {
35
+ 'transforms': transforms_config.ToonifyTransforms,
36
+ 'train_source_root': dataset_paths['toonify_in'],
37
+ 'train_target_root': dataset_paths['toonify_out'],
38
+ 'test_source_root': dataset_paths['toonify_test_in'],
39
+ 'test_target_root': dataset_paths['toonify_test_out'],
40
+ },
41
+ 'ffhq_edit': {
42
+ 'transforms': transforms_config.EditingTransforms,
43
+ 'train_source_root': dataset_paths['ffhq'],
44
+ 'train_target_root': dataset_paths['ffhq'],
45
+ 'test_source_root': dataset_paths['ffhq_test'],
46
+ 'test_target_root': dataset_paths['ffhq_test'],
47
+ },
48
+ }
configs/dataset_config.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset and data loader settings
2
+ datasets:
3
+ train:
4
+ name: FFHQ
5
+ type: FFHQDegradationDataset
6
+ # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
7
+ dataroot_gt: ../../../../share/shuaiyang/ffhq/realign1280x1280test/
8
+ io_backend:
9
+ # type: lmdb
10
+ type: disk
11
+
12
+ use_hflip: true
13
+ mean: [0.5, 0.5, 0.5]
14
+ std: [0.5, 0.5, 0.5]
15
+ out_size: 1280
16
+ scale: 4
17
+
18
+ blur_kernel_size: 41
19
+ kernel_list: ['iso', 'aniso']
20
+ kernel_prob: [0.5, 0.5]
21
+ blur_sigma: [0.1, 10]
22
+ downsample_range: [4, 40]
23
+ noise_range: [0, 20]
24
+ jpeg_range: [60, 100]
25
+
26
+ # color jitter and gray
27
+ #color_jitter_prob: 0.3
28
+ #color_jitter_shift: 20
29
+ #color_jitter_pt_prob: 0.3
30
+ #gray_prob: 0.01
31
+
32
+ # If you do not want colorization, please set
33
+ color_jitter_prob: ~
34
+ color_jitter_pt_prob: ~
35
+ gray_prob: 0.01
36
+ gt_gray: True
37
+
38
+ crop_components: true
39
+ component_path: ./pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
40
+ eye_enlarge_ratio: 1.4
41
+
42
+ # data loader
43
+ use_shuffle: true
44
+ num_worker_per_gpu: 6
45
+ batch_size_per_gpu: 4
46
+ dataset_enlarge_ratio: 1
47
+ prefetch_mode: ~
48
+
49
+ val:
50
+ # Please modify accordingly to use your own validation
51
+ # Or comment the val block if do not need validation during training
52
+ name: validation
53
+ type: PairedImageDataset
54
+ dataroot_lq: datasets/faces/validation/input
55
+ dataroot_gt: datasets/faces/validation/reference
56
+ io_backend:
57
+ type: disk
58
+ mean: [0.5, 0.5, 0.5]
59
+ std: [0.5, 0.5, 0.5]
60
+ scale: 1
configs/paths_config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ 'ffhq': 'data/train/ffhq/realign320x320/',
3
+ 'ffhq_test': 'data/train/ffhq/realign320x320test/',
4
+ 'ffhq1280': 'data/train/ffhq/realign1280x1280/',
5
+ 'ffhq1280_test': 'data/train/ffhq/realign1280x1280test/',
6
+ 'ffhq_train_sketch': 'data/train/ffhq/realign640x640sketch/',
7
+ 'ffhq_test_sketch': 'data/train/ffhq/realign640x640sketchtest/',
8
+ 'ffhq_train_segmentation': 'data/train/ffhq/realign320x320mask/',
9
+ 'ffhq_test_segmentation': 'data/train/ffhq/realign320x320masktest/',
10
+ 'toonify_in': 'data/train/pixar/trainA/',
11
+ 'toonify_out': 'data/train/pixar/trainB/',
12
+ 'toonify_test_in': 'data/train/pixar/testA/',
13
+ 'toonify_test_out': 'data/train/testB/',
14
+ }
15
+
16
+ model_paths = {
17
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
18
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
19
+ 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
20
+ 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
21
+ 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
22
+ 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
23
+ 'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
24
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar'
25
+ }
configs/transforms_config.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+ from datasets import augmentations
4
+
5
+
6
+ class TransformsConfig(object):
7
+
8
+ def __init__(self, opts):
9
+ self.opts = opts
10
+
11
+ @abstractmethod
12
+ def get_transforms(self):
13
+ pass
14
+
15
+
16
+ class EncodeTransforms(TransformsConfig):
17
+
18
+ def __init__(self, opts):
19
+ super(EncodeTransforms, self).__init__(opts)
20
+
21
+ def get_transforms(self):
22
+ transforms_dict = {
23
+ 'transform_gt_train': transforms.Compose([
24
+ transforms.Resize((320, 320)),
25
+ transforms.RandomHorizontalFlip(0.5),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
28
+ 'transform_source': None,
29
+ 'transform_test': transforms.Compose([
30
+ transforms.Resize((320, 320)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
33
+ 'transform_inference': transforms.Compose([
34
+ transforms.Resize((320, 320)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37
+ }
38
+ return transforms_dict
39
+
40
+
41
+ class FrontalizationTransforms(TransformsConfig):
42
+
43
+ def __init__(self, opts):
44
+ super(FrontalizationTransforms, self).__init__(opts)
45
+
46
+ def get_transforms(self):
47
+ transforms_dict = {
48
+ 'transform_gt_train': transforms.Compose([
49
+ transforms.Resize((256, 256)),
50
+ transforms.RandomHorizontalFlip(0.5),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
53
+ 'transform_source': transforms.Compose([
54
+ transforms.Resize((256, 256)),
55
+ transforms.RandomHorizontalFlip(0.5),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
58
+ 'transform_test': transforms.Compose([
59
+ transforms.Resize((256, 256)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
62
+ 'transform_inference': transforms.Compose([
63
+ transforms.Resize((256, 256)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
66
+ }
67
+ return transforms_dict
68
+
69
+
70
+ class SketchToImageTransforms(TransformsConfig):
71
+
72
+ def __init__(self, opts):
73
+ super(SketchToImageTransforms, self).__init__(opts)
74
+
75
+ def get_transforms(self):
76
+ transforms_dict = {
77
+ 'transform_gt_train': transforms.Compose([
78
+ transforms.Resize((320, 320)),
79
+ transforms.ToTensor(),
80
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
81
+ 'transform_source': transforms.Compose([
82
+ transforms.Resize((320, 320)),
83
+ transforms.ToTensor()]),
84
+ 'transform_test': transforms.Compose([
85
+ transforms.Resize((320, 320)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
88
+ 'transform_inference': transforms.Compose([
89
+ transforms.Resize((320, 320)),
90
+ transforms.ToTensor()]),
91
+ }
92
+ return transforms_dict
93
+
94
+
95
+ class SegToImageTransforms(TransformsConfig):
96
+
97
+ def __init__(self, opts):
98
+ super(SegToImageTransforms, self).__init__(opts)
99
+
100
+ def get_transforms(self):
101
+ transforms_dict = {
102
+ 'transform_gt_train': transforms.Compose([
103
+ transforms.Resize((320, 320)),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
106
+ 'transform_source': transforms.Compose([
107
+ transforms.Resize((320, 320)),
108
+ augmentations.ToOneHot(self.opts.label_nc),
109
+ transforms.ToTensor()]),
110
+ 'transform_test': transforms.Compose([
111
+ transforms.Resize((320, 320)),
112
+ transforms.ToTensor(),
113
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
114
+ 'transform_inference': transforms.Compose([
115
+ transforms.Resize((320, 320)),
116
+ augmentations.ToOneHot(self.opts.label_nc),
117
+ transforms.ToTensor()])
118
+ }
119
+ return transforms_dict
120
+
121
+
122
+ class SuperResTransforms(TransformsConfig):
123
+
124
+ def __init__(self, opts):
125
+ super(SuperResTransforms, self).__init__(opts)
126
+
127
+ def get_transforms(self):
128
+ if self.opts.resize_factors is None:
129
+ self.opts.resize_factors = '1,2,4,8,16,32'
130
+ factors = [int(f) for f in self.opts.resize_factors.split(",")]
131
+ print("Performing down-sampling with factors: {}".format(factors))
132
+ transforms_dict = {
133
+ 'transform_gt_train': transforms.Compose([
134
+ transforms.Resize((1280, 1280)),
135
+ transforms.ToTensor(),
136
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
137
+ 'transform_source': transforms.Compose([
138
+ transforms.Resize((320, 320)),
139
+ augmentations.BilinearResize(factors=factors),
140
+ transforms.Resize((320, 320)),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
143
+ 'transform_test': transforms.Compose([
144
+ transforms.Resize((1280, 1280)),
145
+ transforms.ToTensor(),
146
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
147
+ 'transform_inference': transforms.Compose([
148
+ transforms.Resize((320, 320)),
149
+ augmentations.BilinearResize(factors=factors),
150
+ transforms.Resize((320, 320)),
151
+ transforms.ToTensor(),
152
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
153
+ }
154
+ return transforms_dict
155
+
156
+
157
+ class SuperResTransforms_320(TransformsConfig):
158
+
159
+ def __init__(self, opts):
160
+ super(SuperResTransforms_320, self).__init__(opts)
161
+
162
+ def get_transforms(self):
163
+ if self.opts.resize_factors is None:
164
+ self.opts.resize_factors = '1,2,4,8,16,32'
165
+ factors = [int(f) for f in self.opts.resize_factors.split(",")]
166
+ print("Performing down-sampling with factors: {}".format(factors))
167
+ transforms_dict = {
168
+ 'transform_gt_train': transforms.Compose([
169
+ transforms.Resize((320, 320)),
170
+ transforms.ToTensor(),
171
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
172
+ 'transform_source': transforms.Compose([
173
+ transforms.Resize((320, 320)),
174
+ augmentations.BilinearResize(factors=factors),
175
+ transforms.Resize((320, 320)),
176
+ transforms.ToTensor(),
177
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
178
+ 'transform_test': transforms.Compose([
179
+ transforms.Resize((320, 320)),
180
+ transforms.ToTensor(),
181
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
182
+ 'transform_inference': transforms.Compose([
183
+ transforms.Resize((320, 320)),
184
+ augmentations.BilinearResize(factors=factors),
185
+ transforms.Resize((320, 320)),
186
+ transforms.ToTensor(),
187
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
188
+ }
189
+ return transforms_dict
190
+
191
+
192
+ class ToonifyTransforms(TransformsConfig):
193
+
194
+ def __init__(self, opts):
195
+ super(ToonifyTransforms, self).__init__(opts)
196
+
197
+ def get_transforms(self):
198
+ transforms_dict = {
199
+ 'transform_gt_train': transforms.Compose([
200
+ transforms.Resize((1024, 1024)),
201
+ transforms.ToTensor(),
202
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
203
+ 'transform_source': transforms.Compose([
204
+ transforms.Resize((256, 256)),
205
+ transforms.ToTensor(),
206
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
207
+ 'transform_test': transforms.Compose([
208
+ transforms.Resize((1024, 1024)),
209
+ transforms.ToTensor(),
210
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
211
+ 'transform_inference': transforms.Compose([
212
+ transforms.Resize((256, 256)),
213
+ transforms.ToTensor(),
214
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
215
+ }
216
+ return transforms_dict
217
+
218
+ class EditingTransforms(TransformsConfig):
219
+
220
+ def __init__(self, opts):
221
+ super(EditingTransforms, self).__init__(opts)
222
+
223
+ def get_transforms(self):
224
+ transforms_dict = {
225
+ 'transform_gt_train': transforms.Compose([
226
+ transforms.Resize((1280, 1280)),
227
+ transforms.ToTensor(),
228
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
229
+ 'transform_source': transforms.Compose([
230
+ transforms.Resize((320, 320)),
231
+ transforms.ToTensor(),
232
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
233
+ 'transform_test': transforms.Compose([
234
+ transforms.Resize((1280, 1280)),
235
+ transforms.ToTensor(),
236
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
237
+ 'transform_inference': transforms.Compose([
238
+ transforms.Resize((320, 320)),
239
+ transforms.ToTensor(),
240
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
241
+ }
242
+ return transforms_dict