hysts HF staff commited on
Commit
e9e04ce
1 Parent(s): 629fcb5
Files changed (8) hide show
  1. .gitignore +1 -0
  2. .gitmodules +12 -0
  3. app.py +174 -0
  4. face_detection +1 -0
  5. face_parsing +1 -0
  6. fpage +1 -0
  7. requirements.txt +5 -0
  8. roi_tanh_warping +1 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
.gitmodules ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "face_detection"]
2
+ path = face_detection
3
+ url = https://github.com/ibug-group/face_detection
4
+ [submodule "roi_tanh_warping"]
5
+ path = roi_tanh_warping
6
+ url = https://github.com/ibug-group/roi_tanh_warping
7
+ [submodule "face_parsing"]
8
+ path = face_parsing
9
+ url = https://github.com/hhj1897/face_parsing
10
+ [submodule "fpage"]
11
+ path = fpage
12
+ url = https://github.com/ibug-group/fpage
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ import numpy as np
16
+ import torch
17
+
18
+ sys.path.insert(0, 'face_detection')
19
+ sys.path.insert(0, 'face_parsing')
20
+ sys.path.insert(0, 'fpage')
21
+ sys.path.insert(0, 'roi_tanh_warping')
22
+
23
+ from ibug.age_estimation import AgeEstimator
24
+ from ibug.face_detection import RetinaFacePredictor
25
+ from ibug.face_parsing.utils import label_colormap
26
+
27
+ REPO_URL = 'https://github.com/ibug-group/fpage'
28
+ TITLE = 'ibug-group/fpage'
29
+ DESCRIPTION = f'This is a demo for {REPO_URL}.'
30
+ ARTICLE = None
31
+
32
+ TOKEN = os.environ['TOKEN']
33
+
34
+
35
+ def parse_args() -> argparse.Namespace:
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument('--device', type=str, default='cpu')
38
+ parser.add_argument('--theme', type=str)
39
+ parser.add_argument('--live', action='store_true')
40
+ parser.add_argument('--share', action='store_true')
41
+ parser.add_argument('--port', type=int)
42
+ parser.add_argument('--disable-queue',
43
+ dest='enable_queue',
44
+ action='store_false')
45
+ parser.add_argument('--allow-flagging', type=str, default='never')
46
+ parser.add_argument('--allow-screenshot', action='store_true')
47
+ return parser.parse_args()
48
+
49
+
50
+ def load_sample_images() -> list[pathlib.Path]:
51
+ image_dir = pathlib.Path('images')
52
+ if not image_dir.exists():
53
+ image_dir.mkdir()
54
+ dataset_repo = 'hysts/input-images'
55
+ filenames = ['003.tar']
56
+ for name in filenames:
57
+ path = huggingface_hub.hf_hub_download(dataset_repo,
58
+ name,
59
+ repo_type='dataset',
60
+ use_auth_token=TOKEN)
61
+ with tarfile.open(path) as f:
62
+ f.extractall(image_dir.as_posix())
63
+ return sorted(image_dir.rglob('*.jpg'))
64
+
65
+
66
+ def load_detector(device: torch.device) -> RetinaFacePredictor:
67
+ model = RetinaFacePredictor(
68
+ threshold=0.8,
69
+ device=device,
70
+ model=RetinaFacePredictor.get_model('mobilenet0.25'))
71
+ return model
72
+
73
+
74
+ def load_model(device: torch.device) -> AgeEstimator:
75
+ ckpt_path = huggingface_hub.hf_hub_download(
76
+ 'hysts/ibug',
77
+ 'fpage/models/fpage-resnet50-fcn-14-97.torch',
78
+ use_auth_token=TOKEN)
79
+ model = AgeEstimator(
80
+ device=device,
81
+ ckpt=ckpt_path,
82
+ encoder='resnet50',
83
+ decoder='fcn',
84
+ age_classes=97,
85
+ face_classes=14,
86
+ )
87
+ return model
88
+
89
+
90
+ def predict(image: np.ndarray, max_num_faces: int,
91
+ detector: RetinaFacePredictor, model: AgeEstimator) -> np.ndarray:
92
+ colormap = label_colormap(14)
93
+
94
+ # RGB -> BGR
95
+ image = image[:, :, ::-1]
96
+
97
+ faces = detector(image, rgb=False)
98
+ if len(faces) == 0:
99
+ raise RuntimeError('No face was found.')
100
+ faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces][::-1]
101
+ ages, masks = model.predict_img(image, faces, rgb=False)
102
+
103
+ mask_image = np.zeros_like(image)
104
+ for mask in masks:
105
+ temp = colormap[mask]
106
+ mask_image[temp > 0] = temp[temp > 0]
107
+
108
+ res = image.astype(float) * 0.5 + mask_image[:, :, ::-1] * 0.5
109
+ res = np.clip(np.round(res), 0, 255).astype(np.uint8)
110
+
111
+ for face, age in zip(faces, ages):
112
+ bbox = np.round(face[:4]).astype(int)
113
+ cv2.rectangle(
114
+ res,
115
+ (bbox[0], bbox[1]),
116
+ (bbox[2], bbox[3]),
117
+ color=(0, 255, 0),
118
+ thickness=2,
119
+ )
120
+
121
+ text_content = f'Age: ({age: .1f})'
122
+ cv2.putText(
123
+ res,
124
+ text_content,
125
+ (bbox[0], bbox[1] - 10),
126
+ cv2.FONT_HERSHEY_DUPLEX,
127
+ 1,
128
+ (255, 255, 255),
129
+ lineType=cv2.LINE_AA,
130
+ )
131
+
132
+ return res[:, :, ::-1]
133
+
134
+
135
+ def main():
136
+ gr.close_all()
137
+
138
+ args = parse_args()
139
+ device = torch.device(args.device)
140
+
141
+ detector = load_detector(device)
142
+ model = load_model(device)
143
+
144
+ func = functools.partial(predict, detector=detector, model=model)
145
+ func = functools.update_wrapper(func, predict)
146
+
147
+ image_paths = load_sample_images()
148
+ examples = [[path.as_posix(), 5] for path in image_paths]
149
+
150
+ gr.Interface(
151
+ func,
152
+ [
153
+ gr.inputs.Image(type='numpy', label='Input'),
154
+ gr.inputs.Slider(
155
+ 1, 20, step=1, default=5, label='Max Number of Faces'),
156
+ ],
157
+ gr.outputs.Image(type='numpy', label='Output'),
158
+ examples=examples,
159
+ title=TITLE,
160
+ description=DESCRIPTION,
161
+ article=ARTICLE,
162
+ theme=args.theme,
163
+ allow_screenshot=args.allow_screenshot,
164
+ allow_flagging=args.allow_flagging,
165
+ live=args.live,
166
+ ).launch(
167
+ enable_queue=args.enable_queue,
168
+ server_port=args.port,
169
+ share=args.share,
170
+ )
171
+
172
+
173
+ if __name__ == '__main__':
174
+ main()
face_detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bc1e392b11d731fa20b1397c8ff3faed5e7fc76e
face_parsing ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8ce84123d0433e6ed389b33e5d3dc2a6a1609d70
fpage ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 4e6bef126f18c638423cabfe09362a309731e471
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.64
3
+ timm==0.5.4
4
+ torch==1.11.0
5
+ torchvision==0.12.0
roi_tanh_warping ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f9cb77ed9d4ce4e40f026b2425d62efe517691c9