Spaces:
Build error
Build error
File size: 5,005 Bytes
9553b1d 19ee04a 9553b1d 19ee04a 9553b1d 57a5749 19ee04a 9553b1d d9a1638 57a5749 19ee04a 9553b1d 19ee04a 9553b1d 7e5df5d 19ee04a d9a1638 f6efc66 19ee04a 2310f00 19ee04a 2310f00 9553b1d 19ee04a 9553b1d 19ee04a 9553b1d 2310f00 9553b1d 19ee04a 9553b1d 19ee04a 9553b1d 2310f00 d9a1638 7e5df5d d9a1638 19ee04a 9553b1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import subprocess
import tarfile
import mim
if os.environ.get('SYSTEM') == 'spaces':
mim.uninstall('mmcv-full', confirm_yes=True)
mim.install('mmcv-full==1.3.16', is_yes=True)
subprocess.call('pip uninstall -y opencv-python'.split())
subprocess.call('pip uninstall -y opencv-python-headless'.split())
subprocess.call('pip install opencv-python-headless'.split())
import anime_face_detector
import cv2
import gradio as gr
import huggingface_hub
import numpy as np
import torch
REPO_URL = 'https://github.com/hysts/anime-face-detector'
TITLE = 'hysts/anime-face-detector'
DESCRIPTION = f'A demo for {REPO_URL}'
ARTICLE = None
TOKEN = os.environ['TOKEN']
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--detector', type=str, default='yolov3')
parser.add_argument('--face-score-slider-step', type=float, default=0.05)
parser.add_argument('--face-score-threshold', type=float, default=0.5)
parser.add_argument('--landmark-score-slider-step',
type=float,
default=0.05)
parser.add_argument('--landmark-score-threshold', type=float, default=0.3)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
parser.add_argument('--allow-screenshot', action='store_true')
return parser.parse_args()
def load_sample_image_paths() -> list[pathlib.Path]:
image_dir = pathlib.Path('images')
if not image_dir.exists():
dataset_repo = 'hysts/sample-images-TADNE'
path = huggingface_hub.hf_hub_download(dataset_repo,
'images.tar.gz',
repo_type='dataset',
use_auth_token=TOKEN)
with tarfile.open(path) as f:
f.extractall()
return sorted(image_dir.glob('*'))
def detect(image: np.ndarray, face_score_threshold: float,
landmark_score_threshold: float,
detector: anime_face_detector.LandmarkDetector) -> np.ndarray:
# RGB -> BGR
image = image[:, :, ::-1]
preds = detector(image)
res = image.copy()
for pred in preds:
box = pred['bbox']
box, score = box[:4], box[4]
if score < face_score_threshold:
continue
box = np.round(box).astype(int)
lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)
pred_pts = pred['keypoints']
for *pt, score in pred_pts:
if score < landmark_score_threshold:
color = (0, 255, 255)
else:
color = (0, 0, 255)
pt = np.round(pt).astype(int)
cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
return res
def main():
gr.close_all()
args = parse_args()
device = torch.device(args.device)
image_paths = load_sample_image_paths()
examples = [[
path.as_posix(), args.face_score_threshold,
args.landmark_score_threshold
] for path in image_paths]
detector = anime_face_detector.create_detector(args.detector,
device=device)
func = functools.partial(detect, detector=detector)
func = functools.update_wrapper(func, detect)
gr.Interface(
func,
[
gr.inputs.Image(type='numpy', label='Input'),
gr.inputs.Slider(0,
1,
step=args.face_score_slider_step,
default=args.face_score_threshold,
label='Face Score Threshold'),
gr.inputs.Slider(0,
1,
step=args.landmark_score_slider_step,
default=args.landmark_score_threshold,
label='Landmark Score Threshold'),
],
gr.outputs.Image(type='numpy', label='Output'),
examples=examples,
title=TITLE,
description=DESCRIPTION,
article=ARTICLE,
theme=args.theme,
allow_screenshot=args.allow_screenshot,
allow_flagging=args.allow_flagging,
live=args.live,
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()
|