File size: 3,778 Bytes
fb59cb8
 
 
 
 
 
 
 
e9c5f95
fb59cb8
 
 
 
 
 
 
 
f5184a5
16d3ab6
60df4aa
fb59cb8
601183c
e9c5f95
 
 
fb59cb8
 
 
 
 
 
 
 
 
 
e9c5f95
 
 
 
fb59cb8
e9c5f95
fb59cb8
601183c
e9c5f95
 
 
fb59cb8
 
e9c5f95
 
 
601183c
e9c5f95
 
fb59cb8
 
e9c5f95
 
 
601183c
e9c5f95
 
 
 
 
 
5d94de1
fb59cb8
 
 
 
 
 
 
 
 
 
 
 
18988d2
fb59cb8
 
 
5d94de1
fb59cb8
 
 
 
 
e9c5f95
fb59cb8
 
 
e9c5f95
 
fb59cb8
 
 
 
 
 
6d55a5a
 
 
 
ef20843
6d55a5a
fb59cb8
5d94de1
fb59cb8
f5184a5
 
 
ba37cb9
fb59cb8
ba37cb9
fb59cb8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python

from __future__ import annotations

import argparse
import functools
import os
import pathlib
import tarfile

import deepdanbooru as dd
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import tensorflow as tf

TITLE = 'KichangKim/DeepDanbooru'
DESCRIPTION = 'This is an unofficial demo for https://github.com/KichangKim/DeepDanbooru.'
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.deepdanbooru" alt="visitor badge"/></center>'

HF_TOKEN = os.environ['HF_TOKEN']
MODEL_REPO = 'hysts/DeepDanbooru'
MODEL_FILENAME = 'model-resnet_custom_v3.h5'
LABEL_FILENAME = 'tags.txt'


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--score-slider-step', type=float, default=0.05)
    parser.add_argument('--score-threshold', type=float, default=0.5)
    parser.add_argument('--share', action='store_true')
    return parser.parse_args()


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path('images')
    if not image_dir.exists():
        dataset_repo = 'hysts/sample-images-TADNE'
        path = huggingface_hub.hf_hub_download(dataset_repo,
                                               'images.tar.gz',
                                               repo_type='dataset',
                                               use_auth_token=HF_TOKEN)
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob('*'))


def load_model() -> tf.keras.Model:
    path = huggingface_hub.hf_hub_download(MODEL_REPO,
                                           MODEL_FILENAME,
                                           use_auth_token=HF_TOKEN)
    model = tf.keras.models.load_model(path)
    return model


def load_labels() -> list[str]:
    path = huggingface_hub.hf_hub_download(MODEL_REPO,
                                           LABEL_FILENAME,
                                           use_auth_token=HF_TOKEN)
    with open(path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


def predict(image: PIL.Image.Image, score_threshold: float,
            model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
    _, height, width, _ = model.input_shape
    image = np.asarray(image)
    image = tf.image.resize(image,
                            size=(height, width),
                            method=tf.image.ResizeMethod.AREA,
                            preserve_aspect_ratio=True)
    image = image.numpy()
    image = dd.image.transform_and_pad_image(image, width, height)
    image = image / 255.
    probs = model.predict(image[None, ...])[0]
    probs = probs.astype(float)
    res = dict()
    for prob, label in zip(probs.tolist(), labels):
        if prob < score_threshold:
            continue
        res[label] = prob
    return res


def main():
    args = parse_args()

    image_paths = load_sample_image_paths()
    examples = [[path.as_posix(), args.score_threshold]
                for path in image_paths]

    model = load_model()
    labels = load_labels()

    func = functools.partial(predict, model=model, labels=labels)

    gr.Interface(
        func,
        [
            gr.Image(type='pil', label='Input'),
            gr.Slider(0,
                      1,
                      step=args.score_slider_step,
                      value=args.score_threshold,
                      label='Score Threshold'),
        ],
        gr.Label(label='Output'),
        examples=examples,
        title=TITLE,
        description=DESCRIPTION,
        article=ARTICLE,
        allow_flagging='never',
    ).launch(
        enable_queue=True,
        share=args.share,
    )


if __name__ == '__main__':
    main()