ethanNeuralImage commited on
Commit
d1b9ef3
1 Parent(s): 92ec8d3
Files changed (2) hide show
  1. utils/alignment.py +122 -0
  2. utils/mapper_utils.py +49 -0
utils/alignment.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ import PIL.Image
4
+ import scipy
5
+ import scipy.ndimage
6
+ import dlib
7
+
8
+
9
+ def get_landmark(filepath, predictor):
10
+ """get landmark with dlib
11
+ :return: np.array shape=(68, 2)
12
+ """
13
+ detector = dlib.get_frontal_face_detector()
14
+
15
+ img = dlib.load_rgb_image(filepath)
16
+ dets = detector(img, 2)
17
+
18
+ if len(dets) != 0:
19
+ max_area = 0
20
+ for k, d in enumerate(dets):
21
+ if d.area() > max_area or max_area == 0:
22
+ shape = predictor(img, d)
23
+ max_area = d.area()
24
+ else:
25
+ d = dlib.rectangle(0, 0, img.shape[1], img.shape[0])
26
+ shape = predictor(img, d)
27
+
28
+ t = list(shape.parts())
29
+ a = []
30
+ for tt in t:
31
+ a.append([tt.x, tt.y])
32
+ lm = np.array(a)
33
+ return lm
34
+
35
+
36
+ def align_face(filepath, predictor):
37
+ """
38
+ :param filepath: str
39
+ :return: PIL Image
40
+ """
41
+
42
+ lm = get_landmark(filepath, predictor)
43
+
44
+ lm_chin = lm[0: 17] # left-right
45
+ lm_eyebrow_left = lm[17: 22] # left-right
46
+ lm_eyebrow_right = lm[22: 27] # left-right
47
+ lm_nose = lm[27: 31] # top-down
48
+ lm_nostrils = lm[31: 36] # top-down
49
+ lm_eye_left = lm[36: 42] # left-clockwise
50
+ lm_eye_right = lm[42: 48] # left-clockwise
51
+ lm_mouth_outer = lm[48: 60] # left-clockwise
52
+ lm_mouth_inner = lm[60: 68] # left-clockwise
53
+
54
+ # Calculate auxiliary vectors.
55
+ eye_left = np.mean(lm_eye_left, axis=0)
56
+ eye_right = np.mean(lm_eye_right, axis=0)
57
+ eye_avg = (eye_left + eye_right) * 0.5
58
+ eye_to_eye = eye_right - eye_left
59
+ mouth_left = lm_mouth_outer[0]
60
+ mouth_right = lm_mouth_outer[6]
61
+ mouth_avg = (mouth_left + mouth_right) * 0.5
62
+ eye_to_mouth = mouth_avg - eye_avg
63
+
64
+ # Choose oriented crop rectangle.
65
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
66
+ x /= np.hypot(*x)
67
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
68
+ y = np.flipud(x) * [-1, 1]
69
+ c = eye_avg + eye_to_mouth * 0.1
70
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
71
+ qsize = np.hypot(*x) * 2
72
+
73
+ # read image
74
+ img = PIL.Image.open(filepath).convert("RGB")
75
+
76
+ output_size = 256
77
+ transform_size = 256
78
+ enable_padding = True
79
+
80
+ # Shrink.
81
+ shrink = int(np.floor(qsize / output_size * 0.5))
82
+ if shrink > 1:
83
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
84
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
85
+ quad /= shrink
86
+ qsize /= shrink
87
+
88
+ # Crop.
89
+ border = max(int(np.rint(qsize * 0.1)), 3)
90
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
91
+ int(np.ceil(max(quad[:, 1]))))
92
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
93
+ min(crop[3] + border, img.size[1]))
94
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
95
+ img = img.crop(crop)
96
+ quad -= crop[0:2]
97
+
98
+ # Pad.
99
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
100
+ int(np.ceil(max(quad[:, 1]))))
101
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
102
+ max(pad[3] - img.size[1] + border, 0))
103
+ if enable_padding and max(pad) > border - 4:
104
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
105
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
106
+ h, w, _ = img.shape
107
+ y, x, _ = np.ogrid[:h, :w, :1]
108
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
109
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
110
+ blur = qsize * 0.02
111
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
112
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
113
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
114
+ quad += pad[:2]
115
+
116
+ # Transform.
117
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
118
+ if output_size < transform_size:
119
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
120
+
121
+ # Return aligned image.
122
+ return img
utils/mapper_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ google_drive_paths = {
5
+ "stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT",
6
+
7
+ "mapper/pretrained/afro.pt": "https://drive.google.com/uc?id=1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ",
8
+ "mapper/pretrained/angry.pt": "https://drive.google.com/uc?id=1g82HEH0jFDrcbCtn3M22gesWKfzWV_ma",
9
+ "mapper/pretrained/beyonce.pt": "https://drive.google.com/uc?id=1KJTc-h02LXs4zqCyo7pzCp0iWeO6T9fz",
10
+ "mapper/pretrained/bobcut.pt": "https://drive.google.com/uc?id=1IvyqjZzKS-vNdq_OhwapAcwrxgLAY8UF",
11
+ "mapper/pretrained/bowlcut.pt": "https://drive.google.com/uc?id=1xwdxI2YCewSt05dEHgkpmmzoauPjEnnZ",
12
+ "mapper/pretrained/curly_hair.pt": "https://drive.google.com/uc?id=1xZ7fFB12Ci6rUbUfaHPpo44xUFzpWQ6M",
13
+ "mapper/pretrained/depp.pt": "https://drive.google.com/uc?id=1FPiJkvFPG_y-bFanxLLP91wUKuy-l3IV",
14
+ "mapper/pretrained/hilary_clinton.pt": "https://drive.google.com/uc?id=1X7U2zj2lt0KFifIsTfOOzVZXqYyCWVll",
15
+ "mapper/pretrained/mohawk.pt": "https://drive.google.com/uc?id=1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09",
16
+ "mapper/pretrained/purple_hair.pt": "https://drive.google.com/uc?id=14H0CGXWxePrrKIYmZnDD2Ccs65EEww75",
17
+ "mapper/pretrained/surprised.pt": "https://drive.google.com/uc?id=1F-mPrhO-UeWrV1QYMZck63R43aLtPChI",
18
+ "mapper/pretrained/taylor_swift.pt": "https://drive.google.com/uc?id=10jHuHsKKJxuf3N0vgQbX_SMEQgFHDrZa",
19
+ "mapper/pretrained/trump.pt": "https://drive.google.com/uc?id=14v8D0uzy4tOyfBU3ca9T0AzTt3v-dNyh",
20
+ "mapper/pretrained/zuckerberg.pt": "https://drive.google.com/uc?id=1NjDcMUL8G-pO3i_9N6EPpQNXeMc3Ar1r",
21
+
22
+ "example_celebs.pt": "https://drive.google.com/uc?id=1VL3lP4avRhz75LxSza6jgDe-pHd2veQG"
23
+ }
24
+
25
+
26
+ def ensure_checkpoint_exists(model_weights_filename):
27
+ if not os.path.isfile(model_weights_filename) and (
28
+ model_weights_filename in google_drive_paths
29
+ ):
30
+ gdrive_url = google_drive_paths[model_weights_filename]
31
+ try:
32
+ from gdown import download as drive_download
33
+
34
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
35
+ except ModuleNotFoundError:
36
+ print(
37
+ "gdown module not found.",
38
+ "pip3 install gdown or, manually download the checkpoint file:",
39
+ gdrive_url
40
+ )
41
+
42
+ if not os.path.isfile(model_weights_filename) and (
43
+ model_weights_filename not in google_drive_paths
44
+ ):
45
+ print(
46
+ model_weights_filename,
47
+ " not found, you may need to manually download the model weights."
48
+ )
49
+