0975a6be0dfbdd3323a94ec491197546e4fb2f8cc2232996355251f3fece7777
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- repositories/CodeFormer/facelib/utils/__pycache__/misc.cpython-310.pyc +0 -0
- repositories/CodeFormer/facelib/utils/face_restoration_helper.py +455 -0
- repositories/CodeFormer/facelib/utils/face_utils.py +248 -0
- repositories/CodeFormer/facelib/utils/misc.py +141 -0
- repositories/CodeFormer/inference_codeformer.py +189 -0
- repositories/CodeFormer/inputs/cropped_faces/0143.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0240.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0342.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0345.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0368.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0412.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0444.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0478.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0500.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0599.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0717.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0720.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0729.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0763.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0770.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0777.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0885.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/0934.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_0018.png +0 -0
- repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_2_16.png +0 -0
- repositories/CodeFormer/inputs/whole_imgs/00.jpg +0 -0
- repositories/CodeFormer/inputs/whole_imgs/01.jpg +0 -0
- repositories/CodeFormer/inputs/whole_imgs/02.png +0 -0
- repositories/CodeFormer/inputs/whole_imgs/03.jpg +0 -0
- repositories/CodeFormer/inputs/whole_imgs/04.jpg +0 -0
- repositories/CodeFormer/inputs/whole_imgs/05.jpg +0 -0
- repositories/CodeFormer/inputs/whole_imgs/06.png +0 -0
- repositories/CodeFormer/predict.py +188 -0
- repositories/CodeFormer/requirements.txt +20 -0
- repositories/CodeFormer/scripts/crop_align_face.py +192 -0
- repositories/CodeFormer/scripts/download_pretrained_models.py +40 -0
- repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py +60 -0
- repositories/CodeFormer/weights/CodeFormer/.gitkeep +0 -0
- repositories/CodeFormer/weights/README.md +3 -0
- repositories/CodeFormer/weights/facelib/.gitkeep +0 -0
- repositories/generative-models/.gitignore +7 -0
- repositories/generative-models/LICENSE +75 -0
- repositories/generative-models/README.md +194 -0
- repositories/generative-models/assets/000.jpg +0 -0
- repositories/generative-models/assets/sdxl_report.pdf +3 -0
- repositories/generative-models/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +115 -0
- repositories/generative-models/configs/example_training/imagenet-f8_cond.yaml +188 -0
- repositories/generative-models/configs/example_training/toy/cifar10_cond.yaml +99 -0
- repositories/generative-models/configs/example_training/toy/mnist.yaml +80 -0
.gitattributes
CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
|
37 |
extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
|
38 |
repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
|
36 |
extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
|
37 |
extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
|
38 |
repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
repositories/generative-models/assets/sdxl_report.pdf filter=lfs diff=lfs merge=lfs -text
|
repositories/CodeFormer/facelib/utils/__pycache__/misc.cpython-310.pyc
ADDED
Binary file (4.61 kB). View file
|
|
repositories/CodeFormer/facelib/utils/face_restoration_helper.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from facelib.detection import init_detection_model
|
8 |
+
from facelib.parsing import init_parsing_model
|
9 |
+
from facelib.utils.misc import img2tensor, imwrite
|
10 |
+
|
11 |
+
|
12 |
+
def get_largest_face(det_faces, h, w):
|
13 |
+
|
14 |
+
def get_location(val, length):
|
15 |
+
if val < 0:
|
16 |
+
return 0
|
17 |
+
elif val > length:
|
18 |
+
return length
|
19 |
+
else:
|
20 |
+
return val
|
21 |
+
|
22 |
+
face_areas = []
|
23 |
+
for det_face in det_faces:
|
24 |
+
left = get_location(det_face[0], w)
|
25 |
+
right = get_location(det_face[2], w)
|
26 |
+
top = get_location(det_face[1], h)
|
27 |
+
bottom = get_location(det_face[3], h)
|
28 |
+
face_area = (right - left) * (bottom - top)
|
29 |
+
face_areas.append(face_area)
|
30 |
+
largest_idx = face_areas.index(max(face_areas))
|
31 |
+
return det_faces[largest_idx], largest_idx
|
32 |
+
|
33 |
+
|
34 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
35 |
+
if center is not None:
|
36 |
+
center = np.array(center)
|
37 |
+
else:
|
38 |
+
center = np.array([w / 2, h / 2])
|
39 |
+
center_dist = []
|
40 |
+
for det_face in det_faces:
|
41 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
42 |
+
dist = np.linalg.norm(face_center - center)
|
43 |
+
center_dist.append(dist)
|
44 |
+
center_idx = center_dist.index(min(center_dist))
|
45 |
+
return det_faces[center_idx], center_idx
|
46 |
+
|
47 |
+
|
48 |
+
class FaceRestoreHelper(object):
|
49 |
+
"""Helper for the face restoration pipeline (base class)."""
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
upscale_factor,
|
53 |
+
face_size=512,
|
54 |
+
crop_ratio=(1, 1),
|
55 |
+
det_model='retinaface_resnet50',
|
56 |
+
save_ext='png',
|
57 |
+
template_3points=False,
|
58 |
+
pad_blur=False,
|
59 |
+
use_parse=False,
|
60 |
+
device=None):
|
61 |
+
self.template_3points = template_3points # improve robustness
|
62 |
+
self.upscale_factor = upscale_factor
|
63 |
+
# the cropped face ratio based on the square face
|
64 |
+
self.crop_ratio = crop_ratio # (h, w)
|
65 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
66 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
67 |
+
|
68 |
+
if self.template_3points:
|
69 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
70 |
+
else:
|
71 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
72 |
+
# facexlib
|
73 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
74 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
75 |
+
|
76 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
77 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
78 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
79 |
+
|
80 |
+
|
81 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
82 |
+
if self.crop_ratio[0] > 1:
|
83 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
84 |
+
if self.crop_ratio[1] > 1:
|
85 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
86 |
+
self.save_ext = save_ext
|
87 |
+
self.pad_blur = pad_blur
|
88 |
+
if self.pad_blur is True:
|
89 |
+
self.template_3points = False
|
90 |
+
|
91 |
+
self.all_landmarks_5 = []
|
92 |
+
self.det_faces = []
|
93 |
+
self.affine_matrices = []
|
94 |
+
self.inverse_affine_matrices = []
|
95 |
+
self.cropped_faces = []
|
96 |
+
self.restored_faces = []
|
97 |
+
self.pad_input_imgs = []
|
98 |
+
|
99 |
+
if device is None:
|
100 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
101 |
+
else:
|
102 |
+
self.device = device
|
103 |
+
|
104 |
+
# init face detection model
|
105 |
+
self.face_det = init_detection_model(det_model, half=False, device=self.device)
|
106 |
+
|
107 |
+
# init face parsing model
|
108 |
+
self.use_parse = use_parse
|
109 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
|
110 |
+
|
111 |
+
def set_upscale_factor(self, upscale_factor):
|
112 |
+
self.upscale_factor = upscale_factor
|
113 |
+
|
114 |
+
def read_image(self, img):
|
115 |
+
"""img can be image path or cv2 loaded image."""
|
116 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
117 |
+
if isinstance(img, str):
|
118 |
+
img = cv2.imread(img)
|
119 |
+
|
120 |
+
if np.max(img) > 256: # 16-bit image
|
121 |
+
img = img / 65535 * 255
|
122 |
+
if len(img.shape) == 2: # gray image
|
123 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
124 |
+
elif img.shape[2] == 4: # BGRA image with alpha channel
|
125 |
+
img = img[:, :, 0:3]
|
126 |
+
|
127 |
+
self.input_img = img
|
128 |
+
|
129 |
+
if min(self.input_img.shape[:2])<512:
|
130 |
+
f = 512.0/min(self.input_img.shape[:2])
|
131 |
+
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
132 |
+
|
133 |
+
def get_face_landmarks_5(self,
|
134 |
+
only_keep_largest=False,
|
135 |
+
only_center_face=False,
|
136 |
+
resize=None,
|
137 |
+
blur_ratio=0.01,
|
138 |
+
eye_dist_threshold=None):
|
139 |
+
if resize is None:
|
140 |
+
scale = 1
|
141 |
+
input_img = self.input_img
|
142 |
+
else:
|
143 |
+
h, w = self.input_img.shape[0:2]
|
144 |
+
scale = resize / min(h, w)
|
145 |
+
scale = max(1, scale) # always scale up
|
146 |
+
h, w = int(h * scale), int(w * scale)
|
147 |
+
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
148 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
149 |
+
|
150 |
+
with torch.no_grad():
|
151 |
+
bboxes = self.face_det.detect_faces(input_img)
|
152 |
+
|
153 |
+
if bboxes is None or bboxes.shape[0] == 0:
|
154 |
+
return 0
|
155 |
+
else:
|
156 |
+
bboxes = bboxes / scale
|
157 |
+
|
158 |
+
for bbox in bboxes:
|
159 |
+
# remove faces with too small eye distance: side faces or too small faces
|
160 |
+
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
161 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
162 |
+
continue
|
163 |
+
|
164 |
+
if self.template_3points:
|
165 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
166 |
+
else:
|
167 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
168 |
+
self.all_landmarks_5.append(landmark)
|
169 |
+
self.det_faces.append(bbox[0:5])
|
170 |
+
|
171 |
+
if len(self.det_faces) == 0:
|
172 |
+
return 0
|
173 |
+
if only_keep_largest:
|
174 |
+
h, w, _ = self.input_img.shape
|
175 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
176 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
177 |
+
elif only_center_face:
|
178 |
+
h, w, _ = self.input_img.shape
|
179 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
180 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
181 |
+
|
182 |
+
# pad blurry images
|
183 |
+
if self.pad_blur:
|
184 |
+
self.pad_input_imgs = []
|
185 |
+
for landmarks in self.all_landmarks_5:
|
186 |
+
# get landmarks
|
187 |
+
eye_left = landmarks[0, :]
|
188 |
+
eye_right = landmarks[1, :]
|
189 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
190 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
191 |
+
eye_to_eye = eye_right - eye_left
|
192 |
+
eye_to_mouth = mouth_avg - eye_avg
|
193 |
+
|
194 |
+
# Get the oriented crop rectangle
|
195 |
+
# x: half width of the oriented crop rectangle
|
196 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
197 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
198 |
+
# norm with the hypotenuse: get the direction
|
199 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
200 |
+
rect_scale = 1.5
|
201 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
202 |
+
# y: half height of the oriented crop rectangle
|
203 |
+
y = np.flipud(x) * [-1, 1]
|
204 |
+
|
205 |
+
# c: center
|
206 |
+
c = eye_avg + eye_to_mouth * 0.1
|
207 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
208 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
209 |
+
# qsize: side length of the square
|
210 |
+
qsize = np.hypot(*x) * 2
|
211 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
212 |
+
|
213 |
+
# get pad
|
214 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
215 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
216 |
+
int(np.ceil(max(quad[:, 1]))))
|
217 |
+
pad = [
|
218 |
+
max(-pad[0] + border, 1),
|
219 |
+
max(-pad[1] + border, 1),
|
220 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
221 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
222 |
+
]
|
223 |
+
|
224 |
+
if max(pad) > 1:
|
225 |
+
# pad image
|
226 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
227 |
+
# modify landmark coords
|
228 |
+
landmarks[:, 0] += pad[0]
|
229 |
+
landmarks[:, 1] += pad[1]
|
230 |
+
# blur pad images
|
231 |
+
h, w, _ = pad_img.shape
|
232 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
233 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
234 |
+
np.float32(w - 1 - x) / pad[2]),
|
235 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
236 |
+
np.float32(h - 1 - y) / pad[3]))
|
237 |
+
blur = int(qsize * blur_ratio)
|
238 |
+
if blur % 2 == 0:
|
239 |
+
blur += 1
|
240 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
241 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
242 |
+
|
243 |
+
pad_img = pad_img.astype('float32')
|
244 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
245 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
246 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
247 |
+
self.pad_input_imgs.append(pad_img)
|
248 |
+
else:
|
249 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
250 |
+
|
251 |
+
return len(self.all_landmarks_5)
|
252 |
+
|
253 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
254 |
+
"""Align and warp faces with face template.
|
255 |
+
"""
|
256 |
+
if self.pad_blur:
|
257 |
+
assert len(self.pad_input_imgs) == len(
|
258 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
259 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
260 |
+
# use 5 landmarks to get affine matrix
|
261 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
262 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
263 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
264 |
+
self.affine_matrices.append(affine_matrix)
|
265 |
+
# warp and crop faces
|
266 |
+
if border_mode == 'constant':
|
267 |
+
border_mode = cv2.BORDER_CONSTANT
|
268 |
+
elif border_mode == 'reflect101':
|
269 |
+
border_mode = cv2.BORDER_REFLECT101
|
270 |
+
elif border_mode == 'reflect':
|
271 |
+
border_mode = cv2.BORDER_REFLECT
|
272 |
+
if self.pad_blur:
|
273 |
+
input_img = self.pad_input_imgs[idx]
|
274 |
+
else:
|
275 |
+
input_img = self.input_img
|
276 |
+
cropped_face = cv2.warpAffine(
|
277 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
278 |
+
self.cropped_faces.append(cropped_face)
|
279 |
+
# save the cropped face
|
280 |
+
if save_cropped_path is not None:
|
281 |
+
path = os.path.splitext(save_cropped_path)[0]
|
282 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
283 |
+
imwrite(cropped_face, save_path)
|
284 |
+
|
285 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
286 |
+
"""Get inverse affine matrix."""
|
287 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
288 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
289 |
+
inverse_affine *= self.upscale_factor
|
290 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
291 |
+
# save inverse affine matrices
|
292 |
+
if save_inverse_affine_path is not None:
|
293 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
294 |
+
save_path = f'{path}_{idx:02d}.pth'
|
295 |
+
torch.save(inverse_affine, save_path)
|
296 |
+
|
297 |
+
|
298 |
+
def add_restored_face(self, face):
|
299 |
+
self.restored_faces.append(face)
|
300 |
+
|
301 |
+
|
302 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
303 |
+
h, w, _ = self.input_img.shape
|
304 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
305 |
+
|
306 |
+
if upsample_img is None:
|
307 |
+
# simply resize the background
|
308 |
+
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
309 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
310 |
+
else:
|
311 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
312 |
+
|
313 |
+
assert len(self.restored_faces) == len(
|
314 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
315 |
+
|
316 |
+
inv_mask_borders = []
|
317 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
318 |
+
if face_upsampler is not None:
|
319 |
+
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
320 |
+
inverse_affine /= self.upscale_factor
|
321 |
+
inverse_affine[:, 2] *= self.upscale_factor
|
322 |
+
face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
|
323 |
+
else:
|
324 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
325 |
+
if self.upscale_factor > 1:
|
326 |
+
extra_offset = 0.5 * self.upscale_factor
|
327 |
+
else:
|
328 |
+
extra_offset = 0
|
329 |
+
inverse_affine[:, 2] += extra_offset
|
330 |
+
face_size = self.face_size
|
331 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
332 |
+
|
333 |
+
# if draw_box or not self.use_parse: # use square parse maps
|
334 |
+
# mask = np.ones(face_size, dtype=np.float32)
|
335 |
+
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
336 |
+
# # remove the black borders
|
337 |
+
# inv_mask_erosion = cv2.erode(
|
338 |
+
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
339 |
+
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
340 |
+
# total_face_area = np.sum(inv_mask_erosion) # // 3
|
341 |
+
# # add border
|
342 |
+
# if draw_box:
|
343 |
+
# h, w = face_size
|
344 |
+
# mask_border = np.ones((h, w, 3), dtype=np.float32)
|
345 |
+
# border = int(1400/np.sqrt(total_face_area))
|
346 |
+
# mask_border[border:h-border, border:w-border,:] = 0
|
347 |
+
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
348 |
+
# inv_mask_borders.append(inv_mask_border)
|
349 |
+
# if not self.use_parse:
|
350 |
+
# # compute the fusion edge based on the area of face
|
351 |
+
# w_edge = int(total_face_area**0.5) // 20
|
352 |
+
# erosion_radius = w_edge * 2
|
353 |
+
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
354 |
+
# blur_size = w_edge * 2
|
355 |
+
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
356 |
+
# if len(upsample_img.shape) == 2: # upsample_img is gray image
|
357 |
+
# upsample_img = upsample_img[:, :, None]
|
358 |
+
# inv_soft_mask = inv_soft_mask[:, :, None]
|
359 |
+
|
360 |
+
# always use square mask
|
361 |
+
mask = np.ones(face_size, dtype=np.float32)
|
362 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
363 |
+
# remove the black borders
|
364 |
+
inv_mask_erosion = cv2.erode(
|
365 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
366 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
367 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
368 |
+
# add border
|
369 |
+
if draw_box:
|
370 |
+
h, w = face_size
|
371 |
+
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
372 |
+
border = int(1400/np.sqrt(total_face_area))
|
373 |
+
mask_border[border:h-border, border:w-border,:] = 0
|
374 |
+
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
375 |
+
inv_mask_borders.append(inv_mask_border)
|
376 |
+
# compute the fusion edge based on the area of face
|
377 |
+
w_edge = int(total_face_area**0.5) // 20
|
378 |
+
erosion_radius = w_edge * 2
|
379 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
380 |
+
blur_size = w_edge * 2
|
381 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
382 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
383 |
+
upsample_img = upsample_img[:, :, None]
|
384 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
385 |
+
|
386 |
+
# parse mask
|
387 |
+
if self.use_parse:
|
388 |
+
# inference
|
389 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
390 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
391 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
392 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
393 |
+
with torch.no_grad():
|
394 |
+
out = self.face_parse(face_input)[0]
|
395 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
396 |
+
|
397 |
+
parse_mask = np.zeros(out.shape)
|
398 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
399 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
400 |
+
parse_mask[out == idx] = color
|
401 |
+
# blur the mask
|
402 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
403 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
404 |
+
# remove the black borders
|
405 |
+
thres = 10
|
406 |
+
parse_mask[:thres, :] = 0
|
407 |
+
parse_mask[-thres:, :] = 0
|
408 |
+
parse_mask[:, :thres] = 0
|
409 |
+
parse_mask[:, -thres:] = 0
|
410 |
+
parse_mask = parse_mask / 255.
|
411 |
+
|
412 |
+
parse_mask = cv2.resize(parse_mask, face_size)
|
413 |
+
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
414 |
+
inv_soft_parse_mask = parse_mask[:, :, None]
|
415 |
+
# pasted_face = inv_restored
|
416 |
+
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
|
417 |
+
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
|
418 |
+
|
419 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
420 |
+
alpha = upsample_img[:, :, 3:]
|
421 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
422 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
423 |
+
else:
|
424 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
425 |
+
|
426 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
427 |
+
upsample_img = upsample_img.astype(np.uint16)
|
428 |
+
else:
|
429 |
+
upsample_img = upsample_img.astype(np.uint8)
|
430 |
+
|
431 |
+
# draw bounding box
|
432 |
+
if draw_box:
|
433 |
+
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
434 |
+
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
435 |
+
img_color[:,:,0] = 0
|
436 |
+
img_color[:,:,1] = 255
|
437 |
+
img_color[:,:,2] = 0
|
438 |
+
for inv_mask_border in inv_mask_borders:
|
439 |
+
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
440 |
+
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
441 |
+
|
442 |
+
if save_path is not None:
|
443 |
+
path = os.path.splitext(save_path)[0]
|
444 |
+
save_path = f'{path}.{self.save_ext}'
|
445 |
+
imwrite(upsample_img, save_path)
|
446 |
+
return upsample_img
|
447 |
+
|
448 |
+
def clean_all(self):
|
449 |
+
self.all_landmarks_5 = []
|
450 |
+
self.restored_faces = []
|
451 |
+
self.affine_matrices = []
|
452 |
+
self.cropped_faces = []
|
453 |
+
self.inverse_affine_matrices = []
|
454 |
+
self.det_faces = []
|
455 |
+
self.pad_input_imgs = []
|
repositories/CodeFormer/facelib/utils/face_utils.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
|
7 |
+
left, top, right, bot = bbox
|
8 |
+
width = right - left
|
9 |
+
height = bot - top
|
10 |
+
|
11 |
+
if preserve_aspect:
|
12 |
+
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
|
13 |
+
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
|
14 |
+
else:
|
15 |
+
width_increase = height_increase = increase_area
|
16 |
+
left = int(left - width_increase * width)
|
17 |
+
top = int(top - height_increase * height)
|
18 |
+
right = int(right + width_increase * width)
|
19 |
+
bot = int(bot + height_increase * height)
|
20 |
+
return (left, top, right, bot)
|
21 |
+
|
22 |
+
|
23 |
+
def get_valid_bboxes(bboxes, h, w):
|
24 |
+
left = max(bboxes[0], 0)
|
25 |
+
top = max(bboxes[1], 0)
|
26 |
+
right = min(bboxes[2], w)
|
27 |
+
bottom = min(bboxes[3], h)
|
28 |
+
return (left, top, right, bottom)
|
29 |
+
|
30 |
+
|
31 |
+
def align_crop_face_landmarks(img,
|
32 |
+
landmarks,
|
33 |
+
output_size,
|
34 |
+
transform_size=None,
|
35 |
+
enable_padding=True,
|
36 |
+
return_inverse_affine=False,
|
37 |
+
shrink_ratio=(1, 1)):
|
38 |
+
"""Align and crop face with landmarks.
|
39 |
+
|
40 |
+
The output_size and transform_size are based on width. The height is
|
41 |
+
adjusted based on shrink_ratio_h/shring_ration_w.
|
42 |
+
|
43 |
+
Modified from:
|
44 |
+
https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
|
45 |
+
|
46 |
+
Args:
|
47 |
+
img (Numpy array): Input image.
|
48 |
+
landmarks (Numpy array): 5 or 68 or 98 landmarks.
|
49 |
+
output_size (int): Output face size.
|
50 |
+
transform_size (ing): Transform size. Usually the four time of
|
51 |
+
output_size.
|
52 |
+
enable_padding (float): Default: True.
|
53 |
+
shrink_ratio (float | tuple[float] | list[float]): Shring the whole
|
54 |
+
face for height and width (crop larger area). Default: (1, 1).
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
(Numpy array): Cropped face.
|
58 |
+
"""
|
59 |
+
lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
|
60 |
+
|
61 |
+
if isinstance(shrink_ratio, (float, int)):
|
62 |
+
shrink_ratio = (shrink_ratio, shrink_ratio)
|
63 |
+
if transform_size is None:
|
64 |
+
transform_size = output_size * 4
|
65 |
+
|
66 |
+
# Parse landmarks
|
67 |
+
lm = np.array(landmarks)
|
68 |
+
if lm.shape[0] == 5 and lm_type == 'retinaface_5':
|
69 |
+
eye_left = lm[0]
|
70 |
+
eye_right = lm[1]
|
71 |
+
mouth_avg = (lm[3] + lm[4]) * 0.5
|
72 |
+
elif lm.shape[0] == 5 and lm_type == 'dlib_5':
|
73 |
+
lm_eye_left = lm[2:4]
|
74 |
+
lm_eye_right = lm[0:2]
|
75 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
76 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
77 |
+
mouth_avg = lm[4]
|
78 |
+
elif lm.shape[0] == 68:
|
79 |
+
lm_eye_left = lm[36:42]
|
80 |
+
lm_eye_right = lm[42:48]
|
81 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
82 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
83 |
+
mouth_avg = (lm[48] + lm[54]) * 0.5
|
84 |
+
elif lm.shape[0] == 98:
|
85 |
+
lm_eye_left = lm[60:68]
|
86 |
+
lm_eye_right = lm[68:76]
|
87 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
88 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
89 |
+
mouth_avg = (lm[76] + lm[82]) * 0.5
|
90 |
+
|
91 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
92 |
+
eye_to_eye = eye_right - eye_left
|
93 |
+
eye_to_mouth = mouth_avg - eye_avg
|
94 |
+
|
95 |
+
# Get the oriented crop rectangle
|
96 |
+
# x: half width of the oriented crop rectangle
|
97 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
98 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
99 |
+
# norm with the hypotenuse: get the direction
|
100 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
101 |
+
rect_scale = 1 # TODO: you can edit it to get larger rect
|
102 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
103 |
+
# y: half height of the oriented crop rectangle
|
104 |
+
y = np.flipud(x) * [-1, 1]
|
105 |
+
|
106 |
+
x *= shrink_ratio[1] # width
|
107 |
+
y *= shrink_ratio[0] # height
|
108 |
+
|
109 |
+
# c: center
|
110 |
+
c = eye_avg + eye_to_mouth * 0.1
|
111 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
112 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
113 |
+
# qsize: side length of the square
|
114 |
+
qsize = np.hypot(*x) * 2
|
115 |
+
|
116 |
+
quad_ori = np.copy(quad)
|
117 |
+
# Shrink, for large face
|
118 |
+
# TODO: do we really need shrink
|
119 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
120 |
+
if shrink > 1:
|
121 |
+
h, w = img.shape[0:2]
|
122 |
+
rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
|
123 |
+
img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
|
124 |
+
quad /= shrink
|
125 |
+
qsize /= shrink
|
126 |
+
|
127 |
+
# Crop
|
128 |
+
h, w = img.shape[0:2]
|
129 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
130 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
131 |
+
int(np.ceil(max(quad[:, 1]))))
|
132 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
|
133 |
+
if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
|
134 |
+
img = img[crop[1]:crop[3], crop[0]:crop[2], :]
|
135 |
+
quad -= crop[0:2]
|
136 |
+
|
137 |
+
# Pad
|
138 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
139 |
+
h, w = img.shape[0:2]
|
140 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
141 |
+
int(np.ceil(max(quad[:, 1]))))
|
142 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
|
143 |
+
if enable_padding and max(pad) > border - 4:
|
144 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
145 |
+
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
146 |
+
h, w = img.shape[0:2]
|
147 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
148 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
149 |
+
np.float32(w - 1 - x) / pad[2]),
|
150 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
151 |
+
np.float32(h - 1 - y) / pad[3]))
|
152 |
+
blur = int(qsize * 0.02)
|
153 |
+
if blur % 2 == 0:
|
154 |
+
blur += 1
|
155 |
+
blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
|
156 |
+
|
157 |
+
img = img.astype('float32')
|
158 |
+
img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
159 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
160 |
+
img = np.clip(img, 0, 255) # float32, [0, 255]
|
161 |
+
quad += pad[:2]
|
162 |
+
|
163 |
+
# Transform use cv2
|
164 |
+
h_ratio = shrink_ratio[0] / shrink_ratio[1]
|
165 |
+
dst_h, dst_w = int(transform_size * h_ratio), transform_size
|
166 |
+
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
|
167 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
168 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
169 |
+
affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
|
170 |
+
cropped_face = cv2.warpAffine(
|
171 |
+
img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
|
172 |
+
|
173 |
+
if output_size < transform_size:
|
174 |
+
cropped_face = cv2.resize(
|
175 |
+
cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
|
176 |
+
|
177 |
+
if return_inverse_affine:
|
178 |
+
dst_h, dst_w = int(output_size * h_ratio), output_size
|
179 |
+
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
|
180 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
181 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
182 |
+
affine_matrix = cv2.estimateAffinePartial2D(
|
183 |
+
quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
|
184 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
185 |
+
else:
|
186 |
+
inverse_affine = None
|
187 |
+
return cropped_face, inverse_affine
|
188 |
+
|
189 |
+
|
190 |
+
def paste_face_back(img, face, inverse_affine):
|
191 |
+
h, w = img.shape[0:2]
|
192 |
+
face_h, face_w = face.shape[0:2]
|
193 |
+
inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
|
194 |
+
mask = np.ones((face_h, face_w, 3), dtype=np.float32)
|
195 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
|
196 |
+
# remove the black borders
|
197 |
+
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
|
198 |
+
inv_restored_remove_border = inv_mask_erosion * inv_restored
|
199 |
+
total_face_area = np.sum(inv_mask_erosion) // 3
|
200 |
+
# compute the fusion edge based on the area of face
|
201 |
+
w_edge = int(total_face_area**0.5) // 20
|
202 |
+
erosion_radius = w_edge * 2
|
203 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
204 |
+
blur_size = w_edge * 2
|
205 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
206 |
+
img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
|
207 |
+
# float32, [0, 255]
|
208 |
+
return img
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
import os
|
213 |
+
|
214 |
+
from facelib.detection import init_detection_model
|
215 |
+
from facelib.utils.face_restoration_helper import get_largest_face
|
216 |
+
|
217 |
+
img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
|
218 |
+
img_name = os.splitext(os.path.basename(img_path))[0]
|
219 |
+
|
220 |
+
# initialize model
|
221 |
+
det_net = init_detection_model('retinaface_resnet50', half=False)
|
222 |
+
img_ori = cv2.imread(img_path)
|
223 |
+
h, w = img_ori.shape[0:2]
|
224 |
+
# if larger than 800, scale it
|
225 |
+
scale = max(h / 800, w / 800)
|
226 |
+
if scale > 1:
|
227 |
+
img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
|
228 |
+
|
229 |
+
with torch.no_grad():
|
230 |
+
bboxes = det_net.detect_faces(img, 0.97)
|
231 |
+
if scale > 1:
|
232 |
+
bboxes *= scale # the score is incorrect
|
233 |
+
bboxes = get_largest_face(bboxes, h, w)[0]
|
234 |
+
|
235 |
+
landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
|
236 |
+
|
237 |
+
cropped_face, inverse_affine = align_crop_face_landmarks(
|
238 |
+
img_ori,
|
239 |
+
landmarks,
|
240 |
+
output_size=512,
|
241 |
+
transform_size=None,
|
242 |
+
enable_padding=True,
|
243 |
+
return_inverse_affine=True,
|
244 |
+
shrink_ratio=(1, 1))
|
245 |
+
|
246 |
+
cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
|
247 |
+
img = paste_face_back(img_ori, cropped_face, inverse_affine)
|
248 |
+
cv2.imwrite(f'tmp/{img_name}_back.png', img)
|
repositories/CodeFormer/facelib/utils/misc.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import torch
|
5 |
+
from torch.hub import download_url_to_file, get_dir
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
# from basicsr.utils.download_util import download_file_from_google_drive
|
8 |
+
import gdown
|
9 |
+
|
10 |
+
|
11 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
12 |
+
|
13 |
+
|
14 |
+
def download_pretrained_models(file_ids, save_path_root):
|
15 |
+
os.makedirs(save_path_root, exist_ok=True)
|
16 |
+
|
17 |
+
for file_name, file_id in file_ids.items():
|
18 |
+
file_url = 'https://drive.google.com/uc?id='+file_id
|
19 |
+
save_path = osp.abspath(osp.join(save_path_root, file_name))
|
20 |
+
if osp.exists(save_path):
|
21 |
+
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
22 |
+
if user_response.lower() == 'y':
|
23 |
+
print(f'Covering {file_name} to {save_path}')
|
24 |
+
gdown.download(file_url, save_path, quiet=False)
|
25 |
+
# download_file_from_google_drive(file_id, save_path)
|
26 |
+
elif user_response.lower() == 'n':
|
27 |
+
print(f'Skipping {file_name}')
|
28 |
+
else:
|
29 |
+
raise ValueError('Wrong input. Only accepts Y/N.')
|
30 |
+
else:
|
31 |
+
print(f'Downloading {file_name} to {save_path}')
|
32 |
+
gdown.download(file_url, save_path, quiet=False)
|
33 |
+
# download_file_from_google_drive(file_id, save_path)
|
34 |
+
|
35 |
+
|
36 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
37 |
+
"""Write image to file.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
img (ndarray): Image array to be written.
|
41 |
+
file_path (str): Image file path.
|
42 |
+
params (None or list): Same as opencv's :func:`imwrite` interface.
|
43 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
44 |
+
whether to create it automatically.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
bool: Successful or not.
|
48 |
+
"""
|
49 |
+
if auto_mkdir:
|
50 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
51 |
+
os.makedirs(dir_name, exist_ok=True)
|
52 |
+
return cv2.imwrite(file_path, img, params)
|
53 |
+
|
54 |
+
|
55 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
56 |
+
"""Numpy array to tensor.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
imgs (list[ndarray] | ndarray): Input images.
|
60 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
61 |
+
float32 (bool): Whether to change to float32.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
65 |
+
one element, just return tensor.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def _totensor(img, bgr2rgb, float32):
|
69 |
+
if img.shape[2] == 3 and bgr2rgb:
|
70 |
+
if img.dtype == 'float64':
|
71 |
+
img = img.astype('float32')
|
72 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
73 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
74 |
+
if float32:
|
75 |
+
img = img.float()
|
76 |
+
return img
|
77 |
+
|
78 |
+
if isinstance(imgs, list):
|
79 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
80 |
+
else:
|
81 |
+
return _totensor(imgs, bgr2rgb, float32)
|
82 |
+
|
83 |
+
|
84 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
85 |
+
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
86 |
+
"""
|
87 |
+
if model_dir is None:
|
88 |
+
hub_dir = get_dir()
|
89 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
90 |
+
|
91 |
+
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
92 |
+
|
93 |
+
parts = urlparse(url)
|
94 |
+
filename = os.path.basename(parts.path)
|
95 |
+
if file_name is not None:
|
96 |
+
filename = file_name
|
97 |
+
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
98 |
+
if not os.path.exists(cached_file):
|
99 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
100 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
101 |
+
return cached_file
|
102 |
+
|
103 |
+
|
104 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
105 |
+
"""Scan a directory to find the interested files.
|
106 |
+
Args:
|
107 |
+
dir_path (str): Path of the directory.
|
108 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
109 |
+
interested in. Default: None.
|
110 |
+
recursive (bool, optional): If set to True, recursively scan the
|
111 |
+
directory. Default: False.
|
112 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
113 |
+
Default: False.
|
114 |
+
Returns:
|
115 |
+
A generator for all the interested files with relative paths.
|
116 |
+
"""
|
117 |
+
|
118 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
119 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
120 |
+
|
121 |
+
root = dir_path
|
122 |
+
|
123 |
+
def _scandir(dir_path, suffix, recursive):
|
124 |
+
for entry in os.scandir(dir_path):
|
125 |
+
if not entry.name.startswith('.') and entry.is_file():
|
126 |
+
if full_path:
|
127 |
+
return_path = entry.path
|
128 |
+
else:
|
129 |
+
return_path = osp.relpath(entry.path, root)
|
130 |
+
|
131 |
+
if suffix is None:
|
132 |
+
yield return_path
|
133 |
+
elif return_path.endswith(suffix):
|
134 |
+
yield return_path
|
135 |
+
else:
|
136 |
+
if recursive:
|
137 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
138 |
+
else:
|
139 |
+
continue
|
140 |
+
|
141 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
repositories/CodeFormer/inference_codeformer.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms.functional import normalize
|
8 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
9 |
+
from basicsr.utils.download_util import load_file_from_url
|
10 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
14 |
+
|
15 |
+
pretrain_model_url = {
|
16 |
+
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
|
17 |
+
}
|
18 |
+
|
19 |
+
def set_realesrgan():
|
20 |
+
if not torch.cuda.is_available(): # CPU
|
21 |
+
import warnings
|
22 |
+
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
|
23 |
+
'If you really want to use it, please modify the corresponding codes.',
|
24 |
+
category=RuntimeWarning)
|
25 |
+
bg_upsampler = None
|
26 |
+
else:
|
27 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
28 |
+
from basicsr.utils.realesrgan_utils import RealESRGANer
|
29 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
30 |
+
bg_upsampler = RealESRGANer(
|
31 |
+
scale=2,
|
32 |
+
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
33 |
+
model=model,
|
34 |
+
tile=args.bg_tile,
|
35 |
+
tile_pad=40,
|
36 |
+
pre_pad=0,
|
37 |
+
half=True) # need to set False in CPU mode
|
38 |
+
return bg_upsampler
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
|
44 |
+
parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
|
45 |
+
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
|
46 |
+
parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
|
47 |
+
parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
|
48 |
+
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
|
49 |
+
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
50 |
+
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
51 |
+
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
|
52 |
+
parser.add_argument('--draw_box', action='store_true')
|
53 |
+
parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
|
54 |
+
parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
|
55 |
+
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
|
56 |
+
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
# ------------------------ input & output ------------------------
|
60 |
+
if args.test_path.endswith('/'): # solve when path ends with /
|
61 |
+
args.test_path = args.test_path[:-1]
|
62 |
+
|
63 |
+
w = args.w
|
64 |
+
result_root = f'results/{os.path.basename(args.test_path)}_{w}'
|
65 |
+
|
66 |
+
# ------------------ set up background upsampler ------------------
|
67 |
+
if args.bg_upsampler == 'realesrgan':
|
68 |
+
bg_upsampler = set_realesrgan()
|
69 |
+
else:
|
70 |
+
bg_upsampler = None
|
71 |
+
|
72 |
+
# ------------------ set up face upsampler ------------------
|
73 |
+
if args.face_upsample:
|
74 |
+
if bg_upsampler is not None:
|
75 |
+
face_upsampler = bg_upsampler
|
76 |
+
else:
|
77 |
+
face_upsampler = set_realesrgan()
|
78 |
+
else:
|
79 |
+
face_upsampler = None
|
80 |
+
|
81 |
+
# ------------------ set up CodeFormer restorer -------------------
|
82 |
+
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
|
83 |
+
connect_list=['32', '64', '128', '256']).to(device)
|
84 |
+
|
85 |
+
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
|
86 |
+
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
|
87 |
+
model_dir='weights/CodeFormer', progress=True, file_name=None)
|
88 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
89 |
+
net.load_state_dict(checkpoint)
|
90 |
+
net.eval()
|
91 |
+
|
92 |
+
# ------------------ set up FaceRestoreHelper -------------------
|
93 |
+
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
94 |
+
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
95 |
+
if not args.has_aligned:
|
96 |
+
print(f'Face detection model: {args.detection_model}')
|
97 |
+
if bg_upsampler is not None:
|
98 |
+
print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
|
99 |
+
else:
|
100 |
+
print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
|
101 |
+
|
102 |
+
face_helper = FaceRestoreHelper(
|
103 |
+
args.upscale,
|
104 |
+
face_size=512,
|
105 |
+
crop_ratio=(1, 1),
|
106 |
+
det_model = args.detection_model,
|
107 |
+
save_ext='png',
|
108 |
+
use_parse=True,
|
109 |
+
device=device)
|
110 |
+
|
111 |
+
# -------------------- start to processing ---------------------
|
112 |
+
# scan all the jpg and png images
|
113 |
+
for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
|
114 |
+
# clean all the intermediate results to process the next image
|
115 |
+
face_helper.clean_all()
|
116 |
+
|
117 |
+
img_name = os.path.basename(img_path)
|
118 |
+
print(f'Processing: {img_name}')
|
119 |
+
basename, ext = os.path.splitext(img_name)
|
120 |
+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
121 |
+
|
122 |
+
if args.has_aligned:
|
123 |
+
# the input faces are already cropped and aligned
|
124 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
125 |
+
face_helper.cropped_faces = [img]
|
126 |
+
else:
|
127 |
+
face_helper.read_image(img)
|
128 |
+
# get face landmarks for each face
|
129 |
+
num_det_faces = face_helper.get_face_landmarks_5(
|
130 |
+
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
|
131 |
+
print(f'\tdetect {num_det_faces} faces')
|
132 |
+
# align and warp each face
|
133 |
+
face_helper.align_warp_face()
|
134 |
+
|
135 |
+
# face restoration for each cropped face
|
136 |
+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
137 |
+
# prepare data
|
138 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
139 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
140 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
141 |
+
|
142 |
+
try:
|
143 |
+
with torch.no_grad():
|
144 |
+
output = net(cropped_face_t, w=w, adain=True)[0]
|
145 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
146 |
+
del output
|
147 |
+
torch.cuda.empty_cache()
|
148 |
+
except Exception as error:
|
149 |
+
print(f'\tFailed inference for CodeFormer: {error}')
|
150 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
151 |
+
|
152 |
+
restored_face = restored_face.astype('uint8')
|
153 |
+
face_helper.add_restored_face(restored_face)
|
154 |
+
|
155 |
+
# paste_back
|
156 |
+
if not args.has_aligned:
|
157 |
+
# upsample the background
|
158 |
+
if bg_upsampler is not None:
|
159 |
+
# Now only support RealESRGAN for upsampling background
|
160 |
+
bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
|
161 |
+
else:
|
162 |
+
bg_img = None
|
163 |
+
face_helper.get_inverse_affine(None)
|
164 |
+
# paste each restored face to the input image
|
165 |
+
if args.face_upsample and face_upsampler is not None:
|
166 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
|
167 |
+
else:
|
168 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
|
169 |
+
|
170 |
+
# save faces
|
171 |
+
for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
|
172 |
+
# save cropped face
|
173 |
+
if not args.has_aligned:
|
174 |
+
save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
|
175 |
+
imwrite(cropped_face, save_crop_path)
|
176 |
+
# save restored face
|
177 |
+
if args.has_aligned:
|
178 |
+
save_face_name = f'{basename}.png'
|
179 |
+
else:
|
180 |
+
save_face_name = f'{basename}_{idx:02d}.png'
|
181 |
+
save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
|
182 |
+
imwrite(restored_face, save_restore_path)
|
183 |
+
|
184 |
+
# save restored img
|
185 |
+
if not args.has_aligned and restored_img is not None:
|
186 |
+
save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
|
187 |
+
imwrite(restored_img, save_restore_path)
|
188 |
+
|
189 |
+
print(f'\nAll results are saved in {result_root}')
|
repositories/CodeFormer/inputs/cropped_faces/0143.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0240.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0342.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0345.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0368.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0412.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0444.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0478.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0500.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0599.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0717.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0720.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0729.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0763.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0770.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0777.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0885.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/0934.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_0018.png
ADDED
repositories/CodeFormer/inputs/cropped_faces/Solvay_conference_1927_2_16.png
ADDED
repositories/CodeFormer/inputs/whole_imgs/00.jpg
ADDED
repositories/CodeFormer/inputs/whole_imgs/01.jpg
ADDED
repositories/CodeFormer/inputs/whole_imgs/02.png
ADDED
repositories/CodeFormer/inputs/whole_imgs/03.jpg
ADDED
repositories/CodeFormer/inputs/whole_imgs/04.jpg
ADDED
repositories/CodeFormer/inputs/whole_imgs/05.jpg
ADDED
repositories/CodeFormer/inputs/whole_imgs/06.png
ADDED
repositories/CodeFormer/predict.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
download checkpoints to ./weights beforehand
|
3 |
+
python scripts/download_pretrained_models.py facelib
|
4 |
+
python scripts/download_pretrained_models.py CodeFormer
|
5 |
+
wget 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
|
6 |
+
"""
|
7 |
+
|
8 |
+
import tempfile
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
from torchvision.transforms.functional import normalize
|
12 |
+
from cog import BasePredictor, Input, Path
|
13 |
+
|
14 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
15 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
16 |
+
from basicsr.utils.realesrgan_utils import RealESRGANer
|
17 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
18 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
19 |
+
|
20 |
+
|
21 |
+
class Predictor(BasePredictor):
|
22 |
+
def setup(self):
|
23 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
24 |
+
self.device = "cuda:0"
|
25 |
+
self.bg_upsampler = set_realesrgan()
|
26 |
+
self.net = ARCH_REGISTRY.get("CodeFormer")(
|
27 |
+
dim_embd=512,
|
28 |
+
codebook_size=1024,
|
29 |
+
n_head=8,
|
30 |
+
n_layers=9,
|
31 |
+
connect_list=["32", "64", "128", "256"],
|
32 |
+
).to(self.device)
|
33 |
+
ckpt_path = "weights/CodeFormer/codeformer.pth"
|
34 |
+
checkpoint = torch.load(ckpt_path)[
|
35 |
+
"params_ema"
|
36 |
+
] # update file permission if cannot load
|
37 |
+
self.net.load_state_dict(checkpoint)
|
38 |
+
self.net.eval()
|
39 |
+
|
40 |
+
def predict(
|
41 |
+
self,
|
42 |
+
image: Path = Input(description="Input image"),
|
43 |
+
codeformer_fidelity: float = Input(
|
44 |
+
default=0.5,
|
45 |
+
ge=0,
|
46 |
+
le=1,
|
47 |
+
description="Balance the quality (lower number) and fidelity (higher number).",
|
48 |
+
),
|
49 |
+
background_enhance: bool = Input(
|
50 |
+
description="Enhance background image with Real-ESRGAN", default=True
|
51 |
+
),
|
52 |
+
face_upsample: bool = Input(
|
53 |
+
description="Upsample restored faces for high-resolution AI-created images",
|
54 |
+
default=True,
|
55 |
+
),
|
56 |
+
upscale: int = Input(
|
57 |
+
description="The final upsampling scale of the image",
|
58 |
+
default=2,
|
59 |
+
),
|
60 |
+
) -> Path:
|
61 |
+
"""Run a single prediction on the model"""
|
62 |
+
|
63 |
+
# take the default setting for the demo
|
64 |
+
has_aligned = False
|
65 |
+
only_center_face = False
|
66 |
+
draw_box = False
|
67 |
+
detection_model = "retinaface_resnet50"
|
68 |
+
|
69 |
+
self.face_helper = FaceRestoreHelper(
|
70 |
+
upscale,
|
71 |
+
face_size=512,
|
72 |
+
crop_ratio=(1, 1),
|
73 |
+
det_model=detection_model,
|
74 |
+
save_ext="png",
|
75 |
+
use_parse=True,
|
76 |
+
device=self.device,
|
77 |
+
)
|
78 |
+
|
79 |
+
bg_upsampler = self.bg_upsampler if background_enhance else None
|
80 |
+
face_upsampler = self.bg_upsampler if face_upsample else None
|
81 |
+
|
82 |
+
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
|
83 |
+
|
84 |
+
if has_aligned:
|
85 |
+
# the input faces are already cropped and aligned
|
86 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
87 |
+
self.face_helper.cropped_faces = [img]
|
88 |
+
else:
|
89 |
+
self.face_helper.read_image(img)
|
90 |
+
# get face landmarks for each face
|
91 |
+
num_det_faces = self.face_helper.get_face_landmarks_5(
|
92 |
+
only_center_face=only_center_face, resize=640, eye_dist_threshold=5
|
93 |
+
)
|
94 |
+
print(f"\tdetect {num_det_faces} faces")
|
95 |
+
# align and warp each face
|
96 |
+
self.face_helper.align_warp_face()
|
97 |
+
|
98 |
+
# face restoration for each cropped face
|
99 |
+
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
100 |
+
# prepare data
|
101 |
+
cropped_face_t = img2tensor(
|
102 |
+
cropped_face / 255.0, bgr2rgb=True, float32=True
|
103 |
+
)
|
104 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
105 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
106 |
+
|
107 |
+
try:
|
108 |
+
with torch.no_grad():
|
109 |
+
output = self.net(
|
110 |
+
cropped_face_t, w=codeformer_fidelity, adain=True
|
111 |
+
)[0]
|
112 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
113 |
+
del output
|
114 |
+
torch.cuda.empty_cache()
|
115 |
+
except Exception as error:
|
116 |
+
print(f"\tFailed inference for CodeFormer: {error}")
|
117 |
+
restored_face = tensor2img(
|
118 |
+
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
|
119 |
+
)
|
120 |
+
|
121 |
+
restored_face = restored_face.astype("uint8")
|
122 |
+
self.face_helper.add_restored_face(restored_face)
|
123 |
+
|
124 |
+
# paste_back
|
125 |
+
if not has_aligned:
|
126 |
+
# upsample the background
|
127 |
+
if bg_upsampler is not None:
|
128 |
+
# Now only support RealESRGAN for upsampling background
|
129 |
+
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
|
130 |
+
else:
|
131 |
+
bg_img = None
|
132 |
+
self.face_helper.get_inverse_affine(None)
|
133 |
+
# paste each restored face to the input image
|
134 |
+
if face_upsample and face_upsampler is not None:
|
135 |
+
restored_img = self.face_helper.paste_faces_to_input_image(
|
136 |
+
upsample_img=bg_img,
|
137 |
+
draw_box=draw_box,
|
138 |
+
face_upsampler=face_upsampler,
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
restored_img = self.face_helper.paste_faces_to_input_image(
|
142 |
+
upsample_img=bg_img, draw_box=draw_box
|
143 |
+
)
|
144 |
+
|
145 |
+
# save restored img
|
146 |
+
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
147 |
+
|
148 |
+
if not has_aligned and restored_img is not None:
|
149 |
+
imwrite(restored_img, str(out_path))
|
150 |
+
|
151 |
+
return out_path
|
152 |
+
|
153 |
+
|
154 |
+
def imread(img_path):
|
155 |
+
img = cv2.imread(img_path)
|
156 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
157 |
+
return img
|
158 |
+
|
159 |
+
|
160 |
+
def set_realesrgan():
|
161 |
+
if not torch.cuda.is_available(): # CPU
|
162 |
+
import warnings
|
163 |
+
|
164 |
+
warnings.warn(
|
165 |
+
"The unoptimized RealESRGAN is slow on CPU. We do not use it. "
|
166 |
+
"If you really want to use it, please modify the corresponding codes.",
|
167 |
+
category=RuntimeWarning,
|
168 |
+
)
|
169 |
+
bg_upsampler = None
|
170 |
+
else:
|
171 |
+
model = RRDBNet(
|
172 |
+
num_in_ch=3,
|
173 |
+
num_out_ch=3,
|
174 |
+
num_feat=64,
|
175 |
+
num_block=23,
|
176 |
+
num_grow_ch=32,
|
177 |
+
scale=2,
|
178 |
+
)
|
179 |
+
bg_upsampler = RealESRGANer(
|
180 |
+
scale=2,
|
181 |
+
model_path="./weights/RealESRGAN_x2plus.pth",
|
182 |
+
model=model,
|
183 |
+
tile=400,
|
184 |
+
tile_pad=40,
|
185 |
+
pre_pad=0,
|
186 |
+
half=True,
|
187 |
+
)
|
188 |
+
return bg_upsampler
|
repositories/CodeFormer/requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addict
|
2 |
+
future
|
3 |
+
lmdb
|
4 |
+
numpy
|
5 |
+
opencv-python
|
6 |
+
Pillow
|
7 |
+
pyyaml
|
8 |
+
requests
|
9 |
+
scikit-image
|
10 |
+
scipy
|
11 |
+
tb-nightly
|
12 |
+
torch>=1.7.1
|
13 |
+
torchvision
|
14 |
+
tqdm
|
15 |
+
yapf
|
16 |
+
lpips
|
17 |
+
gdown # supports downloading the large file from Google Drive
|
18 |
+
# cmake
|
19 |
+
# dlib
|
20 |
+
# conda install -c conda-forge dlib
|
repositories/CodeFormer/scripts/crop_align_face.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
|
3 |
+
author: lzhbrian (https://lzhbrian.me)
|
4 |
+
link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
|
5 |
+
date: 2020.1.5
|
6 |
+
note: code is heavily borrowed from
|
7 |
+
https://github.com/NVlabs/ffhq-dataset
|
8 |
+
http://dlib.net/face_landmark_detection.py.html
|
9 |
+
requirements:
|
10 |
+
conda install Pillow numpy scipy
|
11 |
+
conda install -c conda-forge dlib
|
12 |
+
# download face landmark model from:
|
13 |
+
# http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
|
14 |
+
"""
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import dlib
|
18 |
+
import glob
|
19 |
+
import numpy as np
|
20 |
+
import os
|
21 |
+
import PIL
|
22 |
+
import PIL.Image
|
23 |
+
import scipy
|
24 |
+
import scipy.ndimage
|
25 |
+
import sys
|
26 |
+
import argparse
|
27 |
+
|
28 |
+
# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
|
29 |
+
predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
|
30 |
+
|
31 |
+
|
32 |
+
def get_landmark(filepath, only_keep_largest=True):
|
33 |
+
"""get landmark with dlib
|
34 |
+
:return: np.array shape=(68, 2)
|
35 |
+
"""
|
36 |
+
detector = dlib.get_frontal_face_detector()
|
37 |
+
|
38 |
+
img = dlib.load_rgb_image(filepath)
|
39 |
+
dets = detector(img, 1)
|
40 |
+
|
41 |
+
# Shangchen modified
|
42 |
+
print("Number of faces detected: {}".format(len(dets)))
|
43 |
+
if only_keep_largest:
|
44 |
+
print('Detect several faces and only keep the largest.')
|
45 |
+
face_areas = []
|
46 |
+
for k, d in enumerate(dets):
|
47 |
+
face_area = (d.right() - d.left()) * (d.bottom() - d.top())
|
48 |
+
face_areas.append(face_area)
|
49 |
+
|
50 |
+
largest_idx = face_areas.index(max(face_areas))
|
51 |
+
d = dets[largest_idx]
|
52 |
+
shape = predictor(img, d)
|
53 |
+
print("Part 0: {}, Part 1: {} ...".format(
|
54 |
+
shape.part(0), shape.part(1)))
|
55 |
+
else:
|
56 |
+
for k, d in enumerate(dets):
|
57 |
+
print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
|
58 |
+
k, d.left(), d.top(), d.right(), d.bottom()))
|
59 |
+
# Get the landmarks/parts for the face in box d.
|
60 |
+
shape = predictor(img, d)
|
61 |
+
print("Part 0: {}, Part 1: {} ...".format(
|
62 |
+
shape.part(0), shape.part(1)))
|
63 |
+
|
64 |
+
t = list(shape.parts())
|
65 |
+
a = []
|
66 |
+
for tt in t:
|
67 |
+
a.append([tt.x, tt.y])
|
68 |
+
lm = np.array(a)
|
69 |
+
# lm is a shape=(68,2) np.array
|
70 |
+
return lm
|
71 |
+
|
72 |
+
def align_face(filepath, out_path):
|
73 |
+
"""
|
74 |
+
:param filepath: str
|
75 |
+
:return: PIL Image
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
lm = get_landmark(filepath)
|
79 |
+
except:
|
80 |
+
print('No landmark ...')
|
81 |
+
return
|
82 |
+
|
83 |
+
lm_chin = lm[0:17] # left-right
|
84 |
+
lm_eyebrow_left = lm[17:22] # left-right
|
85 |
+
lm_eyebrow_right = lm[22:27] # left-right
|
86 |
+
lm_nose = lm[27:31] # top-down
|
87 |
+
lm_nostrils = lm[31:36] # top-down
|
88 |
+
lm_eye_left = lm[36:42] # left-clockwise
|
89 |
+
lm_eye_right = lm[42:48] # left-clockwise
|
90 |
+
lm_mouth_outer = lm[48:60] # left-clockwise
|
91 |
+
lm_mouth_inner = lm[60:68] # left-clockwise
|
92 |
+
|
93 |
+
# Calculate auxiliary vectors.
|
94 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
95 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
96 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
97 |
+
eye_to_eye = eye_right - eye_left
|
98 |
+
mouth_left = lm_mouth_outer[0]
|
99 |
+
mouth_right = lm_mouth_outer[6]
|
100 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
101 |
+
eye_to_mouth = mouth_avg - eye_avg
|
102 |
+
|
103 |
+
# Choose oriented crop rectangle.
|
104 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
105 |
+
x /= np.hypot(*x)
|
106 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
107 |
+
y = np.flipud(x) * [-1, 1]
|
108 |
+
c = eye_avg + eye_to_mouth * 0.1
|
109 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
110 |
+
qsize = np.hypot(*x) * 2
|
111 |
+
|
112 |
+
# read image
|
113 |
+
img = PIL.Image.open(filepath)
|
114 |
+
|
115 |
+
output_size = 512
|
116 |
+
transform_size = 4096
|
117 |
+
enable_padding = False
|
118 |
+
|
119 |
+
# Shrink.
|
120 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
121 |
+
if shrink > 1:
|
122 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)),
|
123 |
+
int(np.rint(float(img.size[1]) / shrink)))
|
124 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
125 |
+
quad /= shrink
|
126 |
+
qsize /= shrink
|
127 |
+
|
128 |
+
# Crop.
|
129 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
130 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
|
131 |
+
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
|
132 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
|
133 |
+
min(crop[2] + border,
|
134 |
+
img.size[0]), min(crop[3] + border, img.size[1]))
|
135 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
136 |
+
img = img.crop(crop)
|
137 |
+
quad -= crop[0:2]
|
138 |
+
|
139 |
+
# Pad.
|
140 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
|
141 |
+
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
|
142 |
+
pad = (max(-pad[0] + border,
|
143 |
+
0), max(-pad[1] + border,
|
144 |
+
0), max(pad[2] - img.size[0] + border,
|
145 |
+
0), max(pad[3] - img.size[1] + border, 0))
|
146 |
+
if enable_padding and max(pad) > border - 4:
|
147 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
148 |
+
img = np.pad(
|
149 |
+
np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
|
150 |
+
'reflect')
|
151 |
+
h, w, _ = img.shape
|
152 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
153 |
+
mask = np.maximum(
|
154 |
+
1.0 -
|
155 |
+
np.minimum(np.float32(x) / pad[0],
|
156 |
+
np.float32(w - 1 - x) / pad[2]), 1.0 -
|
157 |
+
np.minimum(np.float32(y) / pad[1],
|
158 |
+
np.float32(h - 1 - y) / pad[3]))
|
159 |
+
blur = qsize * 0.02
|
160 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
|
161 |
+
img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
162 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
163 |
+
img = PIL.Image.fromarray(
|
164 |
+
np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
165 |
+
quad += pad[:2]
|
166 |
+
|
167 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
|
168 |
+
(quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
169 |
+
|
170 |
+
if output_size < transform_size:
|
171 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
172 |
+
|
173 |
+
# Save aligned image.
|
174 |
+
print('saveing: ', out_path)
|
175 |
+
img.save(out_path)
|
176 |
+
|
177 |
+
return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
parser = argparse.ArgumentParser()
|
182 |
+
parser.add_argument('--in_dir', type=str, default='./inputs/whole_imgs')
|
183 |
+
parser.add_argument('--out_dir', type=str, default='./inputs/cropped_faces')
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
img_list = sorted(glob.glob(f'{args.in_dir}/*.png'))
|
187 |
+
img_list = sorted(img_list)
|
188 |
+
|
189 |
+
for in_path in img_list:
|
190 |
+
out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
|
191 |
+
out_path = out_path.replace('.jpg', '.png')
|
192 |
+
size_ = align_face(in_path, out_path)
|
repositories/CodeFormer/scripts/download_pretrained_models.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
|
7 |
+
|
8 |
+
def download_pretrained_models(method, file_urls):
|
9 |
+
save_path_root = f'./weights/{method}'
|
10 |
+
os.makedirs(save_path_root, exist_ok=True)
|
11 |
+
|
12 |
+
for file_name, file_url in file_urls.items():
|
13 |
+
save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
'method',
|
21 |
+
type=str,
|
22 |
+
help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
file_urls = {
|
26 |
+
'CodeFormer': {
|
27 |
+
'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
28 |
+
},
|
29 |
+
'facelib': {
|
30 |
+
# 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
|
31 |
+
'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
|
32 |
+
'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
if args.method == 'all':
|
37 |
+
for method in file_urls.keys():
|
38 |
+
download_pretrained_models(method, file_urls[method])
|
39 |
+
else:
|
40 |
+
download_pretrained_models(args.method, file_urls[args.method])
|
repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# from basicsr.utils.download_util import download_file_from_google_drive
|
6 |
+
import gdown
|
7 |
+
|
8 |
+
|
9 |
+
def download_pretrained_models(method, file_ids):
|
10 |
+
save_path_root = f'./weights/{method}'
|
11 |
+
os.makedirs(save_path_root, exist_ok=True)
|
12 |
+
|
13 |
+
for file_name, file_id in file_ids.items():
|
14 |
+
file_url = 'https://drive.google.com/uc?id='+file_id
|
15 |
+
save_path = osp.abspath(osp.join(save_path_root, file_name))
|
16 |
+
if osp.exists(save_path):
|
17 |
+
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
18 |
+
if user_response.lower() == 'y':
|
19 |
+
print(f'Covering {file_name} to {save_path}')
|
20 |
+
gdown.download(file_url, save_path, quiet=False)
|
21 |
+
# download_file_from_google_drive(file_id, save_path)
|
22 |
+
elif user_response.lower() == 'n':
|
23 |
+
print(f'Skipping {file_name}')
|
24 |
+
else:
|
25 |
+
raise ValueError('Wrong input. Only accepts Y/N.')
|
26 |
+
else:
|
27 |
+
print(f'Downloading {file_name} to {save_path}')
|
28 |
+
gdown.download(file_url, save_path, quiet=False)
|
29 |
+
# download_file_from_google_drive(file_id, save_path)
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
'method',
|
36 |
+
type=str,
|
37 |
+
help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
# file name: file id
|
41 |
+
# 'dlib': {
|
42 |
+
# 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
|
43 |
+
# 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
|
44 |
+
# 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
|
45 |
+
# }
|
46 |
+
file_ids = {
|
47 |
+
'CodeFormer': {
|
48 |
+
'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
|
49 |
+
},
|
50 |
+
'facelib': {
|
51 |
+
'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
|
52 |
+
'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
if args.method == 'all':
|
57 |
+
for method in file_ids.keys():
|
58 |
+
download_pretrained_models(method, file_ids[method])
|
59 |
+
else:
|
60 |
+
download_pretrained_models(args.method, file_ids[args.method])
|
repositories/CodeFormer/weights/CodeFormer/.gitkeep
ADDED
File without changes
|
repositories/CodeFormer/weights/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Weights
|
2 |
+
|
3 |
+
Put the downloaded pre-trained models to this folder.
|
repositories/CodeFormer/weights/facelib/.gitkeep
ADDED
File without changes
|
repositories/generative-models/.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.pt2
|
2 |
+
.pt2_2
|
3 |
+
.pt13
|
4 |
+
*.egg-info
|
5 |
+
build
|
6 |
+
/outputs
|
7 |
+
/checkpoints
|
repositories/generative-models/LICENSE
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SDXL 0.9 RESEARCH LICENSE AGREEMENT
|
2 |
+
Copyright (c) Stability AI Ltd.
|
3 |
+
This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).
|
4 |
+
By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
|
5 |
+
1. LICENSE GRANT
|
6 |
+
|
7 |
+
a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
|
8 |
+
|
9 |
+
b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
|
10 |
+
|
11 |
+
c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.
|
12 |
+
|
13 |
+
|
14 |
+
2. RESTRICTIONS
|
15 |
+
|
16 |
+
You will not, and will not permit, assist or cause any third party to:
|
17 |
+
|
18 |
+
a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
|
19 |
+
|
20 |
+
b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
|
21 |
+
|
22 |
+
c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or
|
23 |
+
|
24 |
+
d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
|
25 |
+
|
26 |
+
e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
|
27 |
+
|
28 |
+
|
29 |
+
3. ATTRIBUTION
|
30 |
+
|
31 |
+
Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
|
32 |
+
|
33 |
+
|
34 |
+
4. DISCLAIMERS
|
35 |
+
|
36 |
+
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
|
37 |
+
|
38 |
+
|
39 |
+
5. LIMITATION OF LIABILITY
|
40 |
+
|
41 |
+
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
|
42 |
+
|
43 |
+
|
44 |
+
6. INDEMNIFICATION
|
45 |
+
|
46 |
+
You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.
|
47 |
+
|
48 |
+
|
49 |
+
7. TERMINATION; SURVIVAL
|
50 |
+
|
51 |
+
a. This License will automatically terminate upon any breach by you of the terms of this License.
|
52 |
+
|
53 |
+
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
|
54 |
+
|
55 |
+
c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
|
56 |
+
|
57 |
+
|
58 |
+
8. THIRD PARTY MATERIALS
|
59 |
+
|
60 |
+
The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
|
61 |
+
|
62 |
+
|
63 |
+
9. TRADEMARKS
|
64 |
+
|
65 |
+
Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
|
66 |
+
|
67 |
+
|
68 |
+
10. APPLICABLE LAW; DISPUTE RESOLUTION
|
69 |
+
|
70 |
+
This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
|
71 |
+
|
72 |
+
|
73 |
+
11. MISCELLANEOUS
|
74 |
+
|
75 |
+
If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI.
|
repositories/generative-models/README.md
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generative Models by Stability AI
|
2 |
+
|
3 |
+
![sample1](assets/000.jpg)
|
4 |
+
|
5 |
+
## News
|
6 |
+
|
7 |
+
**July 4, 2023**
|
8 |
+
- A technical report on SDXL is now available [here](assets/sdxl_report.pdf).
|
9 |
+
|
10 |
+
**June 22, 2023**
|
11 |
+
|
12 |
+
|
13 |
+
- We are releasing two new diffusion models for research purposes:
|
14 |
+
- `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
|
15 |
+
- `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
16 |
+
|
17 |
+
If you would like to access these models for your research, please apply using one of the following links:
|
18 |
+
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
19 |
+
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
20 |
+
Please log in to your Hugging Face Account with your organization email to request access.
|
21 |
+
**We plan to do a full release soon (July).**
|
22 |
+
|
23 |
+
## The codebase
|
24 |
+
|
25 |
+
### General Philosophy
|
26 |
+
|
27 |
+
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
|
28 |
+
|
29 |
+
### Changelog from the old `ldm` codebase
|
30 |
+
|
31 |
+
For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
|
32 |
+
|
33 |
+
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
|
34 |
+
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
|
35 |
+
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
|
36 |
+
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
|
37 |
+
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
|
38 |
+
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
|
39 |
+
- Autoencoding models have also been cleaned up.
|
40 |
+
|
41 |
+
## Installation:
|
42 |
+
<a name="installation"></a>
|
43 |
+
|
44 |
+
#### 1. Clone the repo
|
45 |
+
|
46 |
+
```shell
|
47 |
+
git clone git@github.com:Stability-AI/generative-models.git
|
48 |
+
cd generative-models
|
49 |
+
```
|
50 |
+
|
51 |
+
#### 2. Setting up the virtualenv
|
52 |
+
|
53 |
+
This is assuming you have navigated to the `generative-models` root after cloning it.
|
54 |
+
|
55 |
+
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
|
56 |
+
|
57 |
+
|
58 |
+
**PyTorch 1.13**
|
59 |
+
|
60 |
+
```shell
|
61 |
+
# install required packages from pypi
|
62 |
+
python3 -m venv .pt1
|
63 |
+
source .pt1/bin/activate
|
64 |
+
pip3 install wheel
|
65 |
+
pip3 install -r requirements_pt13.txt
|
66 |
+
```
|
67 |
+
|
68 |
+
**PyTorch 2.0**
|
69 |
+
|
70 |
+
|
71 |
+
```shell
|
72 |
+
# install required packages from pypi
|
73 |
+
python3 -m venv .pt2
|
74 |
+
source .pt2/bin/activate
|
75 |
+
pip3 install wheel
|
76 |
+
pip3 install -r requirements_pt2.txt
|
77 |
+
```
|
78 |
+
|
79 |
+
## Inference:
|
80 |
+
|
81 |
+
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
|
82 |
+
- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
|
83 |
+
- [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
|
84 |
+
- [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
|
85 |
+
- [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
|
86 |
+
|
87 |
+
**Weights for SDXL**:
|
88 |
+
If you would like to access these models for your research, please apply using one of the following links:
|
89 |
+
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
90 |
+
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
91 |
+
Please log in to your Hugging Face Account with your organization email to request access.
|
92 |
+
|
93 |
+
After obtaining the weights, place them into `checkpoints/`.
|
94 |
+
Next, start the demo using
|
95 |
+
|
96 |
+
```
|
97 |
+
streamlit run scripts/demo/sampling.py --server.port <your_port>
|
98 |
+
```
|
99 |
+
|
100 |
+
### Invisible Watermark Detection
|
101 |
+
|
102 |
+
Images generated with our code use the
|
103 |
+
[invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
|
104 |
+
library to embed an invisible watermark into the model output. We also provide
|
105 |
+
a script to easily detect that watermark. Please note that this watermark is
|
106 |
+
not the same as in previous Stable Diffusion 1.x/2.x versions.
|
107 |
+
|
108 |
+
To run the script you need to either have a working installation as above or
|
109 |
+
try an _experimental_ import using only a minimal amount of packages:
|
110 |
+
```bash
|
111 |
+
python -m venv .detect
|
112 |
+
source .detect/bin/activate
|
113 |
+
|
114 |
+
pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
|
115 |
+
pip install --no-deps invisible-watermark
|
116 |
+
```
|
117 |
+
|
118 |
+
To run the script you need to have a working installation as above. The script
|
119 |
+
is then useable in the following ways (don't forget to activate your
|
120 |
+
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
|
121 |
+
```bash
|
122 |
+
# test a single file
|
123 |
+
python scripts/demo/detect.py <your filename here>
|
124 |
+
# test multiple files at once
|
125 |
+
python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
|
126 |
+
# test all files in a specific folder
|
127 |
+
python scripts/demo/detect.py <your folder name here>/*
|
128 |
+
```
|
129 |
+
|
130 |
+
## Training:
|
131 |
+
|
132 |
+
We are providing example training configs in `configs/example_training`. To launch a training, run
|
133 |
+
|
134 |
+
```
|
135 |
+
python main.py --base configs/<config1.yaml> configs/<config2.yaml>
|
136 |
+
```
|
137 |
+
|
138 |
+
where configs are merged from left to right (later configs overwrite the same values).
|
139 |
+
This can be used to combine model, training and data configs. However, all of them can also be
|
140 |
+
defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
|
141 |
+
run
|
142 |
+
|
143 |
+
```bash
|
144 |
+
python main.py --base configs/example_training/toy/mnist_cond.yaml
|
145 |
+
```
|
146 |
+
|
147 |
+
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
|
148 |
+
|
149 |
+
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
|
150 |
+
|
151 |
+
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
|
152 |
+
|
153 |
+
### Building New Diffusion Models
|
154 |
+
|
155 |
+
#### Conditioner
|
156 |
+
|
157 |
+
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
|
158 |
+
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
|
159 |
+
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
|
160 |
+
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
|
161 |
+
When computing conditionings, the embedder will get `batch[input_key]` as input.
|
162 |
+
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
|
163 |
+
appropriately.
|
164 |
+
Note that the order of the embedders in the `conditioner_config` is important.
|
165 |
+
|
166 |
+
#### Network
|
167 |
+
|
168 |
+
The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
|
169 |
+
enough as we plan to experiment with transformer-based diffusion backbones.
|
170 |
+
|
171 |
+
#### Loss
|
172 |
+
|
173 |
+
The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
|
174 |
+
|
175 |
+
#### Sampler config
|
176 |
+
|
177 |
+
As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
|
178 |
+
solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
|
179 |
+
guidance.
|
180 |
+
|
181 |
+
### Dataset Handling
|
182 |
+
|
183 |
+
|
184 |
+
For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
|
185 |
+
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
|
186 |
+
data keys/values,
|
187 |
+
e.g.,
|
188 |
+
|
189 |
+
```python
|
190 |
+
example = {"jpg": x, # this is a tensor -1...1 chw
|
191 |
+
"txt": "a beautiful image"}
|
192 |
+
```
|
193 |
+
|
194 |
+
where we expect images in -1...1, channel-first format.
|
repositories/generative-models/assets/000.jpg
ADDED
repositories/generative-models/assets/sdxl_report.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d63345686bc36e6f6de1c20610a7657fafba4f24a9e892ea6f0b9a9f36b5c00
|
3 |
+
size 18172854
|
repositories/generative-models/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
4 |
+
params:
|
5 |
+
input_key: jpg
|
6 |
+
monitor: val/rec_loss
|
7 |
+
|
8 |
+
loss_config:
|
9 |
+
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
10 |
+
params:
|
11 |
+
perceptual_weight: 0.25
|
12 |
+
disc_start: 20001
|
13 |
+
disc_weight: 0.5
|
14 |
+
learn_logvar: True
|
15 |
+
|
16 |
+
regularization_weights:
|
17 |
+
kl_loss: 1.0
|
18 |
+
|
19 |
+
regularizer_config:
|
20 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
21 |
+
|
22 |
+
encoder_config:
|
23 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
24 |
+
params:
|
25 |
+
attn_type: none
|
26 |
+
double_z: True
|
27 |
+
z_channels: 4
|
28 |
+
resolution: 256
|
29 |
+
in_channels: 3
|
30 |
+
out_ch: 3
|
31 |
+
ch: 128
|
32 |
+
ch_mult: [ 1, 2, 4 ]
|
33 |
+
num_res_blocks: 4
|
34 |
+
attn_resolutions: [ ]
|
35 |
+
dropout: 0.0
|
36 |
+
|
37 |
+
decoder_config:
|
38 |
+
target: sgm.modules.diffusionmodules.model.Decoder
|
39 |
+
params:
|
40 |
+
attn_type: none
|
41 |
+
double_z: False
|
42 |
+
z_channels: 4
|
43 |
+
resolution: 256
|
44 |
+
in_channels: 3
|
45 |
+
out_ch: 3
|
46 |
+
ch: 128
|
47 |
+
ch_mult: [ 1, 2, 4 ]
|
48 |
+
num_res_blocks: 4
|
49 |
+
attn_resolutions: [ ]
|
50 |
+
dropout: 0.0
|
51 |
+
|
52 |
+
data:
|
53 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
54 |
+
params:
|
55 |
+
train:
|
56 |
+
datapipeline:
|
57 |
+
urls:
|
58 |
+
- "DATA-PATH"
|
59 |
+
pipeline_config:
|
60 |
+
shardshuffle: 10000
|
61 |
+
sample_shuffle: 10000
|
62 |
+
|
63 |
+
decoders:
|
64 |
+
- "pil"
|
65 |
+
|
66 |
+
postprocessors:
|
67 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
68 |
+
params:
|
69 |
+
key: 'jpg'
|
70 |
+
transforms:
|
71 |
+
- target: torchvision.transforms.Resize
|
72 |
+
params:
|
73 |
+
size: 256
|
74 |
+
interpolation: 3
|
75 |
+
- target: torchvision.transforms.ToTensor
|
76 |
+
- target: sdata.mappers.Rescaler
|
77 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
78 |
+
params:
|
79 |
+
h_key: height
|
80 |
+
w_key: width
|
81 |
+
|
82 |
+
loader:
|
83 |
+
batch_size: 8
|
84 |
+
num_workers: 4
|
85 |
+
|
86 |
+
|
87 |
+
lightning:
|
88 |
+
strategy:
|
89 |
+
target: pytorch_lightning.strategies.DDPStrategy
|
90 |
+
params:
|
91 |
+
find_unused_parameters: True
|
92 |
+
|
93 |
+
modelcheckpoint:
|
94 |
+
params:
|
95 |
+
every_n_train_steps: 5000
|
96 |
+
|
97 |
+
callbacks:
|
98 |
+
metrics_over_trainsteps_checkpoint:
|
99 |
+
params:
|
100 |
+
every_n_train_steps: 50000
|
101 |
+
|
102 |
+
image_logger:
|
103 |
+
target: main.ImageLogger
|
104 |
+
params:
|
105 |
+
enable_autocast: False
|
106 |
+
batch_frequency: 1000
|
107 |
+
max_images: 8
|
108 |
+
increase_log_steps: True
|
109 |
+
|
110 |
+
trainer:
|
111 |
+
devices: 0,
|
112 |
+
limit_val_batches: 50
|
113 |
+
benchmark: True
|
114 |
+
accumulate_grad_batches: 1
|
115 |
+
val_check_interval: 10000
|
repositories/generative-models/configs/example_training/imagenet-f8_cond.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
scale_factor: 0.13025
|
6 |
+
disable_first_stage_autocast: True
|
7 |
+
log_keys:
|
8 |
+
- cls
|
9 |
+
|
10 |
+
scheduler_config:
|
11 |
+
target: sgm.lr_scheduler.LambdaLinearScheduler
|
12 |
+
params:
|
13 |
+
warm_up_steps: [10000]
|
14 |
+
cycle_lengths: [10000000000000]
|
15 |
+
f_start: [1.e-6]
|
16 |
+
f_max: [1.]
|
17 |
+
f_min: [1.]
|
18 |
+
|
19 |
+
denoiser_config:
|
20 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
21 |
+
params:
|
22 |
+
num_idx: 1000
|
23 |
+
|
24 |
+
weighting_config:
|
25 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
26 |
+
scaling_config:
|
27 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
28 |
+
discretization_config:
|
29 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
30 |
+
|
31 |
+
network_config:
|
32 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
33 |
+
params:
|
34 |
+
use_checkpoint: True
|
35 |
+
use_fp16: True
|
36 |
+
in_channels: 4
|
37 |
+
out_channels: 4
|
38 |
+
model_channels: 256
|
39 |
+
attention_resolutions: [1, 2, 4]
|
40 |
+
num_res_blocks: 2
|
41 |
+
channel_mult: [1, 2, 4]
|
42 |
+
num_head_channels: 64
|
43 |
+
num_classes: sequential
|
44 |
+
adm_in_channels: 1024
|
45 |
+
use_spatial_transformer: true
|
46 |
+
transformer_depth: 1
|
47 |
+
context_dim: 1024
|
48 |
+
spatial_transformer_attn_type: softmax-xformers
|
49 |
+
|
50 |
+
conditioner_config:
|
51 |
+
target: sgm.modules.GeneralConditioner
|
52 |
+
params:
|
53 |
+
emb_models:
|
54 |
+
# crossattn cond
|
55 |
+
- is_trainable: True
|
56 |
+
input_key: cls
|
57 |
+
ucg_rate: 0.2
|
58 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
59 |
+
params:
|
60 |
+
add_sequence_dim: True # will be used through crossattn then
|
61 |
+
embed_dim: 1024
|
62 |
+
n_classes: 1000
|
63 |
+
# vector cond
|
64 |
+
- is_trainable: False
|
65 |
+
ucg_rate: 0.2
|
66 |
+
input_key: original_size_as_tuple
|
67 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
68 |
+
params:
|
69 |
+
outdim: 256 # multiplied by two
|
70 |
+
# vector cond
|
71 |
+
- is_trainable: False
|
72 |
+
input_key: crop_coords_top_left
|
73 |
+
ucg_rate: 0.2
|
74 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
75 |
+
params:
|
76 |
+
outdim: 256 # multiplied by two
|
77 |
+
|
78 |
+
first_stage_config:
|
79 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
80 |
+
params:
|
81 |
+
ckpt_path: CKPT_PATH
|
82 |
+
embed_dim: 4
|
83 |
+
monitor: val/rec_loss
|
84 |
+
ddconfig:
|
85 |
+
attn_type: vanilla-xformers
|
86 |
+
double_z: true
|
87 |
+
z_channels: 4
|
88 |
+
resolution: 256
|
89 |
+
in_channels: 3
|
90 |
+
out_ch: 3
|
91 |
+
ch: 128
|
92 |
+
ch_mult: [1, 2, 4, 4]
|
93 |
+
num_res_blocks: 2
|
94 |
+
attn_resolutions: []
|
95 |
+
dropout: 0.0
|
96 |
+
lossconfig:
|
97 |
+
target: torch.nn.Identity
|
98 |
+
|
99 |
+
loss_fn_config:
|
100 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
101 |
+
params:
|
102 |
+
sigma_sampler_config:
|
103 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
104 |
+
params:
|
105 |
+
num_idx: 1000
|
106 |
+
|
107 |
+
discretization_config:
|
108 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
109 |
+
|
110 |
+
sampler_config:
|
111 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
112 |
+
params:
|
113 |
+
num_steps: 50
|
114 |
+
|
115 |
+
discretization_config:
|
116 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
117 |
+
|
118 |
+
guider_config:
|
119 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
120 |
+
params:
|
121 |
+
scale: 5.0
|
122 |
+
|
123 |
+
data:
|
124 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
125 |
+
params:
|
126 |
+
train:
|
127 |
+
datapipeline:
|
128 |
+
urls:
|
129 |
+
# USER: adapt this path the root of your custom dataset
|
130 |
+
- "DATA_PATH"
|
131 |
+
pipeline_config:
|
132 |
+
shardshuffle: 10000
|
133 |
+
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
134 |
+
|
135 |
+
decoders:
|
136 |
+
- "pil"
|
137 |
+
|
138 |
+
postprocessors:
|
139 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
140 |
+
params:
|
141 |
+
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
142 |
+
transforms:
|
143 |
+
- target: torchvision.transforms.Resize
|
144 |
+
params:
|
145 |
+
size: 256
|
146 |
+
interpolation: 3
|
147 |
+
- target: torchvision.transforms.ToTensor
|
148 |
+
- target: sdata.mappers.Rescaler
|
149 |
+
|
150 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
151 |
+
params:
|
152 |
+
h_key: height # USER: you might wanna adapt this for your custom dataset
|
153 |
+
w_key: width # USER: you might wanna adapt this for your custom dataset
|
154 |
+
|
155 |
+
loader:
|
156 |
+
batch_size: 64
|
157 |
+
num_workers: 6
|
158 |
+
|
159 |
+
lightning:
|
160 |
+
modelcheckpoint:
|
161 |
+
params:
|
162 |
+
every_n_train_steps: 5000
|
163 |
+
|
164 |
+
callbacks:
|
165 |
+
metrics_over_trainsteps_checkpoint:
|
166 |
+
params:
|
167 |
+
every_n_train_steps: 25000
|
168 |
+
|
169 |
+
image_logger:
|
170 |
+
target: main.ImageLogger
|
171 |
+
params:
|
172 |
+
disabled: False
|
173 |
+
enable_autocast: False
|
174 |
+
batch_frequency: 1000
|
175 |
+
max_images: 8
|
176 |
+
increase_log_steps: True
|
177 |
+
log_first_step: False
|
178 |
+
log_images_kwargs:
|
179 |
+
use_ema_scope: False
|
180 |
+
N: 8
|
181 |
+
n_rows: 2
|
182 |
+
|
183 |
+
trainer:
|
184 |
+
devices: 0,
|
185 |
+
benchmark: True
|
186 |
+
num_sanity_val_steps: 0
|
187 |
+
accumulate_grad_batches: 1
|
188 |
+
max_epochs: 1000
|
repositories/generative-models/configs/example_training/toy/cifar10_cond.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
denoiser_config:
|
6 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
7 |
+
params:
|
8 |
+
weighting_config:
|
9 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
10 |
+
params:
|
11 |
+
sigma_data: 1.0
|
12 |
+
scaling_config:
|
13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
14 |
+
params:
|
15 |
+
sigma_data: 1.0
|
16 |
+
|
17 |
+
network_config:
|
18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
use_checkpoint: True
|
21 |
+
in_channels: 3
|
22 |
+
out_channels: 3
|
23 |
+
model_channels: 32
|
24 |
+
attention_resolutions: []
|
25 |
+
num_res_blocks: 4
|
26 |
+
channel_mult: [1, 2, 2]
|
27 |
+
num_head_channels: 32
|
28 |
+
num_classes: sequential
|
29 |
+
adm_in_channels: 128
|
30 |
+
|
31 |
+
conditioner_config:
|
32 |
+
target: sgm.modules.GeneralConditioner
|
33 |
+
params:
|
34 |
+
emb_models:
|
35 |
+
- is_trainable: True
|
36 |
+
input_key: cls
|
37 |
+
ucg_rate: 0.2
|
38 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
39 |
+
params:
|
40 |
+
embed_dim: 128
|
41 |
+
n_classes: 10
|
42 |
+
|
43 |
+
first_stage_config:
|
44 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
45 |
+
|
46 |
+
loss_fn_config:
|
47 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
48 |
+
params:
|
49 |
+
sigma_sampler_config:
|
50 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
51 |
+
|
52 |
+
sampler_config:
|
53 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
54 |
+
params:
|
55 |
+
num_steps: 50
|
56 |
+
|
57 |
+
discretization_config:
|
58 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
59 |
+
|
60 |
+
guider_config:
|
61 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
62 |
+
params:
|
63 |
+
scale: 3.0
|
64 |
+
|
65 |
+
data:
|
66 |
+
target: sgm.data.cifar10.CIFAR10Loader
|
67 |
+
params:
|
68 |
+
batch_size: 512
|
69 |
+
num_workers: 1
|
70 |
+
|
71 |
+
lightning:
|
72 |
+
modelcheckpoint:
|
73 |
+
params:
|
74 |
+
every_n_train_steps: 5000
|
75 |
+
|
76 |
+
callbacks:
|
77 |
+
metrics_over_trainsteps_checkpoint:
|
78 |
+
params:
|
79 |
+
every_n_train_steps: 25000
|
80 |
+
|
81 |
+
image_logger:
|
82 |
+
target: main.ImageLogger
|
83 |
+
params:
|
84 |
+
disabled: False
|
85 |
+
batch_frequency: 1000
|
86 |
+
max_images: 64
|
87 |
+
increase_log_steps: True
|
88 |
+
log_first_step: False
|
89 |
+
log_images_kwargs:
|
90 |
+
use_ema_scope: False
|
91 |
+
N: 64
|
92 |
+
n_rows: 8
|
93 |
+
|
94 |
+
trainer:
|
95 |
+
devices: 0,
|
96 |
+
benchmark: True
|
97 |
+
num_sanity_val_steps: 0
|
98 |
+
accumulate_grad_batches: 1
|
99 |
+
max_epochs: 20
|
repositories/generative-models/configs/example_training/toy/mnist.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
denoiser_config:
|
6 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
7 |
+
params:
|
8 |
+
weighting_config:
|
9 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
10 |
+
params:
|
11 |
+
sigma_data: 1.0
|
12 |
+
scaling_config:
|
13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
14 |
+
params:
|
15 |
+
sigma_data: 1.0
|
16 |
+
|
17 |
+
network_config:
|
18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
use_checkpoint: True
|
21 |
+
in_channels: 1
|
22 |
+
out_channels: 1
|
23 |
+
model_channels: 32
|
24 |
+
attention_resolutions: []
|
25 |
+
num_res_blocks: 4
|
26 |
+
channel_mult: [1, 2, 2]
|
27 |
+
num_head_channels: 32
|
28 |
+
|
29 |
+
first_stage_config:
|
30 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
31 |
+
|
32 |
+
loss_fn_config:
|
33 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
34 |
+
params:
|
35 |
+
sigma_sampler_config:
|
36 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
37 |
+
|
38 |
+
sampler_config:
|
39 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
40 |
+
params:
|
41 |
+
num_steps: 50
|
42 |
+
|
43 |
+
discretization_config:
|
44 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
45 |
+
|
46 |
+
data:
|
47 |
+
target: sgm.data.mnist.MNISTLoader
|
48 |
+
params:
|
49 |
+
batch_size: 512
|
50 |
+
num_workers: 1
|
51 |
+
|
52 |
+
lightning:
|
53 |
+
modelcheckpoint:
|
54 |
+
params:
|
55 |
+
every_n_train_steps: 5000
|
56 |
+
|
57 |
+
callbacks:
|
58 |
+
metrics_over_trainsteps_checkpoint:
|
59 |
+
params:
|
60 |
+
every_n_train_steps: 25000
|
61 |
+
|
62 |
+
image_logger:
|
63 |
+
target: main.ImageLogger
|
64 |
+
params:
|
65 |
+
disabled: False
|
66 |
+
batch_frequency: 1000
|
67 |
+
max_images: 64
|
68 |
+
increase_log_steps: False
|
69 |
+
log_first_step: False
|
70 |
+
log_images_kwargs:
|
71 |
+
use_ema_scope: False
|
72 |
+
N: 64
|
73 |
+
n_rows: 8
|
74 |
+
|
75 |
+
trainer:
|
76 |
+
devices: 0,
|
77 |
+
benchmark: True
|
78 |
+
num_sanity_val_steps: 0
|
79 |
+
accumulate_grad_batches: 1
|
80 |
+
max_epochs: 10
|