File size: 3,716 Bytes
fb59cb8
 
 
 
 
 
 
 
 
 
 
e9c5f95
01b28b7
 
e9c5f95
 
fb59cb8
 
e9c5f95
01b28b7
 
e9c5f95
 
 
 
 
01b28b7
 
 
 
8cdd8e4
fb59cb8
 
 
 
 
 
 
 
 
 
 
bbe49e5
 
 
 
8cdd8e4
bbe49e5
 
 
 
fb59cb8
bbe49e5
8cdd8e4
bbe49e5
8cdd8e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01b28b7
8cdd8e4
01b28b7
 
 
 
 
 
 
8cdd8e4
 
 
01b28b7
bbe49e5
 
 
8cdd8e4
bbe49e5
 
8cdd8e4
 
01b28b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python

from __future__ import annotations

import deepdanbooru as dd
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import tensorflow as tf

def load_model() -> tf.keras.Model:
    path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
                                           'model-resnet_custom_v3.h5')
    model = tf.keras.models.load_model(path)
    return model


def load_labels() -> list[str]:
    path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
                                           'tags.txt')
    with open(path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


model = load_model()
labels = load_labels()


def predict(image: PIL.Image.Image, score_threshold: float):
    _, height, width, _ = model.input_shape
    image = np.asarray(image)
    image = tf.image.resize(image,
                            size=(height, width),
                            method=tf.image.ResizeMethod.AREA,
                            preserve_aspect_ratio=True)
    image = image.numpy()
    image = dd.image.transform_and_pad_image(image, width, height)
    image = image / 255.
    probs = model.predict(image[None, ...])[0]
    probs = probs.astype(float)

    indices = np.argsort(probs)[::-1]
    result_all = dict()
    result_threshold = dict()
    result_html = ''
    for index in indices:
        label = labels[index]
        prob = probs[index]
        result_all[label] = prob
        if prob < score_threshold:
            break
        result_html = result_html + '<p class="m5dd_list use"><span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
        result_threshold[label] = prob
    result_text = ', '.join(result_threshold.keys())
    result_text = '<div id="m5dd_result">' + str(result_text) + '</div>'
    result_html = '<div>' + str(result_html) + '</div>'
    return result_html, result_text

js = """
async () => {
    document.addEventListener('click', function(event) {
        let tagItem = event.target.closest('.m5dd_list')
        let resultArea = event.target.closest('#m5dd_result')
        if (tagItem){
            if (tagItem.classList.contains('use')){
                tagItem.classList.remove('use')
            }else{
                tagItem.classList.add('use')
            }
            document.getElementById('m5dd_result').innerText =
                Array.from(document.querySelectorAll('.m5dd_list.use>span:nth-child(1)'))
                .map(v=>v.innerText)
                .join(', ')
        }else if (resultArea){
            const selection = window.getSelection()
            selection.removeAllRanges()
            const range = document.createRange()
            range.selectNodeContents(resultArea)
            selection.addRange(range)
        }else{
            return
        }
    })
}
"""

with gr.Blocks(css="style.css") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            image = gr.Image(label='Input', type='pil')
            score_threshold = gr.Slider(label='Score threshold',
                                        minimum=0,
                                        maximum=1,
                                        step=0.05,
                                        value=0.5)
            run_button = gr.Button('Run')
            result_text = gr.HTML(value="<div></div>")
        with gr.Column(scale=3):
            result_html = gr.HTML(value="<div></div>")

    run_button.click(
        fn=predict,
        inputs=[image, score_threshold],
        outputs=[result_html, result_text],
        api_name='predict',
    )
    demo.load(None,None,None,_js=js)

demo.queue().launch()