guocheng66 commited on
Commit
57f6383
1 Parent(s): c04f3e8

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import ImageColor
3
+
4
+ import onnxruntime
5
+ import cv2
6
+ import numpy as np
7
+
8
+ # The common resume photo size is 35mmx45mm
9
+ RESUME_PHOTO_W = 350
10
+ RESUME_PHOTO_H = 450
11
+
12
+
13
+ # modified from https://github.com/opencv/opencv_zoo/blob/main/models/face_detection_yunet/yunet.py
14
+ class YuNet:
15
+ def __init__(
16
+ self,
17
+ modelPath,
18
+ inputSize=[320, 320],
19
+ confThreshold=0.6,
20
+ nmsThreshold=0.3,
21
+ topK=5000,
22
+ backendId=0,
23
+ targetId=0,
24
+ ):
25
+ self._modelPath = modelPath
26
+ self._inputSize = tuple(inputSize) # [w, h]
27
+ self._confThreshold = confThreshold
28
+ self._nmsThreshold = nmsThreshold
29
+ self._topK = topK
30
+ self._backendId = backendId
31
+ self._targetId = targetId
32
+
33
+ self._model = cv2.FaceDetectorYN.create(
34
+ model=self._modelPath,
35
+ config="",
36
+ input_size=self._inputSize,
37
+ score_threshold=self._confThreshold,
38
+ nms_threshold=self._nmsThreshold,
39
+ top_k=self._topK,
40
+ backend_id=self._backendId,
41
+ target_id=self._targetId,
42
+ )
43
+
44
+ @property
45
+ def name(self):
46
+ return self.__class__.__name__
47
+
48
+ def setBackendAndTarget(self, backendId, targetId):
49
+ self._backendId = backendId
50
+ self._targetId = targetId
51
+ self._model = cv2.FaceDetectorYN.create(
52
+ model=self._modelPath,
53
+ config="",
54
+ input_size=self._inputSize,
55
+ score_threshold=self._confThreshold,
56
+ nms_threshold=self._nmsThreshold,
57
+ top_k=self._topK,
58
+ backend_id=self._backendId,
59
+ target_id=self._targetId,
60
+ )
61
+
62
+ def setInputSize(self, input_size):
63
+ self._model.setInputSize(tuple(input_size))
64
+
65
+ def infer(self, image):
66
+ # Forward
67
+ faces = self._model.detect(image)
68
+ return faces[1]
69
+
70
+
71
+ class ONNXModel:
72
+ def __init__(self, model_path, input_w, input_h):
73
+ self.model = onnxruntime.InferenceSession(model_path)
74
+ self.input_w = input_w
75
+ self.input_h = input_h
76
+
77
+ def preprocess(self, rgb, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
78
+ # convert the input data into the float32 input
79
+ img_data = (
80
+ np.array(cv2.resize(rgb, (self.input_w, self.input_h)))
81
+ .transpose(2, 0, 1)
82
+ .astype("float32")
83
+ )
84
+
85
+ # normalize
86
+ norm_img_data = np.zeros(img_data.shape).astype("float32")
87
+
88
+ for i in range(img_data.shape[0]):
89
+ norm_img_data[i, :, :] = img_data[i, :, :] / 255
90
+ norm_img_data[i, :, :] = (norm_img_data[i, :, :] - mean[i]) / std[i]
91
+
92
+ # add batch channel
93
+ norm_img_data = norm_img_data.reshape(1, 3, self.input_h, self.input_w).astype(
94
+ "float32"
95
+ )
96
+ return norm_img_data
97
+
98
+ def forward(self, image):
99
+ input_data = self.preprocess(image)
100
+ output_data = self.model.run(["argmax_0.tmp_0"], {"x": input_data})
101
+
102
+ return output_data
103
+
104
+
105
+ def make_resume_photo(rgb, background_color):
106
+ h, w, _ = rgb.shape
107
+ bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
108
+
109
+ # Initialize models
110
+ face_detector = YuNet("models/face_detection_yunet_2023mar.onnx")
111
+ face_detector.setInputSize([w, h])
112
+ human_segmentor = ONNXModel(
113
+ "models/human_pp_humansegv2_lite_192x192_inference_model.onnx", 192, 192
114
+ )
115
+
116
+ # yunet uses opencv bgr image format
117
+ detections = face_detector.infer(bgr)
118
+
119
+ results = []
120
+ for idx, det in enumerate(detections):
121
+ # bounding box
122
+ pt1 = np.array((det[0], det[1]))
123
+ pt2 = np.array((det[0] + det[2], det[1] + det[3]))
124
+
125
+ # face landmarks
126
+ landmarks = det[4:14].reshape((5, 2))
127
+ right_eye = landmarks[0]
128
+ left_eye = landmarks[1]
129
+
130
+ angle = np.arctan2(right_eye[1] - left_eye[1], (right_eye[0] - left_eye[0]))
131
+ rmat = cv2.getRotationMatrix2D((0, 0), -angle, 1)
132
+
133
+ # apply rotation
134
+ rotated_bgr = cv2.warpAffine(bgr, rmat, (bgr.shape[1], bgr.shape[0]))
135
+ rotated_pt1 = rmat[:, :-1] @ pt1
136
+ rotated_pt2 = rmat[:, :-1] @ pt2
137
+
138
+ face_w, face_h = rotated_pt2 - rotated_pt1
139
+ up_length = int(face_h / 4)
140
+ down_length = int(face_h / 3)
141
+ crop_h = face_h + up_length + down_length
142
+ crop_w = int(crop_h * (RESUME_PHOTO_W / RESUME_PHOTO_H))
143
+
144
+ pt1 = np.array(
145
+ (rotated_pt1[0] - (crop_w - face_w) / 2, rotated_pt1[1] - up_length)
146
+ ).astype(np.int32)
147
+ pt2 = np.array((pt1[0] + crop_w, pt1[1] + crop_h)).astype(np.int32)
148
+
149
+ resume_photo = rotated_bgr[pt1[1] : pt2[1], pt1[0] : pt2[0], :]
150
+
151
+ rgb = cv2.cvtColor(resume_photo, cv2.COLOR_BGR2RGB)
152
+ mask = human_segmentor.forward(rgb)
153
+ mask = mask[0].transpose(1, 2, 0)
154
+ mask = cv2.resize(
155
+ mask.astype(np.uint8), (resume_photo.shape[1], resume_photo.shape[0])
156
+ )
157
+
158
+ resume_photo = cv2.cvtColor(resume_photo, cv2.COLOR_BGR2RGB)
159
+ resume_photo[mask == 0] = ImageColor.getcolor(background_color, "RGB")
160
+ resume_photo = cv2.resize(resume_photo, (RESUME_PHOTO_W, RESUME_PHOTO_H))
161
+ results.append(resume_photo)
162
+
163
+ return results
164
+
165
+
166
+ title = "Resume Photo Maker"
167
+
168
+ demo = gr.Interface(
169
+ fn=make_resume_photo,
170
+ inputs=[
171
+ gr.Image(type="numpy", label="input"),
172
+ gr.ColorPicker(label="background color"),
173
+ ],
174
+ outputs=gr.Gallery(label="output"),
175
+ examples=[
176
+ ["images/elon.jpg", "#FFFFFF"],
177
+ ["images/9_Press_Conference_Press_Conference_9_45.jpg", "#FFFFFF"],
178
+ ],
179
+ title=title,
180
+ allow_flagging="never",
181
+ article="<p style='text-align: center;'><a href='https://github.com/bot66/resume-photo-maker' target='_blank'>Github Repo</a></p>",
182
+ )
183
+
184
+ if __name__ == "__main__":
185
+ demo.launch()
images/9_Press_Conference_Press_Conference_9_45.jpg ADDED
images/elon.jpg ADDED
models/face_detection_yunet_2023mar.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f2383e4dd3cfbb4553ea8718107fc0423210dc964f9f4280604804ed2552fa4
3
+ size 232589
models/human_pp_humansegv2_lite_192x192_inference_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34edc335d7833f5a96bb2dadafb1d9da24bac072a26b447c18dd021ea8f29215
3
+ size 12219997
requirements.txt ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.1.2
3
+ annotated-types==0.6.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ certifi==2023.7.22
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ coloredlogs==15.0.1
11
+ contourpy==1.2.0
12
+ cycler==0.12.1
13
+ exceptiongroup==1.1.3
14
+ fastapi==0.104.1
15
+ ffmpy==0.3.1
16
+ filelock==3.13.1
17
+ flatbuffers==23.5.26
18
+ fonttools==4.44.0
19
+ fsspec==2023.10.0
20
+ gradio==4.3.0
21
+ gradio_client==0.7.0
22
+ h11==0.14.0
23
+ httpcore==1.0.1
24
+ httpx==0.25.1
25
+ huggingface-hub==0.19.2
26
+ humanfriendly==10.0
27
+ idna==3.4
28
+ importlib-resources==6.1.1
29
+ Jinja2==3.1.2
30
+ jsonschema==4.19.2
31
+ jsonschema-specifications==2023.7.1
32
+ kiwisolver==1.4.5
33
+ markdown-it-py==3.0.0
34
+ MarkupSafe==2.1.3
35
+ matplotlib==3.8.1
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ numpy==1.26.1
39
+ onnxruntime==1.16.1
40
+ opencv-python==4.8.1.78
41
+ orjson==3.9.10
42
+ packaging==23.2
43
+ pandas==2.1.2
44
+ Pillow==10.1.0
45
+ protobuf==4.25.0
46
+ pydantic==2.4.2
47
+ pydantic_core==2.10.1
48
+ pydub==0.25.1
49
+ Pygments==2.16.1
50
+ pyparsing==3.1.1
51
+ python-dateutil==2.8.2
52
+ python-multipart==0.0.6
53
+ pytz==2023.3.post1
54
+ PyYAML==6.0.1
55
+ referencing==0.30.2
56
+ requests==2.31.0
57
+ rich==13.6.0
58
+ rpds-py==0.12.0
59
+ semantic-version==2.10.0
60
+ shellingham==1.5.4
61
+ six==1.16.0
62
+ sniffio==1.3.0
63
+ starlette==0.27.0
64
+ sympy==1.12
65
+ tomlkit==0.12.0
66
+ toolz==0.12.0
67
+ tqdm==4.66.1
68
+ typer==0.9.0
69
+ typing_extensions==4.8.0
70
+ tzdata==2023.3
71
+ urllib3==2.0.7
72
+ uvicorn==0.24.0.post1
73
+ websockets==11.0.3