hysts HF staff commited on
Commit
a038241
1 Parent(s): 7942d4e
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +142 -0
  3. face_detection +1 -0
  4. requirements.txt +4 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "face_detection"]
2
+ path = face_detection
3
+ url = https://github.com/ibug-group/face_detection
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import pathlib
8
+ import sys
9
+ import urllib.request
10
+ from typing import Union
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ sys.path.insert(0, 'face_detection')
19
+
20
+ from ibug.face_detection import RetinaFacePredictor, S3FDPredictor
21
+
22
+ REPO_URL = 'https://github.com/ibug-group/face_detection'
23
+ TITLE = 'ibug-group/face_detection'
24
+ DESCRIPTION = f'This is a demo for {REPO_URL}.'
25
+ ARTICLE = None
26
+
27
+
28
+ def parse_args() -> argparse.Namespace:
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--face-score-slider-step', type=float, default=0.05)
31
+ parser.add_argument('--face-score-threshold', type=float, default=0.8)
32
+ parser.add_argument('--device', type=str, default='cpu')
33
+ parser.add_argument('--theme', type=str)
34
+ parser.add_argument('--live', action='store_true')
35
+ parser.add_argument('--share', action='store_true')
36
+ parser.add_argument('--port', type=int)
37
+ parser.add_argument('--disable-queue',
38
+ dest='enable_queue',
39
+ action='store_false')
40
+ parser.add_argument('--allow-flagging', type=str, default='never')
41
+ parser.add_argument('--allow-screenshot', action='store_true')
42
+ return parser.parse_args()
43
+
44
+
45
+ def load_model(
46
+ model_name: str, threshold: float,
47
+ device: torch.device) -> Union[RetinaFacePredictor, S3FDPredictor]:
48
+ if model_name == 's3fd':
49
+ model = S3FDPredictor(threshold=threshold, device=device)
50
+ else:
51
+ model_name = model_name.replace('retinaface_', '')
52
+ model = RetinaFacePredictor(
53
+ threshold=threshold,
54
+ device=device,
55
+ model=RetinaFacePredictor.get_model(model_name))
56
+ return model
57
+
58
+
59
+ def detect(image: np.ndarray, model_name: str, face_score_threshold: float,
60
+ detectors: dict[str, nn.Module]) -> np.ndarray:
61
+ model = detectors[model_name]
62
+ model.threshold = face_score_threshold
63
+
64
+ # RGB -> BGR
65
+ image = image[:, :, ::-1]
66
+ preds = model(image, rgb=False)
67
+
68
+ res = image.copy()
69
+ for pred in preds:
70
+ box = np.round(pred[:4]).astype(int)
71
+
72
+ line_width = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
73
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0),
74
+ line_width)
75
+
76
+ if len(pred) == 15:
77
+ pts = pred[5:].reshape(-1, 2)
78
+ for pt in np.round(pts).astype(int):
79
+ cv2.circle(res, tuple(pt), line_width, (0, 255, 0), cv2.FILLED)
80
+
81
+ return res[:, :, ::-1]
82
+
83
+
84
+ def main():
85
+ gr.close_all()
86
+
87
+ args = parse_args()
88
+ device = torch.device(args.device)
89
+
90
+ model_names = [
91
+ 'retinaface_mobilenet0.25',
92
+ 'retinaface_resnet50',
93
+ 's3fd',
94
+ ]
95
+ detectors = {
96
+ name: load_model(name,
97
+ threshold=args.face_score_threshold,
98
+ device=device)
99
+ for name in model_names
100
+ }
101
+
102
+ func = functools.partial(detect, detectors=detectors)
103
+ func = functools.update_wrapper(func, detect)
104
+
105
+ image_path = pathlib.Path('selfie.jpg')
106
+ if not image_path.exists():
107
+ url = 'https://raw.githubusercontent.com/peiyunh/tiny/master/data/demo/selfie.jpg'
108
+ urllib.request.urlretrieve(url, image_path)
109
+ examples = [[image_path.as_posix(), model_names[1], 0.8]]
110
+
111
+ gr.Interface(
112
+ func,
113
+ [
114
+ gr.inputs.Image(type='numpy', label='Input'),
115
+ gr.inputs.Radio(model_names,
116
+ type='value',
117
+ default='retinaface_resnet50',
118
+ label='Model'),
119
+ gr.inputs.Slider(0,
120
+ 1,
121
+ step=args.face_score_slider_step,
122
+ default=args.face_score_threshold,
123
+ label='Face Score Threshold'),
124
+ ],
125
+ gr.outputs.Image(type='numpy', label='Output'),
126
+ examples=examples,
127
+ title=TITLE,
128
+ description=DESCRIPTION,
129
+ article=ARTICLE,
130
+ theme=args.theme,
131
+ allow_screenshot=args.allow_screenshot,
132
+ allow_flagging=args.allow_flagging,
133
+ live=args.live,
134
+ ).launch(
135
+ enable_queue=args.enable_queue,
136
+ server_port=args.port,
137
+ share=args.share,
138
+ )
139
+
140
+
141
+ if __name__ == '__main__':
142
+ main()
face_detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bc1e392b11d731fa20b1397c8ff3faed5e7fc76e
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.64
3
+ torch==1.11.0
4
+ torchvision==0.12.0