samariddin commited on
Commit
a0d67fb
1 Parent(s): e3fa86e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import functools
5
+ import os
6
+ import pathlib
7
+ import subprocess
8
+ import sys
9
+ import urllib.request
10
+
11
+ if os.environ.get('SYSTEM') == 'spaces':
12
+ # import mim
13
+ # mim.install('mmcv-full==1.3.3', is_yes=True)
14
+
15
+ subprocess.call('pip uninstall -y opencv-python'.split())
16
+ subprocess.call('pip uninstall -y opencv-python-headless'.split())
17
+ subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
18
+ subprocess.call('pip install terminaltables==3.1.0'.split())
19
+ subprocess.call('pip install mmpycocotools==12.0.3'.split())
20
+
21
+ subprocess.call('pip install insightface==0.6.2'.split())
22
+
23
+
24
+ import cv2
25
+ import gradio as gr
26
+ import huggingface_hub
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ sys.path.insert(0, 'insightface/detection/scrfd')
32
+
33
+ from mmdet.apis import inference_detector, init_detector, show_result_pyplot
34
+
35
+ TITLE = 'insightface Face Detection (SCRFD)'
36
+ DESCRIPTION = 'This is an unofficial demo for https://github.com/deepinsight/insightface/tree/master/detection/scrfd.'
37
+ ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.insightface-scrfd" alt="visitor badge"/></center>'
38
+
39
+ TOKEN = os.environ['TOKEN']
40
+
41
+
42
+ def parse_args() -> argparse.Namespace:
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--face-score-slider-step', type=float, default=0.05)
45
+ parser.add_argument('--face-score-threshold', type=float, default=0.3)
46
+ parser.add_argument('--device', type=str, default='cpu')
47
+ parser.add_argument('--theme', type=str)
48
+ parser.add_argument('--live', action='store_true')
49
+ parser.add_argument('--share', action='store_true')
50
+ parser.add_argument('--port', type=int)
51
+ parser.add_argument('--disable-queue',
52
+ dest='enable_queue',
53
+ action='store_false')
54
+ parser.add_argument('--allow-flagging', type=str, default='never')
55
+ return parser.parse_args()
56
+
57
+
58
+ def load_model(model_size: str, device) -> nn.Module:
59
+ ckpt_path = huggingface_hub.hf_hub_download(
60
+ 'hysts/insightface',
61
+ f'models/scrfd_{model_size}/model.pth',
62
+ use_auth_token=TOKEN)
63
+ scrfd_dir = 'insightface/detection/scrfd'
64
+ config_path = f'{scrfd_dir}/configs/scrfd/scrfd_{model_size}.py'
65
+ model = init_detector(config_path, ckpt_path, device.type)
66
+ return model
67
+
68
+
69
+ def update_test_pipeline(model: nn.Module, mode: int):
70
+ cfg = model.cfg
71
+ pipelines = cfg.data.test.pipeline
72
+ for pipeline in pipelines:
73
+ if pipeline.type == 'MultiScaleFlipAug':
74
+ if mode == 0: #640 scale
75
+ pipeline.img_scale = (640, 640)
76
+ if hasattr(pipeline, 'scale_factor'):
77
+ del pipeline.scale_factor
78
+ elif mode == 1: #for single scale in other pages
79
+ pipeline.img_scale = (1100, 1650)
80
+ if hasattr(pipeline, 'scale_factor'):
81
+ del pipeline.scale_factor
82
+ elif mode == 2: #original scale
83
+ pipeline.img_scale = None
84
+ pipeline.scale_factor = 1.0
85
+ transforms = pipeline.transforms
86
+ for transform in transforms:
87
+ if transform.type == 'Pad':
88
+ if mode != 2:
89
+ transform.size = pipeline.img_scale
90
+ if hasattr(transform, 'size_divisor'):
91
+ del transform.size_divisor
92
+ else:
93
+ transform.size = None
94
+ transform.size_divisor = 32
95
+
96
+
97
+ def detect(image: np.ndarray, model_size: str, mode: int,
98
+ face_score_threshold: float,
99
+ detectors: dict[str, nn.Module]) -> np.ndarray:
100
+ model = detectors[model_size]
101
+ update_test_pipeline(model, mode)
102
+
103
+ # RGB -> BGR
104
+ image = image[:, :, ::-1]
105
+ preds = inference_detector(model, image)
106
+ boxes = preds[0]
107
+
108
+ res = image.copy()
109
+ for box in boxes:
110
+ box, score = box[:4], box[4]
111
+ if score < face_score_threshold:
112
+ continue
113
+ box = np.round(box).astype(int)
114
+
115
+ line_width = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
116
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0),
117
+ line_width)
118
+
119
+ res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
120
+ return res
121
+
122
+
123
+ def main():
124
+ args = parse_args()
125
+ device = torch.device(args.device)
126
+
127
+ model_sizes = [
128
+ '500m',
129
+ '1g',
130
+ '2.5g',
131
+ '10g',
132
+ '34g',
133
+ ]
134
+ detectors = {
135
+ model_size: load_model(model_size, device=device)
136
+ for model_size in model_sizes
137
+ }
138
+ modes = [
139
+ '(640, 640)',
140
+ '(1100, 1650)',
141
+ 'original',
142
+ ]
143
+
144
+ func = functools.partial(detect, detectors=detectors)
145
+ func = functools.update_wrapper(func, detect)
146
+
147
+ image_path = pathlib.Path('selfie.jpg')
148
+ if not image_path.exists():
149
+ url = 'https://raw.githubusercontent.com/peiyunh/tiny/master/data/demo/selfie.jpg'
150
+ urllib.request.urlretrieve(url, image_path)
151
+ examples = [[image_path.as_posix(), '10g', modes[0], 0.3]]
152
+
153
+ gr.Interface(
154
+ func,
155
+ [
156
+ gr.inputs.Image(type='numpy', label='Input'),
157
+ gr.inputs.Radio(
158
+ model_sizes, type='value', default='10g', label='Model'),
159
+ gr.inputs.Radio(
160
+ modes, type='index', default=modes[0], label='Mode'),
161
+ gr.inputs.Slider(0,
162
+ 1,
163
+ step=args.face_score_slider_step,
164
+ default=args.face_score_threshold,
165
+ label='Face Score Threshold'),
166
+ ],
167
+ gr.outputs.Image(type='numpy', label='Output'),
168
+ examples=examples,
169
+ title=TITLE,
170
+ description=DESCRIPTION,
171
+ article=ARTICLE,
172
+ theme=args.theme,
173
+ allow_flagging=args.allow_flagging,
174
+ live=args.live,
175
+ ).launch(
176
+ enable_queue=args.enable_queue,
177
+ server_port=args.port,
178
+ share=args.share,
179
+ )
180
+
181
+
182
+ if __name__ == '__main__':
183
+ main()