File size: 3,677 Bytes
fb59cb8
 
 
 
 
 
aa7c58e
7c078a3
c410097
01b28b7
8cdd8e4
7c078a3
8cdd8e4
7c078a3
c410097
 
 
 
 
 
 
 
8cdd8e4
 
aa7c58e
 
 
 
 
 
 
 
 
 
c410097
 
aa7c58e
 
8cdd8e4
aa7c58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c410097
 
 
01b28b7
bbe49e5
 
 
8cdd8e4
bbe49e5
 
aa7c58e
 
 
 
 
 
8cdd8e4
01b28b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from checkIgnore import is_ignore

def predict(image: PIL.Image.Image, score_threshold: float):
    result_threshold = genTag(image, score_threshold)
    result_html = ''
    for label, prob in result_threshold.items():
        if is_ignore(label, 1):
            result_html += '<p class="m5dd_list">'
        else:
            result_html += '<p class="m5dd_list use">'
        result_html = result_html + '<span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
    result_html = '<div>' + result_html + '</div>'
    result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 1)}
    result_text = '<div id="m5dd_result">' + ', '.join(result_filter.keys()) + '</div>'
    return result_html, result_text

def predict_batch(zip_file, score_threshold: float, progress=gr.Progress()):
    result = ''
    with zipfile.ZipFile(zip_file) as zf:
        for file in progress.tqdm(zf.namelist()):
            print(file)
            if file.endswith(".png") or file.endswith(".jpg"):
                image_file = zf.open(file)
                image = PIL.Image.open(image_file)
                image = image.convert("RGB")
                result_threshold = genTag(image, score_threshold)
                result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
                tag = ', '.join(result_filter.keys())
                result = result + str(file) + '\n' + str(tag) + '\n'
    return result

with gr.Blocks(css="style.css", js="script.js") as demo:
    with gr.Tab(label='Single'):
        with gr.Row():
            with gr.Column(scale=1):
                image = gr.Image(label='Upload a image',
                                 type='pil',
                                 sources=["upload", "clipboard"],
                                 height='20em')
                score_threshold = gr.Slider(label='Score threshold',
                                            minimum=0,
                                            maximum=1,
                                            step=0.05,
                                            value=0.5)
                run_button = gr.Button('Run')
                result_text = gr.HTML(value="<div></div>")
            with gr.Column(scale=2):
                result_html = gr.HTML(value="<div></div>")
    with gr.Tab(label='Batch'):
        with gr.Row():
            with gr.Column(scale=1):
                batch_file = gr.File(label="Upload a ZIP file containing images",
                                     file_types=['.zip'],
                                     height='20em')
                score_threshold2 = gr.Slider(label='Score threshold',
                                             minimum=0,
                                             maximum=1,
                                             step=0.05,
                                             value=0.5)
                run_button2 = gr.Button('Run')
            with gr.Column(scale=2):
                result_text2 = gr.Textbox(lines=5,
                                          label='Result',
                                          show_copy_button=True)

    run_button.click(
        fn=predict,
        inputs=[image, score_threshold],
        outputs=[result_html, result_text],
        api_name='predict',
    )
    run_button2.click(
        fn=predict_batch,
        inputs=[batch_file, score_threshold2],
        outputs=[result_text2],
        api_name='predict_batch',
    )

demo.queue().launch()