File size: 4,099 Bytes
fb59cb8
 
 
 
 
 
e9c5f95
fb59cb8
 
 
 
 
 
 
 
 
e9c5f95
 
 
01b28b7
 
 
 
e9c5f95
 
 
fb59cb8
 
e9c5f95
01b28b7
 
e9c5f95
 
fb59cb8
 
e9c5f95
01b28b7
 
e9c5f95
 
 
 
 
01b28b7
 
 
40613f2
 
 
 
 
 
 
 
ac32df9
40613f2
01b28b7
bbe49e5
 
 
fb59cb8
 
 
 
 
 
 
 
 
 
 
bbe49e5
 
 
 
 
 
6274240
bbe49e5
40613f2
6274240
40613f2
 
 
 
bbe49e5
fb59cb8
bbe49e5
 
 
 
fb59cb8
 
01b28b7
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe49e5
 
 
 
 
 
 
 
 
 
01b28b7
 
bbe49e5
01b28b7
 
 
bbe49e5
 
 
 
 
 
01b28b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python

from __future__ import annotations

import os
import pathlib
import tarfile

import deepdanbooru as dd
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import tensorflow as tf


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path('images')
    if not image_dir.exists():
        path = huggingface_hub.hf_hub_download(
            'public-data/sample-images-TADNE',
            'images.tar.gz',
            repo_type='dataset')
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob('*'))


def load_model() -> tf.keras.Model:
    path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
                                           'model-resnet_custom_v3.h5')
    model = tf.keras.models.load_model(path)
    return model


def load_labels() -> list[str]:
    path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
                                           'tags.txt')
    with open(path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


model = load_model()
labels = load_labels()

skip = ["rating:safe",
        "rating:questionable",
        "rating:explicit",
        "3d", 
        "photorealistic", 
        "realistic",
        "uncensored"]

translate = {'yuri': 'lesbian', 'paizuri': 'tit job'}


def predict(
        image: PIL.Image.Image, score_threshold: float
) -> tuple[dict[str, float], dict[str, float], str]:
    _, height, width, _ = model.input_shape
    image = np.asarray(image)
    image = tf.image.resize(image,
                            size=(height, width),
                            method=tf.image.ResizeMethod.AREA,
                            preserve_aspect_ratio=True)
    image = image.numpy()
    image = dd.image.transform_and_pad_image(image, width, height)
    image = image / 255.
    probs = model.predict(image[None, ...])[0]
    probs = probs.astype(float)

    indices = np.argsort(probs)[::-1]
    result_all = dict()
    result_threshold = dict()
    for index in indices:
        label = labels[index]
        print(label)
        prob = probs[index]
        if label in skip:
            print("skip", label)
            continue
        if label in translate:
            label = translate[label]

        result_all[label] = prob
        if prob < score_threshold:
            break
        result_threshold[label] = prob
    result_text = ', '.join(result_all.keys())
    return result_threshold, result_all, result_text


image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.5] for path in image_paths]

with gr.Blocks(css='style.css') as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Input', type='pil')
            score_threshold = gr.Slider(label='Score threshold',
                                        minimum=0,
                                        maximum=1,
                                        step=0.05,
                                        value=0.5)
            run_button = gr.Button('Run')
        with gr.Column():
            with gr.Tabs():
                with gr.Tab(label='Output'):
                    result = gr.Label(label='Output', show_label=False)
                with gr.Tab(label='JSON'):
                    result_json = gr.JSON(label='JSON output',
                                          show_label=False)
                with gr.Tab(label='Text'):
                    result_text = gr.Text(label='Text output',
                                          show_label=False,
                                          lines=5)
    gr.Examples(examples=examples,
                inputs=[image, score_threshold],
                outputs=[result, result_json, result_text],
                fn=predict,
                cache_examples=os.getenv('CACHE_EXAMPLES') == '1')

    run_button.click(
        fn=predict,
        inputs=[image, score_threshold],
        outputs=[result, result_json, result_text],
        api_name='predict',
    )
demo.queue().launch()