File size: 4,115 Bytes
fb59cb8
 
 
 
 
 
aa7c58e
7c078a3
166cd36
c410097
895ccaa
01b28b7
32a1cbe
7c078a3
8cdd8e4
7c078a3
895ccaa
c410097
 
935457c
166cd36
 
8cdd8e4
aa7c58e
 
 
 
 
6bd4266
aa7c58e
 
 
 
c410097
 
228bfd8
aa7c58e
8cdd8e4
aa7c58e
 
 
 
 
 
228bfd8
 
aa7c58e
 
 
935457c
bfae66b
32a1cbe
d3cab6f
 
166cd36
 
d3cab6f
166cd36
d3cab6f
aa7c58e
228bfd8
aa7c58e
 
 
 
228bfd8
aa7c58e
 
 
935457c
bfae66b
aa7c58e
 
228bfd8
 
c410097
935457c
 
01b28b7
bbe49e5
 
 
166cd36
bbe49e5
 
aa7c58e
 
 
 
 
 
8cdd8e4
32a1cbe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from cropImage import cropImage
from checkIgnore import is_ignore
from createTagDom import create_tag_dom

def predict(image: PIL.Image.Image, score_threshold: float):
    result_threshold = genTag(image, score_threshold)
    result_html = ''
    for label, prob in result_threshold.items():
        result_html += create_tag_dom(label, is_ignore(label, 1), prob)
    result_html = '<div>' + result_html + '</div>'
    result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 1)}
    result_text = '<div id="m5dd_result">' + ', '.join(result_filter.keys()) + '</div>'
    crop_image = cropImage(image)
    return result_html, result_text, crop_image

def predict_batch(zip_file, score_threshold: float, progress=gr.Progress()):
    result = ''
    with zipfile.ZipFile(zip_file) as zf:
        for file in progress.tqdm(zf.namelist()):
            print(file)
            if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg"):
                image_file = zf.open(file)
                image = PIL.Image.open(image_file)
                image = image.convert("RGB")
                result_threshold = genTag(image, score_threshold)
                result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
                tag = ', '.join(result_filter.keys())
                result = result + str(file) + '\n' + str(tag) + '\n\n'
    return result

with gr.Blocks(css="style.css", js="script.js") as demo:
    with gr.Tab(label='Single'):
        with gr.Row():
            with gr.Column(scale=1):
                image = gr.Image(label='Upload a image',
                                 type='pil',
                                 elem_classes='m5dd_image',
                                 sources=["upload", "clipboard"])
                score_threshold = gr.Slider(label='Score threshold',
                                            minimum=0,
                                            maximum=1,
                                            step=0.1,
                                            value=0.3)
                run_button = gr.Button('Run')
                with gr.Accordion(label="Crop Image", open=False):
                    crop_image = gr.Image(elem_classes='m5dd_image2',
                                          format='jpg',
                                          show_label=False,
                                          show_share_button=False,
                                          container=False)
                result_text = gr.HTML(value="")
            with gr.Column(scale=2):
                result_html = gr.HTML(value="")
    with gr.Tab(label='Batch'):
        with gr.Row():
            with gr.Column(scale=1):
                batch_file = gr.File(label="Upload a ZIP file containing images",
                                     file_types=['.zip'])
                score_threshold2 = gr.Slider(label='Score threshold',
                                             minimum=0,
                                             maximum=1,
                                             step=0.1,
                                             value=0.3)
                run_button2 = gr.Button('Run')
            with gr.Column(scale=2):
                result_text2 = gr.Textbox(lines=20,
                                          max_lines=20,
                                          label='Result',
                                          show_copy_button=True,
                                          autoscroll=False)

    run_button.click(
        fn=predict,
        inputs=[image, score_threshold],
        outputs=[result_html, result_text, crop_image],
        api_name='predict',
    )
    run_button2.click(
        fn=predict_batch,
        inputs=[batch_file, score_threshold2],
        outputs=[result_text2],
        api_name='predict_batch',
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()