toto10 commited on
Commit
ca18358
1 Parent(s): 34743a6

0975a6be0dfbdd3323a94ec491197546e4fb2f8cc2232996355251f3fece7777

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. repositories/CodeFormer/facelib/utils/__pycache__/misc.cpython-310.pyc +0 -0
  3. repositories/CodeFormer/facelib/utils/face_restoration_helper.py +455 -0
  4. repositories/CodeFormer/facelib/utils/face_utils.py +248 -0
  5. repositories/CodeFormer/facelib/utils/misc.py +141 -0
  6. repositories/CodeFormer/inference_codeformer.py +189 -0
  7. repositories/CodeFormer/inputs/cropped_faces/0143.png +0 -0
  8. repositories/CodeFormer/inputs/cropped_faces/0240.png +0 -0
  9. repositories/CodeFormer/inputs/cropped_faces/0342.png +0 -0
  10. repositories/CodeFormer/inputs/cropped_faces/0345.png +0 -0
  11. repositories/CodeFormer/inputs/cropped_faces/0368.png +0 -0
  12. repositories/CodeFormer/inputs/cropped_faces/0412.png +0 -0
  13. repositories/CodeFormer/inputs/cropped_faces/0444.png +0 -0
  14. repositories/CodeFormer/inputs/cropped_faces/0478.png +0 -0
  15. repositories/CodeFormer/inputs/cropped_faces/0500.png +0 -0
  16. repositories/CodeFormer/inputs/cropped_faces/0599.png +0 -0
  17. repositories/CodeFormer/inputs/cropped_faces/0717.png +0 -0
  18. repositories/CodeFormer/inputs/cropped_faces/0720.png +0 -0
  19. repositories/CodeFormer/inputs/cropped_faces/0729.png +0 -0
  20. repositories/CodeFormer/inputs/cropped_faces/0763.png +0 -0
  21. repositories/CodeFormer/inputs/cropped_faces/0770.png +0 -0
  22. repositories/CodeFormer/inputs/cropped_faces/0777.png +0 -0
  23. repositories/CodeFormer/inputs/cropped_faces/0885.png +0 -0
  24. repositories/CodeFormer/inputs/cropped_faces/0934.png +0 -0
  25. repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_0018.png +0 -0
  26. repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_2_16.png +0 -0
  27. repositories/CodeFormer/inputs/whole_imgs/00.jpg +0 -0
  28. repositories/CodeFormer/inputs/whole_imgs/01.jpg +0 -0
  29. repositories/CodeFormer/inputs/whole_imgs/02.png +0 -0
  30. repositories/CodeFormer/inputs/whole_imgs/03.jpg +0 -0
  31. repositories/CodeFormer/inputs/whole_imgs/04.jpg +0 -0
  32. repositories/CodeFormer/inputs/whole_imgs/05.jpg +0 -0
  33. repositories/CodeFormer/inputs/whole_imgs/06.png +0 -0
  34. repositories/CodeFormer/predict.py +188 -0
  35. repositories/CodeFormer/requirements.txt +20 -0
  36. repositories/CodeFormer/scripts/crop_align_face.py +192 -0
  37. repositories/CodeFormer/scripts/download_pretrained_models.py +40 -0
  38. repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py +60 -0
  39. repositories/CodeFormer/weights/CodeFormer/.gitkeep +0 -0
  40. repositories/CodeFormer/weights/README.md +3 -0
  41. repositories/CodeFormer/weights/facelib/.gitkeep +0 -0
  42. repositories/generative-models/.gitignore +7 -0
  43. repositories/generative-models/LICENSE +75 -0
  44. repositories/generative-models/README.md +194 -0
  45. repositories/generative-models/assets/000.jpg +0 -0
  46. repositories/generative-models/assets/sdxl_report.pdf +3 -0
  47. repositories/generative-models/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +115 -0
  48. repositories/generative-models/configs/example_training/imagenet-f8_cond.yaml +188 -0
  49. repositories/generative-models/configs/example_training/toy/cifar10_cond.yaml +99 -0
  50. repositories/generative-models/configs/example_training/toy/mnist.yaml +80 -0
.gitattributes CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
37
  extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
38
  repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
 
 
36
  extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
37
  extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
38
  repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
39
+ repositories/generative-models/assets/sdxl_report.pdf filter=lfs diff=lfs merge=lfs -text
repositories/CodeFormer/facelib/utils/__pycache__/misc.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
repositories/CodeFormer/facelib/utils/face_restoration_helper.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from facelib.detection import init_detection_model
8
+ from facelib.parsing import init_parsing_model
9
+ from facelib.utils.misc import img2tensor, imwrite
10
+
11
+
12
+ def get_largest_face(det_faces, h, w):
13
+
14
+ def get_location(val, length):
15
+ if val < 0:
16
+ return 0
17
+ elif val > length:
18
+ return length
19
+ else:
20
+ return val
21
+
22
+ face_areas = []
23
+ for det_face in det_faces:
24
+ left = get_location(det_face[0], w)
25
+ right = get_location(det_face[2], w)
26
+ top = get_location(det_face[1], h)
27
+ bottom = get_location(det_face[3], h)
28
+ face_area = (right - left) * (bottom - top)
29
+ face_areas.append(face_area)
30
+ largest_idx = face_areas.index(max(face_areas))
31
+ return det_faces[largest_idx], largest_idx
32
+
33
+
34
+ def get_center_face(det_faces, h=0, w=0, center=None):
35
+ if center is not None:
36
+ center = np.array(center)
37
+ else:
38
+ center = np.array([w / 2, h / 2])
39
+ center_dist = []
40
+ for det_face in det_faces:
41
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
42
+ dist = np.linalg.norm(face_center - center)
43
+ center_dist.append(dist)
44
+ center_idx = center_dist.index(min(center_dist))
45
+ return det_faces[center_idx], center_idx
46
+
47
+
48
+ class FaceRestoreHelper(object):
49
+ """Helper for the face restoration pipeline (base class)."""
50
+
51
+ def __init__(self,
52
+ upscale_factor,
53
+ face_size=512,
54
+ crop_ratio=(1, 1),
55
+ det_model='retinaface_resnet50',
56
+ save_ext='png',
57
+ template_3points=False,
58
+ pad_blur=False,
59
+ use_parse=False,
60
+ device=None):
61
+ self.template_3points = template_3points # improve robustness
62
+ self.upscale_factor = upscale_factor
63
+ # the cropped face ratio based on the square face
64
+ self.crop_ratio = crop_ratio # (h, w)
65
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
66
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
67
+
68
+ if self.template_3points:
69
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
70
+ else:
71
+ # standard 5 landmarks for FFHQ faces with 512 x 512
72
+ # facexlib
73
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
74
+ [201.26117, 371.41043], [313.08905, 371.15118]])
75
+
76
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
77
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
78
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
79
+
80
+
81
+ self.face_template = self.face_template * (face_size / 512.0)
82
+ if self.crop_ratio[0] > 1:
83
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
84
+ if self.crop_ratio[1] > 1:
85
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
86
+ self.save_ext = save_ext
87
+ self.pad_blur = pad_blur
88
+ if self.pad_blur is True:
89
+ self.template_3points = False
90
+
91
+ self.all_landmarks_5 = []
92
+ self.det_faces = []
93
+ self.affine_matrices = []
94
+ self.inverse_affine_matrices = []
95
+ self.cropped_faces = []
96
+ self.restored_faces = []
97
+ self.pad_input_imgs = []
98
+
99
+ if device is None:
100
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101
+ else:
102
+ self.device = device
103
+
104
+ # init face detection model
105
+ self.face_det = init_detection_model(det_model, half=False, device=self.device)
106
+
107
+ # init face parsing model
108
+ self.use_parse = use_parse
109
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
110
+
111
+ def set_upscale_factor(self, upscale_factor):
112
+ self.upscale_factor = upscale_factor
113
+
114
+ def read_image(self, img):
115
+ """img can be image path or cv2 loaded image."""
116
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
117
+ if isinstance(img, str):
118
+ img = cv2.imread(img)
119
+
120
+ if np.max(img) > 256: # 16-bit image
121
+ img = img / 65535 * 255
122
+ if len(img.shape) == 2: # gray image
123
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
124
+ elif img.shape[2] == 4: # BGRA image with alpha channel
125
+ img = img[:, :, 0:3]
126
+
127
+ self.input_img = img
128
+
129
+ if min(self.input_img.shape[:2])<512:
130
+ f = 512.0/min(self.input_img.shape[:2])
131
+ self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
132
+
133
+ def get_face_landmarks_5(self,
134
+ only_keep_largest=False,
135
+ only_center_face=False,
136
+ resize=None,
137
+ blur_ratio=0.01,
138
+ eye_dist_threshold=None):
139
+ if resize is None:
140
+ scale = 1
141
+ input_img = self.input_img
142
+ else:
143
+ h, w = self.input_img.shape[0:2]
144
+ scale = resize / min(h, w)
145
+ scale = max(1, scale) # always scale up
146
+ h, w = int(h * scale), int(w * scale)
147
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
148
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
149
+
150
+ with torch.no_grad():
151
+ bboxes = self.face_det.detect_faces(input_img)
152
+
153
+ if bboxes is None or bboxes.shape[0] == 0:
154
+ return 0
155
+ else:
156
+ bboxes = bboxes / scale
157
+
158
+ for bbox in bboxes:
159
+ # remove faces with too small eye distance: side faces or too small faces
160
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
161
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
162
+ continue
163
+
164
+ if self.template_3points:
165
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
166
+ else:
167
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
168
+ self.all_landmarks_5.append(landmark)
169
+ self.det_faces.append(bbox[0:5])
170
+
171
+ if len(self.det_faces) == 0:
172
+ return 0
173
+ if only_keep_largest:
174
+ h, w, _ = self.input_img.shape
175
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
176
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
177
+ elif only_center_face:
178
+ h, w, _ = self.input_img.shape
179
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
180
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
181
+
182
+ # pad blurry images
183
+ if self.pad_blur:
184
+ self.pad_input_imgs = []
185
+ for landmarks in self.all_landmarks_5:
186
+ # get landmarks
187
+ eye_left = landmarks[0, :]
188
+ eye_right = landmarks[1, :]
189
+ eye_avg = (eye_left + eye_right) * 0.5
190
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
191
+ eye_to_eye = eye_right - eye_left
192
+ eye_to_mouth = mouth_avg - eye_avg
193
+
194
+ # Get the oriented crop rectangle
195
+ # x: half width of the oriented crop rectangle
196
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
197
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
198
+ # norm with the hypotenuse: get the direction
199
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
200
+ rect_scale = 1.5
201
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
202
+ # y: half height of the oriented crop rectangle
203
+ y = np.flipud(x) * [-1, 1]
204
+
205
+ # c: center
206
+ c = eye_avg + eye_to_mouth * 0.1
207
+ # quad: (left_top, left_bottom, right_bottom, right_top)
208
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
209
+ # qsize: side length of the square
210
+ qsize = np.hypot(*x) * 2
211
+ border = max(int(np.rint(qsize * 0.1)), 3)
212
+
213
+ # get pad
214
+ # pad: (width_left, height_top, width_right, height_bottom)
215
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
216
+ int(np.ceil(max(quad[:, 1]))))
217
+ pad = [
218
+ max(-pad[0] + border, 1),
219
+ max(-pad[1] + border, 1),
220
+ max(pad[2] - self.input_img.shape[0] + border, 1),
221
+ max(pad[3] - self.input_img.shape[1] + border, 1)
222
+ ]
223
+
224
+ if max(pad) > 1:
225
+ # pad image
226
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
227
+ # modify landmark coords
228
+ landmarks[:, 0] += pad[0]
229
+ landmarks[:, 1] += pad[1]
230
+ # blur pad images
231
+ h, w, _ = pad_img.shape
232
+ y, x, _ = np.ogrid[:h, :w, :1]
233
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
234
+ np.float32(w - 1 - x) / pad[2]),
235
+ 1.0 - np.minimum(np.float32(y) / pad[1],
236
+ np.float32(h - 1 - y) / pad[3]))
237
+ blur = int(qsize * blur_ratio)
238
+ if blur % 2 == 0:
239
+ blur += 1
240
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
241
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
242
+
243
+ pad_img = pad_img.astype('float32')
244
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
245
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
246
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
247
+ self.pad_input_imgs.append(pad_img)
248
+ else:
249
+ self.pad_input_imgs.append(np.copy(self.input_img))
250
+
251
+ return len(self.all_landmarks_5)
252
+
253
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
254
+ """Align and warp faces with face template.
255
+ """
256
+ if self.pad_blur:
257
+ assert len(self.pad_input_imgs) == len(
258
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
259
+ for idx, landmark in enumerate(self.all_landmarks_5):
260
+ # use 5 landmarks to get affine matrix
261
+ # use cv2.LMEDS method for the equivalence to skimage transform
262
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
263
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
264
+ self.affine_matrices.append(affine_matrix)
265
+ # warp and crop faces
266
+ if border_mode == 'constant':
267
+ border_mode = cv2.BORDER_CONSTANT
268
+ elif border_mode == 'reflect101':
269
+ border_mode = cv2.BORDER_REFLECT101
270
+ elif border_mode == 'reflect':
271
+ border_mode = cv2.BORDER_REFLECT
272
+ if self.pad_blur:
273
+ input_img = self.pad_input_imgs[idx]
274
+ else:
275
+ input_img = self.input_img
276
+ cropped_face = cv2.warpAffine(
277
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
278
+ self.cropped_faces.append(cropped_face)
279
+ # save the cropped face
280
+ if save_cropped_path is not None:
281
+ path = os.path.splitext(save_cropped_path)[0]
282
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
283
+ imwrite(cropped_face, save_path)
284
+
285
+ def get_inverse_affine(self, save_inverse_affine_path=None):
286
+ """Get inverse affine matrix."""
287
+ for idx, affine_matrix in enumerate(self.affine_matrices):
288
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
289
+ inverse_affine *= self.upscale_factor
290
+ self.inverse_affine_matrices.append(inverse_affine)
291
+ # save inverse affine matrices
292
+ if save_inverse_affine_path is not None:
293
+ path, _ = os.path.splitext(save_inverse_affine_path)
294
+ save_path = f'{path}_{idx:02d}.pth'
295
+ torch.save(inverse_affine, save_path)
296
+
297
+
298
+ def add_restored_face(self, face):
299
+ self.restored_faces.append(face)
300
+
301
+
302
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
303
+ h, w, _ = self.input_img.shape
304
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
305
+
306
+ if upsample_img is None:
307
+ # simply resize the background
308
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
309
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
310
+ else:
311
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
312
+
313
+ assert len(self.restored_faces) == len(
314
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
315
+
316
+ inv_mask_borders = []
317
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
318
+ if face_upsampler is not None:
319
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
320
+ inverse_affine /= self.upscale_factor
321
+ inverse_affine[:, 2] *= self.upscale_factor
322
+ face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
323
+ else:
324
+ # Add an offset to inverse affine matrix, for more precise back alignment
325
+ if self.upscale_factor > 1:
326
+ extra_offset = 0.5 * self.upscale_factor
327
+ else:
328
+ extra_offset = 0
329
+ inverse_affine[:, 2] += extra_offset
330
+ face_size = self.face_size
331
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
332
+
333
+ # if draw_box or not self.use_parse: # use square parse maps
334
+ # mask = np.ones(face_size, dtype=np.float32)
335
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
336
+ # # remove the black borders
337
+ # inv_mask_erosion = cv2.erode(
338
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
339
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
340
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
341
+ # # add border
342
+ # if draw_box:
343
+ # h, w = face_size
344
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
345
+ # border = int(1400/np.sqrt(total_face_area))
346
+ # mask_border[border:h-border, border:w-border,:] = 0
347
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
348
+ # inv_mask_borders.append(inv_mask_border)
349
+ # if not self.use_parse:
350
+ # # compute the fusion edge based on the area of face
351
+ # w_edge = int(total_face_area**0.5) // 20
352
+ # erosion_radius = w_edge * 2
353
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
354
+ # blur_size = w_edge * 2
355
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
356
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
357
+ # upsample_img = upsample_img[:, :, None]
358
+ # inv_soft_mask = inv_soft_mask[:, :, None]
359
+
360
+ # always use square mask
361
+ mask = np.ones(face_size, dtype=np.float32)
362
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
363
+ # remove the black borders
364
+ inv_mask_erosion = cv2.erode(
365
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
366
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
367
+ total_face_area = np.sum(inv_mask_erosion) # // 3
368
+ # add border
369
+ if draw_box:
370
+ h, w = face_size
371
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
372
+ border = int(1400/np.sqrt(total_face_area))
373
+ mask_border[border:h-border, border:w-border,:] = 0
374
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
375
+ inv_mask_borders.append(inv_mask_border)
376
+ # compute the fusion edge based on the area of face
377
+ w_edge = int(total_face_area**0.5) // 20
378
+ erosion_radius = w_edge * 2
379
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
380
+ blur_size = w_edge * 2
381
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
382
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
383
+ upsample_img = upsample_img[:, :, None]
384
+ inv_soft_mask = inv_soft_mask[:, :, None]
385
+
386
+ # parse mask
387
+ if self.use_parse:
388
+ # inference
389
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
390
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
391
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
392
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
393
+ with torch.no_grad():
394
+ out = self.face_parse(face_input)[0]
395
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
396
+
397
+ parse_mask = np.zeros(out.shape)
398
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
399
+ for idx, color in enumerate(MASK_COLORMAP):
400
+ parse_mask[out == idx] = color
401
+ # blur the mask
402
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
403
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
404
+ # remove the black borders
405
+ thres = 10
406
+ parse_mask[:thres, :] = 0
407
+ parse_mask[-thres:, :] = 0
408
+ parse_mask[:, :thres] = 0
409
+ parse_mask[:, -thres:] = 0
410
+ parse_mask = parse_mask / 255.
411
+
412
+ parse_mask = cv2.resize(parse_mask, face_size)
413
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
414
+ inv_soft_parse_mask = parse_mask[:, :, None]
415
+ # pasted_face = inv_restored
416
+ fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
417
+ inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
418
+
419
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
420
+ alpha = upsample_img[:, :, 3:]
421
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
422
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
423
+ else:
424
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
425
+
426
+ if np.max(upsample_img) > 256: # 16-bit image
427
+ upsample_img = upsample_img.astype(np.uint16)
428
+ else:
429
+ upsample_img = upsample_img.astype(np.uint8)
430
+
431
+ # draw bounding box
432
+ if draw_box:
433
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
434
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
435
+ img_color[:,:,0] = 0
436
+ img_color[:,:,1] = 255
437
+ img_color[:,:,2] = 0
438
+ for inv_mask_border in inv_mask_borders:
439
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
440
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
441
+
442
+ if save_path is not None:
443
+ path = os.path.splitext(save_path)[0]
444
+ save_path = f'{path}.{self.save_ext}'
445
+ imwrite(upsample_img, save_path)
446
+ return upsample_img
447
+
448
+ def clean_all(self):
449
+ self.all_landmarks_5 = []
450
+ self.restored_faces = []
451
+ self.affine_matrices = []
452
+ self.cropped_faces = []
453
+ self.inverse_affine_matrices = []
454
+ self.det_faces = []
455
+ self.pad_input_imgs = []
repositories/CodeFormer/facelib/utils/face_utils.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
7
+ left, top, right, bot = bbox
8
+ width = right - left
9
+ height = bot - top
10
+
11
+ if preserve_aspect:
12
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
13
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
14
+ else:
15
+ width_increase = height_increase = increase_area
16
+ left = int(left - width_increase * width)
17
+ top = int(top - height_increase * height)
18
+ right = int(right + width_increase * width)
19
+ bot = int(bot + height_increase * height)
20
+ return (left, top, right, bot)
21
+
22
+
23
+ def get_valid_bboxes(bboxes, h, w):
24
+ left = max(bboxes[0], 0)
25
+ top = max(bboxes[1], 0)
26
+ right = min(bboxes[2], w)
27
+ bottom = min(bboxes[3], h)
28
+ return (left, top, right, bottom)
29
+
30
+
31
+ def align_crop_face_landmarks(img,
32
+ landmarks,
33
+ output_size,
34
+ transform_size=None,
35
+ enable_padding=True,
36
+ return_inverse_affine=False,
37
+ shrink_ratio=(1, 1)):
38
+ """Align and crop face with landmarks.
39
+
40
+ The output_size and transform_size are based on width. The height is
41
+ adjusted based on shrink_ratio_h/shring_ration_w.
42
+
43
+ Modified from:
44
+ https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
45
+
46
+ Args:
47
+ img (Numpy array): Input image.
48
+ landmarks (Numpy array): 5 or 68 or 98 landmarks.
49
+ output_size (int): Output face size.
50
+ transform_size (ing): Transform size. Usually the four time of
51
+ output_size.
52
+ enable_padding (float): Default: True.
53
+ shrink_ratio (float | tuple[float] | list[float]): Shring the whole
54
+ face for height and width (crop larger area). Default: (1, 1).
55
+
56
+ Returns:
57
+ (Numpy array): Cropped face.
58
+ """
59
+ lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
60
+
61
+ if isinstance(shrink_ratio, (float, int)):
62
+ shrink_ratio = (shrink_ratio, shrink_ratio)
63
+ if transform_size is None:
64
+ transform_size = output_size * 4
65
+
66
+ # Parse landmarks
67
+ lm = np.array(landmarks)
68
+ if lm.shape[0] == 5 and lm_type == 'retinaface_5':
69
+ eye_left = lm[0]
70
+ eye_right = lm[1]
71
+ mouth_avg = (lm[3] + lm[4]) * 0.5
72
+ elif lm.shape[0] == 5 and lm_type == 'dlib_5':
73
+ lm_eye_left = lm[2:4]
74
+ lm_eye_right = lm[0:2]
75
+ eye_left = np.mean(lm_eye_left, axis=0)
76
+ eye_right = np.mean(lm_eye_right, axis=0)
77
+ mouth_avg = lm[4]
78
+ elif lm.shape[0] == 68:
79
+ lm_eye_left = lm[36:42]
80
+ lm_eye_right = lm[42:48]
81
+ eye_left = np.mean(lm_eye_left, axis=0)
82
+ eye_right = np.mean(lm_eye_right, axis=0)
83
+ mouth_avg = (lm[48] + lm[54]) * 0.5
84
+ elif lm.shape[0] == 98:
85
+ lm_eye_left = lm[60:68]
86
+ lm_eye_right = lm[68:76]
87
+ eye_left = np.mean(lm_eye_left, axis=0)
88
+ eye_right = np.mean(lm_eye_right, axis=0)
89
+ mouth_avg = (lm[76] + lm[82]) * 0.5
90
+
91
+ eye_avg = (eye_left + eye_right) * 0.5
92
+ eye_to_eye = eye_right - eye_left
93
+ eye_to_mouth = mouth_avg - eye_avg
94
+
95
+ # Get the oriented crop rectangle
96
+ # x: half width of the oriented crop rectangle
97
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
98
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
99
+ # norm with the hypotenuse: get the direction
100
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
101
+ rect_scale = 1 # TODO: you can edit it to get larger rect
102
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
103
+ # y: half height of the oriented crop rectangle
104
+ y = np.flipud(x) * [-1, 1]
105
+
106
+ x *= shrink_ratio[1] # width
107
+ y *= shrink_ratio[0] # height
108
+
109
+ # c: center
110
+ c = eye_avg + eye_to_mouth * 0.1
111
+ # quad: (left_top, left_bottom, right_bottom, right_top)
112
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
113
+ # qsize: side length of the square
114
+ qsize = np.hypot(*x) * 2
115
+
116
+ quad_ori = np.copy(quad)
117
+ # Shrink, for large face
118
+ # TODO: do we really need shrink
119
+ shrink = int(np.floor(qsize / output_size * 0.5))
120
+ if shrink > 1:
121
+ h, w = img.shape[0:2]
122
+ rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
123
+ img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
124
+ quad /= shrink
125
+ qsize /= shrink
126
+
127
+ # Crop
128
+ h, w = img.shape[0:2]
129
+ border = max(int(np.rint(qsize * 0.1)), 3)
130
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
131
+ int(np.ceil(max(quad[:, 1]))))
132
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
133
+ if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
134
+ img = img[crop[1]:crop[3], crop[0]:crop[2], :]
135
+ quad -= crop[0:2]
136
+
137
+ # Pad
138
+ # pad: (width_left, height_top, width_right, height_bottom)
139
+ h, w = img.shape[0:2]
140
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
141
+ int(np.ceil(max(quad[:, 1]))))
142
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
143
+ if enable_padding and max(pad) > border - 4:
144
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
145
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
146
+ h, w = img.shape[0:2]
147
+ y, x, _ = np.ogrid[:h, :w, :1]
148
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
149
+ np.float32(w - 1 - x) / pad[2]),
150
+ 1.0 - np.minimum(np.float32(y) / pad[1],
151
+ np.float32(h - 1 - y) / pad[3]))
152
+ blur = int(qsize * 0.02)
153
+ if blur % 2 == 0:
154
+ blur += 1
155
+ blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
156
+
157
+ img = img.astype('float32')
158
+ img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
159
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
160
+ img = np.clip(img, 0, 255) # float32, [0, 255]
161
+ quad += pad[:2]
162
+
163
+ # Transform use cv2
164
+ h_ratio = shrink_ratio[0] / shrink_ratio[1]
165
+ dst_h, dst_w = int(transform_size * h_ratio), transform_size
166
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
167
+ # use cv2.LMEDS method for the equivalence to skimage transform
168
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
169
+ affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
170
+ cropped_face = cv2.warpAffine(
171
+ img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
172
+
173
+ if output_size < transform_size:
174
+ cropped_face = cv2.resize(
175
+ cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
176
+
177
+ if return_inverse_affine:
178
+ dst_h, dst_w = int(output_size * h_ratio), output_size
179
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
180
+ # use cv2.LMEDS method for the equivalence to skimage transform
181
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
182
+ affine_matrix = cv2.estimateAffinePartial2D(
183
+ quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
184
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
185
+ else:
186
+ inverse_affine = None
187
+ return cropped_face, inverse_affine
188
+
189
+
190
+ def paste_face_back(img, face, inverse_affine):
191
+ h, w = img.shape[0:2]
192
+ face_h, face_w = face.shape[0:2]
193
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
194
+ mask = np.ones((face_h, face_w, 3), dtype=np.float32)
195
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
196
+ # remove the black borders
197
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
198
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
199
+ total_face_area = np.sum(inv_mask_erosion) // 3
200
+ # compute the fusion edge based on the area of face
201
+ w_edge = int(total_face_area**0.5) // 20
202
+ erosion_radius = w_edge * 2
203
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
204
+ blur_size = w_edge * 2
205
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
206
+ img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
207
+ # float32, [0, 255]
208
+ return img
209
+
210
+
211
+ if __name__ == '__main__':
212
+ import os
213
+
214
+ from facelib.detection import init_detection_model
215
+ from facelib.utils.face_restoration_helper import get_largest_face
216
+
217
+ img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
218
+ img_name = os.splitext(os.path.basename(img_path))[0]
219
+
220
+ # initialize model
221
+ det_net = init_detection_model('retinaface_resnet50', half=False)
222
+ img_ori = cv2.imread(img_path)
223
+ h, w = img_ori.shape[0:2]
224
+ # if larger than 800, scale it
225
+ scale = max(h / 800, w / 800)
226
+ if scale > 1:
227
+ img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
228
+
229
+ with torch.no_grad():
230
+ bboxes = det_net.detect_faces(img, 0.97)
231
+ if scale > 1:
232
+ bboxes *= scale # the score is incorrect
233
+ bboxes = get_largest_face(bboxes, h, w)[0]
234
+
235
+ landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
236
+
237
+ cropped_face, inverse_affine = align_crop_face_landmarks(
238
+ img_ori,
239
+ landmarks,
240
+ output_size=512,
241
+ transform_size=None,
242
+ enable_padding=True,
243
+ return_inverse_affine=True,
244
+ shrink_ratio=(1, 1))
245
+
246
+ cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
247
+ img = paste_face_back(img_ori, cropped_face, inverse_affine)
248
+ cv2.imwrite(f'tmp/{img_name}_back.png', img)
repositories/CodeFormer/facelib/utils/misc.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import os.path as osp
4
+ import torch
5
+ from torch.hub import download_url_to_file, get_dir
6
+ from urllib.parse import urlparse
7
+ # from basicsr.utils.download_util import download_file_from_google_drive
8
+ import gdown
9
+
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+
14
+ def download_pretrained_models(file_ids, save_path_root):
15
+ os.makedirs(save_path_root, exist_ok=True)
16
+
17
+ for file_name, file_id in file_ids.items():
18
+ file_url = 'https://drive.google.com/uc?id='+file_id
19
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
20
+ if osp.exists(save_path):
21
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
22
+ if user_response.lower() == 'y':
23
+ print(f'Covering {file_name} to {save_path}')
24
+ gdown.download(file_url, save_path, quiet=False)
25
+ # download_file_from_google_drive(file_id, save_path)
26
+ elif user_response.lower() == 'n':
27
+ print(f'Skipping {file_name}')
28
+ else:
29
+ raise ValueError('Wrong input. Only accepts Y/N.')
30
+ else:
31
+ print(f'Downloading {file_name} to {save_path}')
32
+ gdown.download(file_url, save_path, quiet=False)
33
+ # download_file_from_google_drive(file_id, save_path)
34
+
35
+
36
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
37
+ """Write image to file.
38
+
39
+ Args:
40
+ img (ndarray): Image array to be written.
41
+ file_path (str): Image file path.
42
+ params (None or list): Same as opencv's :func:`imwrite` interface.
43
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
44
+ whether to create it automatically.
45
+
46
+ Returns:
47
+ bool: Successful or not.
48
+ """
49
+ if auto_mkdir:
50
+ dir_name = os.path.abspath(os.path.dirname(file_path))
51
+ os.makedirs(dir_name, exist_ok=True)
52
+ return cv2.imwrite(file_path, img, params)
53
+
54
+
55
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
56
+ """Numpy array to tensor.
57
+
58
+ Args:
59
+ imgs (list[ndarray] | ndarray): Input images.
60
+ bgr2rgb (bool): Whether to change bgr to rgb.
61
+ float32 (bool): Whether to change to float32.
62
+
63
+ Returns:
64
+ list[tensor] | tensor: Tensor images. If returned results only have
65
+ one element, just return tensor.
66
+ """
67
+
68
+ def _totensor(img, bgr2rgb, float32):
69
+ if img.shape[2] == 3 and bgr2rgb:
70
+ if img.dtype == 'float64':
71
+ img = img.astype('float32')
72
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
73
+ img = torch.from_numpy(img.transpose(2, 0, 1))
74
+ if float32:
75
+ img = img.float()
76
+ return img
77
+
78
+ if isinstance(imgs, list):
79
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
80
+ else:
81
+ return _totensor(imgs, bgr2rgb, float32)
82
+
83
+
84
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
85
+ """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
86
+ """
87
+ if model_dir is None:
88
+ hub_dir = get_dir()
89
+ model_dir = os.path.join(hub_dir, 'checkpoints')
90
+
91
+ os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
92
+
93
+ parts = urlparse(url)
94
+ filename = os.path.basename(parts.path)
95
+ if file_name is not None:
96
+ filename = file_name
97
+ cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
98
+ if not os.path.exists(cached_file):
99
+ print(f'Downloading: "{url}" to {cached_file}\n')
100
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
101
+ return cached_file
102
+
103
+
104
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
105
+ """Scan a directory to find the interested files.
106
+ Args:
107
+ dir_path (str): Path of the directory.
108
+ suffix (str | tuple(str), optional): File suffix that we are
109
+ interested in. Default: None.
110
+ recursive (bool, optional): If set to True, recursively scan the
111
+ directory. Default: False.
112
+ full_path (bool, optional): If set to True, include the dir_path.
113
+ Default: False.
114
+ Returns:
115
+ A generator for all the interested files with relative paths.
116
+ """
117
+
118
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
119
+ raise TypeError('"suffix" must be a string or tuple of strings')
120
+
121
+ root = dir_path
122
+
123
+ def _scandir(dir_path, suffix, recursive):
124
+ for entry in os.scandir(dir_path):
125
+ if not entry.name.startswith('.') and entry.is_file():
126
+ if full_path:
127
+ return_path = entry.path
128
+ else:
129
+ return_path = osp.relpath(entry.path, root)
130
+
131
+ if suffix is None:
132
+ yield return_path
133
+ elif return_path.endswith(suffix):
134
+ yield return_path
135
+ else:
136
+ if recursive:
137
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
138
+ else:
139
+ continue
140
+
141
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
repositories/CodeFormer/inference_codeformer.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
2
+ import os
3
+ import cv2
4
+ import argparse
5
+ import glob
6
+ import torch
7
+ from torchvision.transforms.functional import normalize
8
+ from basicsr.utils import imwrite, img2tensor, tensor2img
9
+ from basicsr.utils.download_util import load_file_from_url
10
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
+ import torch.nn.functional as F
12
+
13
+ from basicsr.utils.registry import ARCH_REGISTRY
14
+
15
+ pretrain_model_url = {
16
+ 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
17
+ }
18
+
19
+ def set_realesrgan():
20
+ if not torch.cuda.is_available(): # CPU
21
+ import warnings
22
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
23
+ 'If you really want to use it, please modify the corresponding codes.',
24
+ category=RuntimeWarning)
25
+ bg_upsampler = None
26
+ else:
27
+ from basicsr.archs.rrdbnet_arch import RRDBNet
28
+ from basicsr.utils.realesrgan_utils import RealESRGANer
29
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
30
+ bg_upsampler = RealESRGANer(
31
+ scale=2,
32
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
33
+ model=model,
34
+ tile=args.bg_tile,
35
+ tile_pad=40,
36
+ pre_pad=0,
37
+ half=True) # need to set False in CPU mode
38
+ return bg_upsampler
39
+
40
+ if __name__ == '__main__':
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
+ parser = argparse.ArgumentParser()
43
+
44
+ parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
45
+ parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
46
+ parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
47
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
48
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
49
+ # large det_model: 'YOLOv5l', 'retinaface_resnet50'
50
+ # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
51
+ parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
52
+ parser.add_argument('--draw_box', action='store_true')
53
+ parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
54
+ parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
55
+ parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
56
+
57
+ args = parser.parse_args()
58
+
59
+ # ------------------------ input & output ------------------------
60
+ if args.test_path.endswith('/'): # solve when path ends with /
61
+ args.test_path = args.test_path[:-1]
62
+
63
+ w = args.w
64
+ result_root = f'results/{os.path.basename(args.test_path)}_{w}'
65
+
66
+ # ------------------ set up background upsampler ------------------
67
+ if args.bg_upsampler == 'realesrgan':
68
+ bg_upsampler = set_realesrgan()
69
+ else:
70
+ bg_upsampler = None
71
+
72
+ # ------------------ set up face upsampler ------------------
73
+ if args.face_upsample:
74
+ if bg_upsampler is not None:
75
+ face_upsampler = bg_upsampler
76
+ else:
77
+ face_upsampler = set_realesrgan()
78
+ else:
79
+ face_upsampler = None
80
+
81
+ # ------------------ set up CodeFormer restorer -------------------
82
+ net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
83
+ connect_list=['32', '64', '128', '256']).to(device)
84
+
85
+ # ckpt_path = 'weights/CodeFormer/codeformer.pth'
86
+ ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
87
+ model_dir='weights/CodeFormer', progress=True, file_name=None)
88
+ checkpoint = torch.load(ckpt_path)['params_ema']
89
+ net.load_state_dict(checkpoint)
90
+ net.eval()
91
+
92
+ # ------------------ set up FaceRestoreHelper -------------------
93
+ # large det_model: 'YOLOv5l', 'retinaface_resnet50'
94
+ # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
95
+ if not args.has_aligned:
96
+ print(f'Face detection model: {args.detection_model}')
97
+ if bg_upsampler is not None:
98
+ print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
99
+ else:
100
+ print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
101
+
102
+ face_helper = FaceRestoreHelper(
103
+ args.upscale,
104
+ face_size=512,
105
+ crop_ratio=(1, 1),
106
+ det_model = args.detection_model,
107
+ save_ext='png',
108
+ use_parse=True,
109
+ device=device)
110
+
111
+ # -------------------- start to processing ---------------------
112
+ # scan all the jpg and png images
113
+ for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
114
+ # clean all the intermediate results to process the next image
115
+ face_helper.clean_all()
116
+
117
+ img_name = os.path.basename(img_path)
118
+ print(f'Processing: {img_name}')
119
+ basename, ext = os.path.splitext(img_name)
120
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
121
+
122
+ if args.has_aligned:
123
+ # the input faces are already cropped and aligned
124
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
125
+ face_helper.cropped_faces = [img]
126
+ else:
127
+ face_helper.read_image(img)
128
+ # get face landmarks for each face
129
+ num_det_faces = face_helper.get_face_landmarks_5(
130
+ only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
131
+ print(f'\tdetect {num_det_faces} faces')
132
+ # align and warp each face
133
+ face_helper.align_warp_face()
134
+
135
+ # face restoration for each cropped face
136
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
137
+ # prepare data
138
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
139
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
140
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
141
+
142
+ try:
143
+ with torch.no_grad():
144
+ output = net(cropped_face_t, w=w, adain=True)[0]
145
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
146
+ del output
147
+ torch.cuda.empty_cache()
148
+ except Exception as error:
149
+ print(f'\tFailed inference for CodeFormer: {error}')
150
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
151
+
152
+ restored_face = restored_face.astype('uint8')
153
+ face_helper.add_restored_face(restored_face)
154
+
155
+ # paste_back
156
+ if not args.has_aligned:
157
+ # upsample the background
158
+ if bg_upsampler is not None:
159
+ # Now only support RealESRGAN for upsampling background
160
+ bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
161
+ else:
162
+ bg_img = None
163
+ face_helper.get_inverse_affine(None)
164
+ # paste each restored face to the input image
165
+ if args.face_upsample and face_upsampler is not None:
166
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
167
+ else:
168
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
169
+
170
+ # save faces
171
+ for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
172
+ # save cropped face
173
+ if not args.has_aligned:
174
+ save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
175
+ imwrite(cropped_face, save_crop_path)
176
+ # save restored face
177
+ if args.has_aligned:
178
+ save_face_name = f'{basename}.png'
179
+ else:
180
+ save_face_name = f'{basename}_{idx:02d}.png'
181
+ save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
182
+ imwrite(restored_face, save_restore_path)
183
+
184
+ # save restored img
185
+ if not args.has_aligned and restored_img is not None:
186
+ save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
187
+ imwrite(restored_img, save_restore_path)
188
+
189
+ print(f'\nAll results are saved in {result_root}')
repositories/CodeFormer/inputs/cropped_faces/0143.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0240.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0342.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0345.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0368.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0412.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0444.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0478.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0500.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0599.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0717.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0720.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0729.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0763.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0770.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0777.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0885.png ADDED
repositories/CodeFormer/inputs/cropped_faces/0934.png ADDED
repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_0018.png ADDED
repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_2_16.png ADDED
repositories/CodeFormer/inputs/whole_imgs/00.jpg ADDED
repositories/CodeFormer/inputs/whole_imgs/01.jpg ADDED
repositories/CodeFormer/inputs/whole_imgs/02.png ADDED
repositories/CodeFormer/inputs/whole_imgs/03.jpg ADDED
repositories/CodeFormer/inputs/whole_imgs/04.jpg ADDED
repositories/CodeFormer/inputs/whole_imgs/05.jpg ADDED
repositories/CodeFormer/inputs/whole_imgs/06.png ADDED
repositories/CodeFormer/predict.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download checkpoints to ./weights beforehand
3
+ python scripts/download_pretrained_models.py facelib
4
+ python scripts/download_pretrained_models.py CodeFormer
5
+ wget 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
6
+ """
7
+
8
+ import tempfile
9
+ import cv2
10
+ import torch
11
+ from torchvision.transforms.functional import normalize
12
+ from cog import BasePredictor, Input, Path
13
+
14
+ from basicsr.utils import imwrite, img2tensor, tensor2img
15
+ from basicsr.archs.rrdbnet_arch import RRDBNet
16
+ from basicsr.utils.realesrgan_utils import RealESRGANer
17
+ from basicsr.utils.registry import ARCH_REGISTRY
18
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
19
+
20
+
21
+ class Predictor(BasePredictor):
22
+ def setup(self):
23
+ """Load the model into memory to make running multiple predictions efficient"""
24
+ self.device = "cuda:0"
25
+ self.bg_upsampler = set_realesrgan()
26
+ self.net = ARCH_REGISTRY.get("CodeFormer")(
27
+ dim_embd=512,
28
+ codebook_size=1024,
29
+ n_head=8,
30
+ n_layers=9,
31
+ connect_list=["32", "64", "128", "256"],
32
+ ).to(self.device)
33
+ ckpt_path = "weights/CodeFormer/codeformer.pth"
34
+ checkpoint = torch.load(ckpt_path)[
35
+ "params_ema"
36
+ ] # update file permission if cannot load
37
+ self.net.load_state_dict(checkpoint)
38
+ self.net.eval()
39
+
40
+ def predict(
41
+ self,
42
+ image: Path = Input(description="Input image"),
43
+ codeformer_fidelity: float = Input(
44
+ default=0.5,
45
+ ge=0,
46
+ le=1,
47
+ description="Balance the quality (lower number) and fidelity (higher number).",
48
+ ),
49
+ background_enhance: bool = Input(
50
+ description="Enhance background image with Real-ESRGAN", default=True
51
+ ),
52
+ face_upsample: bool = Input(
53
+ description="Upsample restored faces for high-resolution AI-created images",
54
+ default=True,
55
+ ),
56
+ upscale: int = Input(
57
+ description="The final upsampling scale of the image",
58
+ default=2,
59
+ ),
60
+ ) -> Path:
61
+ """Run a single prediction on the model"""
62
+
63
+ # take the default setting for the demo
64
+ has_aligned = False
65
+ only_center_face = False
66
+ draw_box = False
67
+ detection_model = "retinaface_resnet50"
68
+
69
+ self.face_helper = FaceRestoreHelper(
70
+ upscale,
71
+ face_size=512,
72
+ crop_ratio=(1, 1),
73
+ det_model=detection_model,
74
+ save_ext="png",
75
+ use_parse=True,
76
+ device=self.device,
77
+ )
78
+
79
+ bg_upsampler = self.bg_upsampler if background_enhance else None
80
+ face_upsampler = self.bg_upsampler if face_upsample else None
81
+
82
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
83
+
84
+ if has_aligned:
85
+ # the input faces are already cropped and aligned
86
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
87
+ self.face_helper.cropped_faces = [img]
88
+ else:
89
+ self.face_helper.read_image(img)
90
+ # get face landmarks for each face
91
+ num_det_faces = self.face_helper.get_face_landmarks_5(
92
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
93
+ )
94
+ print(f"\tdetect {num_det_faces} faces")
95
+ # align and warp each face
96
+ self.face_helper.align_warp_face()
97
+
98
+ # face restoration for each cropped face
99
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
100
+ # prepare data
101
+ cropped_face_t = img2tensor(
102
+ cropped_face / 255.0, bgr2rgb=True, float32=True
103
+ )
104
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
105
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
106
+
107
+ try:
108
+ with torch.no_grad():
109
+ output = self.net(
110
+ cropped_face_t, w=codeformer_fidelity, adain=True
111
+ )[0]
112
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
113
+ del output
114
+ torch.cuda.empty_cache()
115
+ except Exception as error:
116
+ print(f"\tFailed inference for CodeFormer: {error}")
117
+ restored_face = tensor2img(
118
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
119
+ )
120
+
121
+ restored_face = restored_face.astype("uint8")
122
+ self.face_helper.add_restored_face(restored_face)
123
+
124
+ # paste_back
125
+ if not has_aligned:
126
+ # upsample the background
127
+ if bg_upsampler is not None:
128
+ # Now only support RealESRGAN for upsampling background
129
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
130
+ else:
131
+ bg_img = None
132
+ self.face_helper.get_inverse_affine(None)
133
+ # paste each restored face to the input image
134
+ if face_upsample and face_upsampler is not None:
135
+ restored_img = self.face_helper.paste_faces_to_input_image(
136
+ upsample_img=bg_img,
137
+ draw_box=draw_box,
138
+ face_upsampler=face_upsampler,
139
+ )
140
+ else:
141
+ restored_img = self.face_helper.paste_faces_to_input_image(
142
+ upsample_img=bg_img, draw_box=draw_box
143
+ )
144
+
145
+ # save restored img
146
+ out_path = Path(tempfile.mkdtemp()) / "output.png"
147
+
148
+ if not has_aligned and restored_img is not None:
149
+ imwrite(restored_img, str(out_path))
150
+
151
+ return out_path
152
+
153
+
154
+ def imread(img_path):
155
+ img = cv2.imread(img_path)
156
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
157
+ return img
158
+
159
+
160
+ def set_realesrgan():
161
+ if not torch.cuda.is_available(): # CPU
162
+ import warnings
163
+
164
+ warnings.warn(
165
+ "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
166
+ "If you really want to use it, please modify the corresponding codes.",
167
+ category=RuntimeWarning,
168
+ )
169
+ bg_upsampler = None
170
+ else:
171
+ model = RRDBNet(
172
+ num_in_ch=3,
173
+ num_out_ch=3,
174
+ num_feat=64,
175
+ num_block=23,
176
+ num_grow_ch=32,
177
+ scale=2,
178
+ )
179
+ bg_upsampler = RealESRGANer(
180
+ scale=2,
181
+ model_path="./weights/RealESRGAN_x2plus.pth",
182
+ model=model,
183
+ tile=400,
184
+ tile_pad=40,
185
+ pre_pad=0,
186
+ half=True,
187
+ )
188
+ return bg_upsampler
repositories/CodeFormer/requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ future
3
+ lmdb
4
+ numpy
5
+ opencv-python
6
+ Pillow
7
+ pyyaml
8
+ requests
9
+ scikit-image
10
+ scipy
11
+ tb-nightly
12
+ torch>=1.7.1
13
+ torchvision
14
+ tqdm
15
+ yapf
16
+ lpips
17
+ gdown # supports downloading the large file from Google Drive
18
+ # cmake
19
+ # dlib
20
+ # conda install -c conda-forge dlib
repositories/CodeFormer/scripts/crop_align_face.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
3
+ author: lzhbrian (https://lzhbrian.me)
4
+ link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
5
+ date: 2020.1.5
6
+ note: code is heavily borrowed from
7
+ https://github.com/NVlabs/ffhq-dataset
8
+ http://dlib.net/face_landmark_detection.py.html
9
+ requirements:
10
+ conda install Pillow numpy scipy
11
+ conda install -c conda-forge dlib
12
+ # download face landmark model from:
13
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
14
+ """
15
+
16
+ import cv2
17
+ import dlib
18
+ import glob
19
+ import numpy as np
20
+ import os
21
+ import PIL
22
+ import PIL.Image
23
+ import scipy
24
+ import scipy.ndimage
25
+ import sys
26
+ import argparse
27
+
28
+ # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
29
+ predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
30
+
31
+
32
+ def get_landmark(filepath, only_keep_largest=True):
33
+ """get landmark with dlib
34
+ :return: np.array shape=(68, 2)
35
+ """
36
+ detector = dlib.get_frontal_face_detector()
37
+
38
+ img = dlib.load_rgb_image(filepath)
39
+ dets = detector(img, 1)
40
+
41
+ # Shangchen modified
42
+ print("Number of faces detected: {}".format(len(dets)))
43
+ if only_keep_largest:
44
+ print('Detect several faces and only keep the largest.')
45
+ face_areas = []
46
+ for k, d in enumerate(dets):
47
+ face_area = (d.right() - d.left()) * (d.bottom() - d.top())
48
+ face_areas.append(face_area)
49
+
50
+ largest_idx = face_areas.index(max(face_areas))
51
+ d = dets[largest_idx]
52
+ shape = predictor(img, d)
53
+ print("Part 0: {}, Part 1: {} ...".format(
54
+ shape.part(0), shape.part(1)))
55
+ else:
56
+ for k, d in enumerate(dets):
57
+ print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
58
+ k, d.left(), d.top(), d.right(), d.bottom()))
59
+ # Get the landmarks/parts for the face in box d.
60
+ shape = predictor(img, d)
61
+ print("Part 0: {}, Part 1: {} ...".format(
62
+ shape.part(0), shape.part(1)))
63
+
64
+ t = list(shape.parts())
65
+ a = []
66
+ for tt in t:
67
+ a.append([tt.x, tt.y])
68
+ lm = np.array(a)
69
+ # lm is a shape=(68,2) np.array
70
+ return lm
71
+
72
+ def align_face(filepath, out_path):
73
+ """
74
+ :param filepath: str
75
+ :return: PIL Image
76
+ """
77
+ try:
78
+ lm = get_landmark(filepath)
79
+ except:
80
+ print('No landmark ...')
81
+ return
82
+
83
+ lm_chin = lm[0:17] # left-right
84
+ lm_eyebrow_left = lm[17:22] # left-right
85
+ lm_eyebrow_right = lm[22:27] # left-right
86
+ lm_nose = lm[27:31] # top-down
87
+ lm_nostrils = lm[31:36] # top-down
88
+ lm_eye_left = lm[36:42] # left-clockwise
89
+ lm_eye_right = lm[42:48] # left-clockwise
90
+ lm_mouth_outer = lm[48:60] # left-clockwise
91
+ lm_mouth_inner = lm[60:68] # left-clockwise
92
+
93
+ # Calculate auxiliary vectors.
94
+ eye_left = np.mean(lm_eye_left, axis=0)
95
+ eye_right = np.mean(lm_eye_right, axis=0)
96
+ eye_avg = (eye_left + eye_right) * 0.5
97
+ eye_to_eye = eye_right - eye_left
98
+ mouth_left = lm_mouth_outer[0]
99
+ mouth_right = lm_mouth_outer[6]
100
+ mouth_avg = (mouth_left + mouth_right) * 0.5
101
+ eye_to_mouth = mouth_avg - eye_avg
102
+
103
+ # Choose oriented crop rectangle.
104
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
105
+ x /= np.hypot(*x)
106
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
107
+ y = np.flipud(x) * [-1, 1]
108
+ c = eye_avg + eye_to_mouth * 0.1
109
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
110
+ qsize = np.hypot(*x) * 2
111
+
112
+ # read image
113
+ img = PIL.Image.open(filepath)
114
+
115
+ output_size = 512
116
+ transform_size = 4096
117
+ enable_padding = False
118
+
119
+ # Shrink.
120
+ shrink = int(np.floor(qsize / output_size * 0.5))
121
+ if shrink > 1:
122
+ rsize = (int(np.rint(float(img.size[0]) / shrink)),
123
+ int(np.rint(float(img.size[1]) / shrink)))
124
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
125
+ quad /= shrink
126
+ qsize /= shrink
127
+
128
+ # Crop.
129
+ border = max(int(np.rint(qsize * 0.1)), 3)
130
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
131
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
132
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
133
+ min(crop[2] + border,
134
+ img.size[0]), min(crop[3] + border, img.size[1]))
135
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
136
+ img = img.crop(crop)
137
+ quad -= crop[0:2]
138
+
139
+ # Pad.
140
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
141
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
142
+ pad = (max(-pad[0] + border,
143
+ 0), max(-pad[1] + border,
144
+ 0), max(pad[2] - img.size[0] + border,
145
+ 0), max(pad[3] - img.size[1] + border, 0))
146
+ if enable_padding and max(pad) > border - 4:
147
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
148
+ img = np.pad(
149
+ np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
150
+ 'reflect')
151
+ h, w, _ = img.shape
152
+ y, x, _ = np.ogrid[:h, :w, :1]
153
+ mask = np.maximum(
154
+ 1.0 -
155
+ np.minimum(np.float32(x) / pad[0],
156
+ np.float32(w - 1 - x) / pad[2]), 1.0 -
157
+ np.minimum(np.float32(y) / pad[1],
158
+ np.float32(h - 1 - y) / pad[3]))
159
+ blur = qsize * 0.02
160
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
161
+ img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
162
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
163
+ img = PIL.Image.fromarray(
164
+ np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
165
+ quad += pad[:2]
166
+
167
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
168
+ (quad + 0.5).flatten(), PIL.Image.BILINEAR)
169
+
170
+ if output_size < transform_size:
171
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
172
+
173
+ # Save aligned image.
174
+ print('saveing: ', out_path)
175
+ img.save(out_path)
176
+
177
+ return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
178
+
179
+
180
+ if __name__ == '__main__':
181
+ parser = argparse.ArgumentParser()
182
+ parser.add_argument('--in_dir', type=str, default='./inputs/whole_imgs')
183
+ parser.add_argument('--out_dir', type=str, default='./inputs/cropped_faces')
184
+ args = parser.parse_args()
185
+
186
+ img_list = sorted(glob.glob(f'{args.in_dir}/*.png'))
187
+ img_list = sorted(img_list)
188
+
189
+ for in_path in img_list:
190
+ out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
191
+ out_path = out_path.replace('.jpg', '.png')
192
+ size_ = align_face(in_path, out_path)
repositories/CodeFormer/scripts/download_pretrained_models.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os import path as osp
4
+
5
+ from basicsr.utils.download_util import load_file_from_url
6
+
7
+
8
+ def download_pretrained_models(method, file_urls):
9
+ save_path_root = f'./weights/{method}'
10
+ os.makedirs(save_path_root, exist_ok=True)
11
+
12
+ for file_name, file_url in file_urls.items():
13
+ save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
14
+
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+
19
+ parser.add_argument(
20
+ 'method',
21
+ type=str,
22
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
23
+ args = parser.parse_args()
24
+
25
+ file_urls = {
26
+ 'CodeFormer': {
27
+ 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
28
+ },
29
+ 'facelib': {
30
+ # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
31
+ 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
32
+ 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
33
+ }
34
+ }
35
+
36
+ if args.method == 'all':
37
+ for method in file_urls.keys():
38
+ download_pretrained_models(method, file_urls[method])
39
+ else:
40
+ download_pretrained_models(args.method, file_urls[args.method])
repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os import path as osp
4
+
5
+ # from basicsr.utils.download_util import download_file_from_google_drive
6
+ import gdown
7
+
8
+
9
+ def download_pretrained_models(method, file_ids):
10
+ save_path_root = f'./weights/{method}'
11
+ os.makedirs(save_path_root, exist_ok=True)
12
+
13
+ for file_name, file_id in file_ids.items():
14
+ file_url = 'https://drive.google.com/uc?id='+file_id
15
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
16
+ if osp.exists(save_path):
17
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
18
+ if user_response.lower() == 'y':
19
+ print(f'Covering {file_name} to {save_path}')
20
+ gdown.download(file_url, save_path, quiet=False)
21
+ # download_file_from_google_drive(file_id, save_path)
22
+ elif user_response.lower() == 'n':
23
+ print(f'Skipping {file_name}')
24
+ else:
25
+ raise ValueError('Wrong input. Only accepts Y/N.')
26
+ else:
27
+ print(f'Downloading {file_name} to {save_path}')
28
+ gdown.download(file_url, save_path, quiet=False)
29
+ # download_file_from_google_drive(file_id, save_path)
30
+
31
+ if __name__ == '__main__':
32
+ parser = argparse.ArgumentParser()
33
+
34
+ parser.add_argument(
35
+ 'method',
36
+ type=str,
37
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
38
+ args = parser.parse_args()
39
+
40
+ # file name: file id
41
+ # 'dlib': {
42
+ # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
43
+ # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
44
+ # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
45
+ # }
46
+ file_ids = {
47
+ 'CodeFormer': {
48
+ 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
49
+ },
50
+ 'facelib': {
51
+ 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
52
+ 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
53
+ }
54
+ }
55
+
56
+ if args.method == 'all':
57
+ for method in file_ids.keys():
58
+ download_pretrained_models(method, file_ids[method])
59
+ else:
60
+ download_pretrained_models(args.method, file_ids[args.method])
repositories/CodeFormer/weights/CodeFormer/.gitkeep ADDED
File without changes
repositories/CodeFormer/weights/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Weights
2
+
3
+ Put the downloaded pre-trained models to this folder.
repositories/CodeFormer/weights/facelib/.gitkeep ADDED
File without changes
repositories/generative-models/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .pt2
2
+ .pt2_2
3
+ .pt13
4
+ *.egg-info
5
+ build
6
+ /outputs
7
+ /checkpoints
repositories/generative-models/LICENSE ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SDXL 0.9 RESEARCH LICENSE AGREEMENT
2
+ Copyright (c) Stability AI Ltd.
3
+ This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).
4
+ By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
5
+ 1. LICENSE GRANT
6
+
7
+ a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
8
+
9
+ b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
10
+
11
+ c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.
12
+
13
+
14
+ 2. RESTRICTIONS
15
+
16
+ You will not, and will not permit, assist or cause any third party to:
17
+
18
+ a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
19
+
20
+ b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
21
+
22
+ c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or
23
+
24
+ d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
25
+
26
+ e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
27
+
28
+
29
+ 3. ATTRIBUTION
30
+
31
+ Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
32
+
33
+
34
+ 4. DISCLAIMERS
35
+
36
+ THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
37
+
38
+
39
+ 5. LIMITATION OF LIABILITY
40
+
41
+ TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
42
+
43
+
44
+ 6. INDEMNIFICATION
45
+
46
+ You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.
47
+
48
+
49
+ 7. TERMINATION; SURVIVAL
50
+
51
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
52
+
53
+ b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
54
+
55
+ c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
56
+
57
+
58
+ 8. THIRD PARTY MATERIALS
59
+
60
+ The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
61
+
62
+
63
+ 9. TRADEMARKS
64
+
65
+ Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
66
+
67
+
68
+ 10. APPLICABLE LAW; DISPUTE RESOLUTION
69
+
70
+ This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
71
+
72
+
73
+ 11. MISCELLANEOUS
74
+
75
+ If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI.
repositories/generative-models/README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Models by Stability AI
2
+
3
+ ![sample1](assets/000.jpg)
4
+
5
+ ## News
6
+
7
+ **July 4, 2023**
8
+ - A technical report on SDXL is now available [here](assets/sdxl_report.pdf).
9
+
10
+ **June 22, 2023**
11
+
12
+
13
+ - We are releasing two new diffusion models for research purposes:
14
+ - `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
15
+ - `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
16
+
17
+ If you would like to access these models for your research, please apply using one of the following links:
18
+ [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
19
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
20
+ Please log in to your Hugging Face Account with your organization email to request access.
21
+ **We plan to do a full release soon (July).**
22
+
23
+ ## The codebase
24
+
25
+ ### General Philosophy
26
+
27
+ Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
28
+
29
+ ### Changelog from the old `ldm` codebase
30
+
31
+ For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
32
+
33
+ - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
34
+ - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
35
+ samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
36
+ - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
37
+ * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
38
+ * The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
39
+ - Autoencoding models have also been cleaned up.
40
+
41
+ ## Installation:
42
+ <a name="installation"></a>
43
+
44
+ #### 1. Clone the repo
45
+
46
+ ```shell
47
+ git clone git@github.com:Stability-AI/generative-models.git
48
+ cd generative-models
49
+ ```
50
+
51
+ #### 2. Setting up the virtualenv
52
+
53
+ This is assuming you have navigated to the `generative-models` root after cloning it.
54
+
55
+ **NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
56
+
57
+
58
+ **PyTorch 1.13**
59
+
60
+ ```shell
61
+ # install required packages from pypi
62
+ python3 -m venv .pt1
63
+ source .pt1/bin/activate
64
+ pip3 install wheel
65
+ pip3 install -r requirements_pt13.txt
66
+ ```
67
+
68
+ **PyTorch 2.0**
69
+
70
+
71
+ ```shell
72
+ # install required packages from pypi
73
+ python3 -m venv .pt2
74
+ source .pt2/bin/activate
75
+ pip3 install wheel
76
+ pip3 install -r requirements_pt2.txt
77
+ ```
78
+
79
+ ## Inference:
80
+
81
+ We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
82
+ - [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
83
+ - [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
84
+ - [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
85
+ - [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
86
+
87
+ **Weights for SDXL**:
88
+ If you would like to access these models for your research, please apply using one of the following links:
89
+ [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
90
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
91
+ Please log in to your Hugging Face Account with your organization email to request access.
92
+
93
+ After obtaining the weights, place them into `checkpoints/`.
94
+ Next, start the demo using
95
+
96
+ ```
97
+ streamlit run scripts/demo/sampling.py --server.port <your_port>
98
+ ```
99
+
100
+ ### Invisible Watermark Detection
101
+
102
+ Images generated with our code use the
103
+ [invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
104
+ library to embed an invisible watermark into the model output. We also provide
105
+ a script to easily detect that watermark. Please note that this watermark is
106
+ not the same as in previous Stable Diffusion 1.x/2.x versions.
107
+
108
+ To run the script you need to either have a working installation as above or
109
+ try an _experimental_ import using only a minimal amount of packages:
110
+ ```bash
111
+ python -m venv .detect
112
+ source .detect/bin/activate
113
+
114
+ pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
115
+ pip install --no-deps invisible-watermark
116
+ ```
117
+
118
+ To run the script you need to have a working installation as above. The script
119
+ is then useable in the following ways (don't forget to activate your
120
+ virtual environment beforehand, e.g. `source .pt1/bin/activate`):
121
+ ```bash
122
+ # test a single file
123
+ python scripts/demo/detect.py <your filename here>
124
+ # test multiple files at once
125
+ python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
126
+ # test all files in a specific folder
127
+ python scripts/demo/detect.py <your folder name here>/*
128
+ ```
129
+
130
+ ## Training:
131
+
132
+ We are providing example training configs in `configs/example_training`. To launch a training, run
133
+
134
+ ```
135
+ python main.py --base configs/<config1.yaml> configs/<config2.yaml>
136
+ ```
137
+
138
+ where configs are merged from left to right (later configs overwrite the same values).
139
+ This can be used to combine model, training and data configs. However, all of them can also be
140
+ defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
141
+ run
142
+
143
+ ```bash
144
+ python main.py --base configs/example_training/toy/mnist_cond.yaml
145
+ ```
146
+
147
+ **NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
148
+
149
+ **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
150
+
151
+ **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
152
+
153
+ ### Building New Diffusion Models
154
+
155
+ #### Conditioner
156
+
157
+ The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
158
+ different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
159
+ All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
160
+ guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
161
+ When computing conditionings, the embedder will get `batch[input_key]` as input.
162
+ We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
163
+ appropriately.
164
+ Note that the order of the embedders in the `conditioner_config` is important.
165
+
166
+ #### Network
167
+
168
+ The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
169
+ enough as we plan to experiment with transformer-based diffusion backbones.
170
+
171
+ #### Loss
172
+
173
+ The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
174
+
175
+ #### Sampler config
176
+
177
+ As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
178
+ solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
179
+ guidance.
180
+
181
+ ### Dataset Handling
182
+
183
+
184
+ For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
185
+ Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
186
+ data keys/values,
187
+ e.g.,
188
+
189
+ ```python
190
+ example = {"jpg": x, # this is a tensor -1...1 chw
191
+ "txt": "a beautiful image"}
192
+ ```
193
+
194
+ where we expect images in -1...1, channel-first format.
repositories/generative-models/assets/000.jpg ADDED
repositories/generative-models/assets/sdxl_report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d63345686bc36e6f6de1c20610a7657fafba4f24a9e892ea6f0b9a9f36b5c00
3
+ size 18172854
repositories/generative-models/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/rec_loss
7
+
8
+ loss_config:
9
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
10
+ params:
11
+ perceptual_weight: 0.25
12
+ disc_start: 20001
13
+ disc_weight: 0.5
14
+ learn_logvar: True
15
+
16
+ regularization_weights:
17
+ kl_loss: 1.0
18
+
19
+ regularizer_config:
20
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
21
+
22
+ encoder_config:
23
+ target: sgm.modules.diffusionmodules.model.Encoder
24
+ params:
25
+ attn_type: none
26
+ double_z: True
27
+ z_channels: 4
28
+ resolution: 256
29
+ in_channels: 3
30
+ out_ch: 3
31
+ ch: 128
32
+ ch_mult: [ 1, 2, 4 ]
33
+ num_res_blocks: 4
34
+ attn_resolutions: [ ]
35
+ dropout: 0.0
36
+
37
+ decoder_config:
38
+ target: sgm.modules.diffusionmodules.model.Decoder
39
+ params:
40
+ attn_type: none
41
+ double_z: False
42
+ z_channels: 4
43
+ resolution: 256
44
+ in_channels: 3
45
+ out_ch: 3
46
+ ch: 128
47
+ ch_mult: [ 1, 2, 4 ]
48
+ num_res_blocks: 4
49
+ attn_resolutions: [ ]
50
+ dropout: 0.0
51
+
52
+ data:
53
+ target: sgm.data.dataset.StableDataModuleFromConfig
54
+ params:
55
+ train:
56
+ datapipeline:
57
+ urls:
58
+ - "DATA-PATH"
59
+ pipeline_config:
60
+ shardshuffle: 10000
61
+ sample_shuffle: 10000
62
+
63
+ decoders:
64
+ - "pil"
65
+
66
+ postprocessors:
67
+ - target: sdata.mappers.TorchVisionImageTransforms
68
+ params:
69
+ key: 'jpg'
70
+ transforms:
71
+ - target: torchvision.transforms.Resize
72
+ params:
73
+ size: 256
74
+ interpolation: 3
75
+ - target: torchvision.transforms.ToTensor
76
+ - target: sdata.mappers.Rescaler
77
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
78
+ params:
79
+ h_key: height
80
+ w_key: width
81
+
82
+ loader:
83
+ batch_size: 8
84
+ num_workers: 4
85
+
86
+
87
+ lightning:
88
+ strategy:
89
+ target: pytorch_lightning.strategies.DDPStrategy
90
+ params:
91
+ find_unused_parameters: True
92
+
93
+ modelcheckpoint:
94
+ params:
95
+ every_n_train_steps: 5000
96
+
97
+ callbacks:
98
+ metrics_over_trainsteps_checkpoint:
99
+ params:
100
+ every_n_train_steps: 50000
101
+
102
+ image_logger:
103
+ target: main.ImageLogger
104
+ params:
105
+ enable_autocast: False
106
+ batch_frequency: 1000
107
+ max_images: 8
108
+ increase_log_steps: True
109
+
110
+ trainer:
111
+ devices: 0,
112
+ limit_val_batches: 50
113
+ benchmark: True
114
+ accumulate_grad_batches: 1
115
+ val_check_interval: 10000
repositories/generative-models/configs/example_training/imagenet-f8_cond.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - cls
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ weighting_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
26
+ scaling_config:
27
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
28
+ discretization_config:
29
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
30
+
31
+ network_config:
32
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
33
+ params:
34
+ use_checkpoint: True
35
+ use_fp16: True
36
+ in_channels: 4
37
+ out_channels: 4
38
+ model_channels: 256
39
+ attention_resolutions: [1, 2, 4]
40
+ num_res_blocks: 2
41
+ channel_mult: [1, 2, 4]
42
+ num_head_channels: 64
43
+ num_classes: sequential
44
+ adm_in_channels: 1024
45
+ use_spatial_transformer: true
46
+ transformer_depth: 1
47
+ context_dim: 1024
48
+ spatial_transformer_attn_type: softmax-xformers
49
+
50
+ conditioner_config:
51
+ target: sgm.modules.GeneralConditioner
52
+ params:
53
+ emb_models:
54
+ # crossattn cond
55
+ - is_trainable: True
56
+ input_key: cls
57
+ ucg_rate: 0.2
58
+ target: sgm.modules.encoders.modules.ClassEmbedder
59
+ params:
60
+ add_sequence_dim: True # will be used through crossattn then
61
+ embed_dim: 1024
62
+ n_classes: 1000
63
+ # vector cond
64
+ - is_trainable: False
65
+ ucg_rate: 0.2
66
+ input_key: original_size_as_tuple
67
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
68
+ params:
69
+ outdim: 256 # multiplied by two
70
+ # vector cond
71
+ - is_trainable: False
72
+ input_key: crop_coords_top_left
73
+ ucg_rate: 0.2
74
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
75
+ params:
76
+ outdim: 256 # multiplied by two
77
+
78
+ first_stage_config:
79
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
80
+ params:
81
+ ckpt_path: CKPT_PATH
82
+ embed_dim: 4
83
+ monitor: val/rec_loss
84
+ ddconfig:
85
+ attn_type: vanilla-xformers
86
+ double_z: true
87
+ z_channels: 4
88
+ resolution: 256
89
+ in_channels: 3
90
+ out_ch: 3
91
+ ch: 128
92
+ ch_mult: [1, 2, 4, 4]
93
+ num_res_blocks: 2
94
+ attn_resolutions: []
95
+ dropout: 0.0
96
+ lossconfig:
97
+ target: torch.nn.Identity
98
+
99
+ loss_fn_config:
100
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
101
+ params:
102
+ sigma_sampler_config:
103
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
104
+ params:
105
+ num_idx: 1000
106
+
107
+ discretization_config:
108
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
109
+
110
+ sampler_config:
111
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
112
+ params:
113
+ num_steps: 50
114
+
115
+ discretization_config:
116
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
117
+
118
+ guider_config:
119
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
120
+ params:
121
+ scale: 5.0
122
+
123
+ data:
124
+ target: sgm.data.dataset.StableDataModuleFromConfig
125
+ params:
126
+ train:
127
+ datapipeline:
128
+ urls:
129
+ # USER: adapt this path the root of your custom dataset
130
+ - "DATA_PATH"
131
+ pipeline_config:
132
+ shardshuffle: 10000
133
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
134
+
135
+ decoders:
136
+ - "pil"
137
+
138
+ postprocessors:
139
+ - target: sdata.mappers.TorchVisionImageTransforms
140
+ params:
141
+ key: 'jpg' # USER: you might wanna adapt this for your custom dataset
142
+ transforms:
143
+ - target: torchvision.transforms.Resize
144
+ params:
145
+ size: 256
146
+ interpolation: 3
147
+ - target: torchvision.transforms.ToTensor
148
+ - target: sdata.mappers.Rescaler
149
+
150
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
151
+ params:
152
+ h_key: height # USER: you might wanna adapt this for your custom dataset
153
+ w_key: width # USER: you might wanna adapt this for your custom dataset
154
+
155
+ loader:
156
+ batch_size: 64
157
+ num_workers: 6
158
+
159
+ lightning:
160
+ modelcheckpoint:
161
+ params:
162
+ every_n_train_steps: 5000
163
+
164
+ callbacks:
165
+ metrics_over_trainsteps_checkpoint:
166
+ params:
167
+ every_n_train_steps: 25000
168
+
169
+ image_logger:
170
+ target: main.ImageLogger
171
+ params:
172
+ disabled: False
173
+ enable_autocast: False
174
+ batch_frequency: 1000
175
+ max_images: 8
176
+ increase_log_steps: True
177
+ log_first_step: False
178
+ log_images_kwargs:
179
+ use_ema_scope: False
180
+ N: 8
181
+ n_rows: 2
182
+
183
+ trainer:
184
+ devices: 0,
185
+ benchmark: True
186
+ num_sanity_val_steps: 0
187
+ accumulate_grad_batches: 1
188
+ max_epochs: 1000
repositories/generative-models/configs/example_training/toy/cifar10_cond.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ weighting_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
10
+ params:
11
+ sigma_data: 1.0
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
14
+ params:
15
+ sigma_data: 1.0
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 3
22
+ out_channels: 3
23
+ model_channels: 32
24
+ attention_resolutions: []
25
+ num_res_blocks: 4
26
+ channel_mult: [1, 2, 2]
27
+ num_head_channels: 32
28
+ num_classes: sequential
29
+ adm_in_channels: 128
30
+
31
+ conditioner_config:
32
+ target: sgm.modules.GeneralConditioner
33
+ params:
34
+ emb_models:
35
+ - is_trainable: True
36
+ input_key: cls
37
+ ucg_rate: 0.2
38
+ target: sgm.modules.encoders.modules.ClassEmbedder
39
+ params:
40
+ embed_dim: 128
41
+ n_classes: 10
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.IdentityFirstStage
45
+
46
+ loss_fn_config:
47
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
48
+ params:
49
+ sigma_sampler_config:
50
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51
+
52
+ sampler_config:
53
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54
+ params:
55
+ num_steps: 50
56
+
57
+ discretization_config:
58
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59
+
60
+ guider_config:
61
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62
+ params:
63
+ scale: 3.0
64
+
65
+ data:
66
+ target: sgm.data.cifar10.CIFAR10Loader
67
+ params:
68
+ batch_size: 512
69
+ num_workers: 1
70
+
71
+ lightning:
72
+ modelcheckpoint:
73
+ params:
74
+ every_n_train_steps: 5000
75
+
76
+ callbacks:
77
+ metrics_over_trainsteps_checkpoint:
78
+ params:
79
+ every_n_train_steps: 25000
80
+
81
+ image_logger:
82
+ target: main.ImageLogger
83
+ params:
84
+ disabled: False
85
+ batch_frequency: 1000
86
+ max_images: 64
87
+ increase_log_steps: True
88
+ log_first_step: False
89
+ log_images_kwargs:
90
+ use_ema_scope: False
91
+ N: 64
92
+ n_rows: 8
93
+
94
+ trainer:
95
+ devices: 0,
96
+ benchmark: True
97
+ num_sanity_val_steps: 0
98
+ accumulate_grad_batches: 1
99
+ max_epochs: 20
repositories/generative-models/configs/example_training/toy/mnist.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ weighting_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
10
+ params:
11
+ sigma_data: 1.0
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
14
+ params:
15
+ sigma_data: 1.0
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 1
22
+ out_channels: 1
23
+ model_channels: 32
24
+ attention_resolutions: []
25
+ num_res_blocks: 4
26
+ channel_mult: [1, 2, 2]
27
+ num_head_channels: 32
28
+
29
+ first_stage_config:
30
+ target: sgm.models.autoencoder.IdentityFirstStage
31
+
32
+ loss_fn_config:
33
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
34
+ params:
35
+ sigma_sampler_config:
36
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
37
+
38
+ sampler_config:
39
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
40
+ params:
41
+ num_steps: 50
42
+
43
+ discretization_config:
44
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
45
+
46
+ data:
47
+ target: sgm.data.mnist.MNISTLoader
48
+ params:
49
+ batch_size: 512
50
+ num_workers: 1
51
+
52
+ lightning:
53
+ modelcheckpoint:
54
+ params:
55
+ every_n_train_steps: 5000
56
+
57
+ callbacks:
58
+ metrics_over_trainsteps_checkpoint:
59
+ params:
60
+ every_n_train_steps: 25000
61
+
62
+ image_logger:
63
+ target: main.ImageLogger
64
+ params:
65
+ disabled: False
66
+ batch_frequency: 1000
67
+ max_images: 64
68
+ increase_log_steps: False
69
+ log_first_step: False
70
+ log_images_kwargs:
71
+ use_ema_scope: False
72
+ N: 64
73
+ n_rows: 8
74
+
75
+ trainer:
76
+ devices: 0,
77
+ benchmark: True
78
+ num_sanity_val_steps: 0
79
+ accumulate_grad_batches: 1
80
+ max_epochs: 10