Spaces:
Running
Running
File size: 3,877 Bytes
fb59cb8 aa7c58e 7c078a3 c410097 01b28b7 8cdd8e4 7c078a3 8cdd8e4 7c078a3 c410097 228bfd8 8cdd8e4 aa7c58e c410097 228bfd8 aa7c58e 8cdd8e4 aa7c58e 228bfd8 aa7c58e bfae66b aa7c58e 228bfd8 aa7c58e 228bfd8 aa7c58e bfae66b aa7c58e 228bfd8 c410097 01b28b7 bbe49e5 8cdd8e4 bbe49e5 aa7c58e 8cdd8e4 01b28b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#!/usr/bin/env python
from __future__ import annotations
import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from checkIgnore import is_ignore
def predict(image: PIL.Image.Image, score_threshold: float):
result_threshold = genTag(image, score_threshold)
result_html = ''
for label, prob in result_threshold.items():
if is_ignore(label, 1):
result_html += '<p class="m5dd_list">'
else:
result_html += '<p class="m5dd_list use">'
result_html = result_html + '<span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
result_html = '<div>' + result_html + '</div>'
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 1)}
result_text = ', '.join(result_filter.keys())
return result_html, result_text
def predict_batch(zip_file, score_threshold: float, progress=gr.Progress()):
result = ''
with zipfile.ZipFile(zip_file) as zf:
for file in progress.tqdm(zf.namelist()):
print(file)
if file.endswith(".png") or file.endswith(".jpg"):
image_file = zf.open(file)
image = PIL.Image.open(image_file)
image = image.convert("RGB")
result_threshold = genTag(image, score_threshold)
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
tag = ', '.join(result_filter.keys())
result = result + str(file) + '\n' + str(tag) + '\n\n'
return result
with gr.Blocks(css="style.css", js="script.js") as demo:
with gr.Tab(label='Single'):
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(label='Upload a image',
type='pil',
elem_classes='m5dd_image',
sources=["upload", "clipboard"])
score_threshold = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.05,
value=0.3)
run_button = gr.Button('Run')
with gr.Column(scale=2):
result_text = gr.Textbox(lines=3,
max_lines=3,
label='Result',
show_copy_button=True,
elem_id="m5dd_result")
result_html = gr.HTML(value="")
with gr.Tab(label='Batch'):
with gr.Row():
with gr.Column(scale=1):
batch_file = gr.File(label="Upload a ZIP file containing images",
file_types=['.zip'])
score_threshold2 = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.05,
value=0.3)
run_button2 = gr.Button('Run')
with gr.Column(scale=2):
result_text2 = gr.Textbox(lines=20,
max_lines=20,
label='Result',
show_copy_button=True)
run_button.click(
fn=predict,
inputs=[image, score_threshold],
outputs=[result_html, result_text],
api_name='predict',
)
run_button2.click(
fn=predict_batch,
inputs=[batch_file, score_threshold2],
outputs=[result_text2],
api_name='predict_batch',
)
demo.queue().launch()
|