hysts's picture
hysts HF staff
Update
384ccc2
raw
history blame
3.03 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import pathlib
import sys
import urllib.request
from typing import Union
import cv2
import gradio as gr
import numpy as np
import torch
sys.path.insert(0, "face_detection")
from ibug.face_detection import RetinaFacePredictor, S3FDPredictor
DESCRIPTION = "# [ibug-group/face_detection](https://github.com/ibug-group/face_detection)"
def load_model(model_name: str, threshold: float, device: torch.device) -> Union[RetinaFacePredictor, S3FDPredictor]:
if model_name == "s3fd":
model = S3FDPredictor(threshold=threshold, device=device)
else:
model_name = model_name.replace("retinaface_", "")
model = RetinaFacePredictor(
threshold=threshold, device=device, model=RetinaFacePredictor.get_model(model_name)
)
return model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_names = [
"retinaface_mobilenet0.25",
"retinaface_resnet50",
"s3fd",
]
detectors = {name: load_model(name, threshold=0.8, device=device) for name in model_names}
def detect(image: np.ndarray, model_name: str, face_score_threshold: float) -> np.ndarray:
model = detectors[model_name]
model.threshold = face_score_threshold
# RGB -> BGR
image = image[:, :, ::-1]
preds = model(image, rgb=False)
res = image.copy()
for pred in preds:
box = np.round(pred[:4]).astype(int)
line_width = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), line_width)
if len(pred) == 15:
pts = pred[5:].reshape(-1, 2)
for pt in np.round(pts).astype(int):
cv2.circle(res, tuple(pt), line_width, (0, 255, 0), cv2.FILLED)
return res[:, :, ::-1]
example_image_path = pathlib.Path("selfie.jpg")
if not example_image_path.exists():
url = "https://raw.githubusercontent.com/peiyunh/tiny/master/data/demo/selfie.jpg"
urllib.request.urlretrieve(url, example_image_path)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(type="numpy", label="Input")
model_name = gr.Radio(model_names, type="value", value="retinaface_resnet50", label="Model")
score_threshold = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.8, label="Face Score Threshold")
run_button = gr.Button()
with gr.Column():
result = gr.Image(label="Output")
gr.Examples(
examples=[[example_image_path.as_posix(), model_names[1], 0.8]],
inputs=[image, model_name, score_threshold],
outputs=result,
fn=detect,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
run_button.click(
fn=detect,
inputs=[image, model_name, score_threshold],
outputs=result,
api_name="detect",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()