yuna0x0's picture
Init
906e212 unverified
raw
history blame
3.73 kB
import argparse
import functools
import pathlib
import cv2
import gradio as gr
import numpy as np
import PIL.Image
import torch
import anime_face_detector
def detect(
img,
face_score_threshold: float,
landmark_score_threshold: float,
detector: anime_face_detector.LandmarkDetector,
) -> PIL.Image.Image:
if not img:
return None
image = cv2.imread(img)
preds = detector(image)
res = image.copy()
for pred in preds:
box = pred["bbox"]
box, score = box[:4], box[4]
if score < face_score_threshold:
continue
box = np.round(box).astype(int)
lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)
pred_pts = pred["keypoints"]
for *pt, score in pred_pts:
if score < landmark_score_threshold:
color = (0, 255, 255)
else:
color = (0, 0, 255)
pt = np.round(pt).astype(int)
cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
image_pil = PIL.Image.fromarray(res)
return image_pil
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--detector", type=str, default="yolov3", choices=["yolov3", "faster-rcnn"]
)
parser.add_argument(
"--device", type=str, default="cuda:0", choices=["cuda:0", "cpu"]
)
parser.add_argument("--face-score-threshold", type=float, default=0.5)
parser.add_argument("--landmark-score-threshold", type=float, default=0.3)
parser.add_argument("--score-slider-step", type=float, default=0.05)
parser.add_argument("--port", type=int)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--share", action="store_true")
parser.add_argument("--live", action="store_true")
args = parser.parse_args()
sample_path = pathlib.Path("assets/input.jpg")
if not sample_path.exists():
torch.hub.download_url_to_file(
"https://raw.githubusercontent.com/edisonlee55/hysts-anime-face-detector/main/assets/input.jpg",
sample_path.as_posix(),
)
detector = anime_face_detector.create_detector(args.detector, device=args.device)
func = functools.partial(detect, detector=detector)
title = "edisonlee55/hysts-anime-face-detector"
description = "Demo for edisonlee55/hysts-anime-face-detector. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<a href='https://github.com/edisonlee55/hysts-anime-face-detector'>GitHub Repo</a>"
gr.Interface(
func,
[
gr.Image(type="filepath", label="Input"),
gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.face_score_threshold,
label="Face Score Threshold",
),
gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.landmark_score_threshold,
label="Landmark Score Threshold",
),
],
gr.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=[
[
sample_path.as_posix(),
args.face_score_threshold,
args.landmark_score_threshold,
],
],
live=args.live,
).launch(
debug=args.debug, share=args.share, server_port=args.port, enable_queue=True
)
if __name__ == "__main__":
main()