Ahsen Khaliq commited on
Commit
56a97f7
1 Parent(s): 0b2237b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("git clone https://github.com/bryandlee/animegan2-pytorch")
3
+
4
+ os.system("gdown https://drive.google.com/uc?id=1WK5Mdt6mwlcsqCZMHkCUSDJxN1UyFi0-")
5
+ os.system("gdown https://drive.google.com/uc?id=18H3iK09_d54qEDoWIc82SyWB2xun4gjU")
6
+
7
+ import sys
8
+ sys.path.append("animegan2-pytorch")
9
+
10
+ import torch
11
+ torch.set_grad_enabled(False)
12
+
13
+ from model import Generator
14
+
15
+ device = "cpu"
16
+
17
+ model = Generator().eval().to(device)
18
+ model.load_state_dict(torch.load("face_paint_512_v2_0.pt"))
19
+
20
+ from PIL import Image
21
+ from torchvision.transforms.functional import to_tensor, to_pil_image
22
+
23
+ def face2paint(
24
+ img: Image.Image,
25
+ size: int,
26
+ side_by_side: bool = True,
27
+ ) -> Image.Image:
28
+
29
+ w, h = img.size
30
+ s = min(w, h)
31
+ img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
32
+ img = img.resize((size, size), Image.LANCZOS)
33
+
34
+ input = to_tensor(img).unsqueeze(0) * 2 - 1
35
+ output = model(input.to(device)).cpu()[0]
36
+
37
+ if side_by_side:
38
+ output = torch.cat([input[0], output], dim=2)
39
+
40
+ output = (output * 0.5 + 0.5).clip(0, 1)
41
+
42
+ return to_pil_image(output)
43
+
44
+
45
+ #@title Face Detector & FFHQ-style Alignment
46
+
47
+ # https://github.com/woctezuma/stylegan2-projecting-images
48
+
49
+ import os
50
+ import dlib
51
+ import collections
52
+ from typing import Union, List
53
+ import numpy as np
54
+ from PIL import Image
55
+ import matplotlib.pyplot as plt
56
+
57
+
58
+ def get_dlib_face_detector(predictor_path: str = "shape_predictor_68_face_landmarks.dat"):
59
+
60
+ if not os.path.isfile(predictor_path):
61
+ model_file = "shape_predictor_68_face_landmarks.dat.bz2"
62
+ os.system(f"wget http://dlib.net/files/{model_file}")
63
+ os.system(f"bzip2 -dk {model_file}")
64
+
65
+ detector = dlib.get_frontal_face_detector()
66
+ shape_predictor = dlib.shape_predictor(predictor_path)
67
+
68
+ def detect_face_landmarks(img: Union[Image.Image, np.ndarray]):
69
+ if isinstance(img, Image.Image):
70
+ img = np.array(img)
71
+ faces = []
72
+ dets = detector(img)
73
+ for d in dets:
74
+ shape = shape_predictor(img, d)
75
+ faces.append(np.array([[v.x, v.y] for v in shape.parts()]))
76
+ return faces
77
+
78
+ return detect_face_landmarks
79
+
80
+
81
+ def display_facial_landmarks(
82
+ img: Image,
83
+ landmarks: List[np.ndarray],
84
+ fig_size=[15, 15]
85
+ ):
86
+ plot_style = dict(
87
+ marker='o',
88
+ markersize=4,
89
+ linestyle='-',
90
+ lw=2
91
+ )
92
+ pred_type = collections.namedtuple('prediction_type', ['slice', 'color'])
93
+ pred_types = {
94
+ 'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),
95
+ 'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),
96
+ 'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),
97
+ 'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),
98
+ 'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),
99
+ 'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),
100
+ 'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),
101
+ 'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),
102
+ 'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))
103
+ }
104
+
105
+ fig = plt.figure(figsize=fig_size)
106
+ ax = fig.add_subplot(1, 1, 1)
107
+ ax.imshow(img)
108
+ ax.axis('off')
109
+
110
+ for face in landmarks:
111
+ for pred_type in pred_types.values():
112
+ ax.plot(
113
+ face[pred_type.slice, 0],
114
+ face[pred_type.slice, 1],
115
+ color=pred_type.color, **plot_style
116
+ )
117
+ plt.show()
118
+
119
+ import PIL.Image
120
+ import PIL.ImageFile
121
+ import numpy as np
122
+ import scipy.ndimage
123
+
124
+
125
+ def align_and_crop_face(
126
+ img: Image.Image,
127
+ landmarks: np.ndarray,
128
+ expand: float = 1.0,
129
+ output_size: int = 1024,
130
+ transform_size: int = 4096,
131
+ enable_padding: bool = True,
132
+ ):
133
+ # Parse landmarks.
134
+ # pylint: disable=unused-variable
135
+ lm = landmarks
136
+ lm_chin = lm[0 : 17] # left-right
137
+ lm_eyebrow_left = lm[17 : 22] # left-right
138
+ lm_eyebrow_right = lm[22 : 27] # left-right
139
+ lm_nose = lm[27 : 31] # top-down
140
+ lm_nostrils = lm[31 : 36] # top-down
141
+ lm_eye_left = lm[36 : 42] # left-clockwise
142
+ lm_eye_right = lm[42 : 48] # left-clockwise
143
+ lm_mouth_outer = lm[48 : 60] # left-clockwise
144
+ lm_mouth_inner = lm[60 : 68] # left-clockwise
145
+
146
+ # Calculate auxiliary vectors.
147
+ eye_left = np.mean(lm_eye_left, axis=0)
148
+ eye_right = np.mean(lm_eye_right, axis=0)
149
+ eye_avg = (eye_left + eye_right) * 0.5
150
+ eye_to_eye = eye_right - eye_left
151
+ mouth_left = lm_mouth_outer[0]
152
+ mouth_right = lm_mouth_outer[6]
153
+ mouth_avg = (mouth_left + mouth_right) * 0.5
154
+ eye_to_mouth = mouth_avg - eye_avg
155
+
156
+ # Choose oriented crop rectangle.
157
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
158
+ x /= np.hypot(*x)
159
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
160
+ x *= expand
161
+ y = np.flipud(x) * [-1, 1]
162
+ c = eye_avg + eye_to_mouth * 0.1
163
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
164
+ qsize = np.hypot(*x) * 2
165
+
166
+ # Shrink.
167
+ shrink = int(np.floor(qsize / output_size * 0.5))
168
+ if shrink > 1:
169
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
170
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
171
+ quad /= shrink
172
+ qsize /= shrink
173
+
174
+ # Crop.
175
+ border = max(int(np.rint(qsize * 0.1)), 3)
176
+ crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
177
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
178
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
179
+ img = img.crop(crop)
180
+ quad -= crop[0:2]
181
+
182
+ # Pad.
183
+ pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
184
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
185
+ if enable_padding and max(pad) > border - 4:
186
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
187
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
188
+ h, w, _ = img.shape
189
+ y, x, _ = np.ogrid[:h, :w, :1]
190
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
191
+ blur = qsize * 0.02
192
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
193
+ img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
194
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
195
+ quad += pad[:2]
196
+
197
+ # Transform.
198
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
199
+ if output_size < transform_size:
200
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
201
+
202
+ return img
203
+
204
+
205
+ import requests
206
+
207
+ def inference(image):
208
+ img = image
209
+ face_detector = get_dlib_face_detector()
210
+ landmarks = face_detector(img)
211
+
212
+ display_facial_landmarks(img, landmarks, fig_size=[5, 5])
213
+
214
+ for landmark in landmarks:
215
+ face = align_and_crop_face(img, landmark, expand=1.3)
216
+ out = face2paint(face, 512)
217
+
218
+ return out
219
+
220
+
221
+
222
+ iface = gr.Interface(inference, "image", "image")
223
+ iface.launch()