File size: 4,157 Bytes
906e212
 
 
 
0db1bcd
 
72b39a2
0db1bcd
 
 
 
 
72b39a2
 
 
 
 
0db1bcd
906e212
 
72b39a2
906e212
 
 
 
 
72b39a2
34383d4
72b39a2
 
 
 
3b6e0c9
72b39a2
 
 
 
 
906e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f0f47d
906e212
 
 
 
 
 
 
 
 
5f0f47d
906e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f45b35d
906e212
 
 
 
 
 
 
 
 
72b39a2
5f0f47d
906e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72b39a2
906e212
0db1bcd
906e212
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import functools
import pathlib

import os
import subprocess
import tarfile

if os.environ.get("SYSTEM") == "spaces":
    import mim

    mim.uninstall("mmcv-full", confirm_yes=True)
    subprocess.call("mim install mmcv-full==1.6.2".split())

    subprocess.call("pip uninstall -y opencv-python".split())
    subprocess.call("pip uninstall -y opencv-python-headless".split())
    subprocess.call("pip install opencv-python-headless==4.7.0.72".split())

import cv2
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image

import anime_face_detector


def load_sample_image_paths():
    image_dir = pathlib.Path("images")
    if not image_dir.exists():
        dataset_repo = "hysts/sample-images-TADNE"
        path = huggingface_hub.hf_hub_download(
            dataset_repo, "images.tar.gz", repo_type="dataset"
        )
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob("*"))


def detect(
    img,
    face_score_threshold: float,
    landmark_score_threshold: float,
    detector: anime_face_detector.LandmarkDetector,
) -> PIL.Image.Image:
    if not img:
        return None

    image = cv2.imread(img)
    preds = detector(image)

    res = image.copy()
    for pred in preds:
        box = pred["bbox"]
        box, score = box[:4], box[4]
        if score < face_score_threshold:
            continue
        box = np.round(box).astype(int)

        lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))

        cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)

        pred_pts = pred["keypoints"]
        for *pt, score in pred_pts:
            if score < landmark_score_threshold:
                color = (0, 255, 255)
            else:
                color = (0, 0, 255)
            pt = np.round(pt).astype(int)
            cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
    res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)

    image_pil = PIL.Image.fromarray(res)
    return image_pil


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--detector", type=str, default="yolov3", choices=["yolov3", "faster-rcnn"]
    )
    parser.add_argument("--device", type=str, default="cpu", choices=["cuda:0", "cpu"])
    parser.add_argument("--face-score-threshold", type=float, default=0.5)
    parser.add_argument("--landmark-score-threshold", type=float, default=0.3)
    parser.add_argument("--score-slider-step", type=float, default=0.05)
    parser.add_argument("--port", type=int)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--live", action="store_true")
    args = parser.parse_args()

    image_paths = load_sample_image_paths()
    examples = [[path.as_posix(), 0.5, 0.3] for path in image_paths]

    detector = anime_face_detector.create_detector(args.detector, device=args.device)
    func = functools.partial(detect, detector=detector)

    title = "edisonlee55/hysts-anime-face-detector"
    description = "Demo for edisonlee55/hysts-anime-face-detector. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
    article = "<a href='https://github.com/edisonlee55/hysts-anime-face-detector'>GitHub Repo</a>"

    gr.Interface(
        func,
        [
            gr.Image(type="filepath", label="Input"),
            gr.Slider(
                0,
                1,
                step=args.score_slider_step,
                value=args.face_score_threshold,
                label="Face Score Threshold",
            ),
            gr.Slider(
                0,
                1,
                step=args.score_slider_step,
                value=args.landmark_score_threshold,
                label="Landmark Score Threshold",
            ),
        ],
        gr.Image(type="pil", label="Output"),
        title=title,
        description=description,
        article=article,
        examples=examples,
        live=args.live,
    ).launch(debug=args.debug, share=args.share, server_port=args.port)


if __name__ == "__main__":
    main()