File size: 6,036 Bytes
21ce843
 
 
 
 
 
 
5f16544
 
ce0031d
5f16544
 
 
 
 
 
 
 
 
 
21ce843
 
8a19b0b
71a75ca
8a19b0b
38df08a
 
8a19b0b
21ce843
 
 
8a19b0b
5f16544
 
8a19b0b
 
21ce843
e752590
8a19b0b
 
21ce843
e752590
85e6850
e752590
025ec32
e752590
 
 
 
 
 
 
21ce843
8a19b0b
f4e3f66
 
 
 
 
 
 
5f16544
21ce843
f4e3f66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f16544
 
025ec32
21ce843
ce0031d
8a19b0b
 
21ce843
ce0031d
e752590
 
 
21ce843
 
f4e3f66
21ce843
e752590
 
 
 
 
 
 
 
 
 
21ce843
e752590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a19b0b
 
 
 
 
 
 
 
 
 
 
 
 
2c10ed3
8a19b0b
 
 
 
 
 
 
e752590
8a19b0b
e752590
8a19b0b
 
e752590
 
 
 
 
 
 
85e6850
 
e752590
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python
import pathlib
import os
import gradio as gr
import huggingface_hub
import numpy as np
import functools
from dataclasses import dataclass

from mivolo.predictor import Predictor


@dataclass
class Cfg:
    detector_weights: str
    checkpoint: str
    device: str = "cpu"
    with_persons: bool = True
    disable_faces: bool = False
    draw: bool = True


DESCRIPTION = """
# MiVOLO: Multi-input Transformer for Age and Gender Estimation

This is an official demo for https://github.com/WildChlamydia/MiVOLO.\n
Telegram channel: https://t.me/+K0i2fLGpVKBjNzUy (Russian language)
"""

HF_TOKEN = os.getenv('HF_TOKEN')


def load_models():
    detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector',
                                                    'yolov8x_person_face.pt',
                                                    use_auth_token=HF_TOKEN)

    age_gender_path_v1 = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
                                                      'checkpoint-377.pth.tar',
                                                      use_auth_token=HF_TOKEN)

    age_gender_path_v2 = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
                                                         'mivolo_v2_384_0.15.pth.tar',
                                                         use_auth_token=HF_TOKEN)

    predictor_cfg_v1 = Cfg(detector_path, age_gender_path_v1)
    predictor_cfg_v2 = Cfg(detector_path, age_gender_path_v2)

    predictor_v1 = Predictor(predictor_cfg_v1)
    predictor_v2 = Predictor(predictor_cfg_v2)

    return predictor_v1, predictor_v2


def detect(
        image: np.ndarray,
        score_threshold: float,
        iou_threshold: float,
        mode: str,
        predictor: Predictor
) -> np.ndarray:
    # input is rgb image, output must be rgb too

    predictor.detector.detector_kwargs['conf'] = score_threshold
    predictor.detector.detector_kwargs['iou'] = iou_threshold

    if mode == "Use persons and faces":
        use_persons = True
        disable_faces = False
    elif mode == "Use persons only":
        use_persons = True
        disable_faces = True
    elif mode == "Use faces only":
        use_persons = False
        disable_faces = False

    predictor.age_gender_model.meta.use_persons = use_persons
    predictor.age_gender_model.meta.disable_faces = disable_faces

    image = image[:, :, ::-1]  # RGB -> BGR
    detected_objects, out_im = predictor.recognize(image)
    return out_im[:, :, ::-1]  # BGR -> RGB


def clear():
    return None, 0.4, 0.7, "Use persons and faces", None


predictor_v1, predictor_v2 = load_models()
prediction_func_v1 = functools.partial(detect, predictor=predictor_v1)
prediction_func_v2 = functools.partial(detect, predictor=predictor_v2)

image_dir = pathlib.Path('images')
examples = [[path.as_posix(), 0.4, 0.7, "Use persons and faces"] for path in sorted(image_dir.glob('*.jpg'))]

with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo_v1:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Input', type='numpy')
            score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold')
            iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold')
            mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
                            value="Use persons and faces",
                            label="Inference mode",
                            info="What to use for gender and age recognition")

            with gr.Row():
                clear_button = gr.Button("Clear")
                with gr.Column():
                    run_button = gr.Button("Submit", variant="primary")
        with gr.Column():
            result = gr.Image(label='Output', type='numpy')

    inputs = [image, score_threshold, iou_threshold, mode]
    gr.Examples(examples=examples,
                inputs=inputs,
                outputs=result,
                fn=prediction_func_v1,
                cache_examples=False)
    run_button.click(fn=prediction_func_v1, inputs=inputs, outputs=result, api_name='predict')
    clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result])


with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo_v2:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Input', type='numpy')
            score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold')
            iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold')
            mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
                            value="Use persons and faces",
                            label="Inference mode",
                            info="What to use for gender and age recognition")

            with gr.Row():
                clear_button = gr.Button("Clear")
                with gr.Column():
                    run_button = gr.Button("Submit", variant="primary")
        with gr.Column():
            result = gr.Image(label='Output', type='numpy')

    inputs = [image, score_threshold, iou_threshold, mode]
    gr.Examples(examples=examples,
                inputs=inputs,
                outputs=result,
                fn=prediction_func_v2,
                cache_examples=False)
    run_button.click(fn=prediction_func_v2, inputs=inputs, outputs=result, api_name='predict')
    clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result])


with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Tabs():
        with gr.Tab(label="MiVOLO_V1"):
            demo_v1.render()
        with gr.Tab(label="MiVOLO_V2"):
            demo_v2.render()


if __name__ == "__main__":
    demo.queue(max_size=15).launch()