DeepDanbooru / app.py
MonkeyJuice's picture
Update app.py
6bd4266 verified
raw history blame
No virus
4.12 kB
#!/usr/bin/env python
from __future__ import annotations
import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from cropImage import cropImage
from checkIgnore import is_ignore
from createTagDom import create_tag_dom
def predict(image: PIL.Image.Image, score_threshold: float):
result_threshold = genTag(image, score_threshold)
result_html = ''
for label, prob in result_threshold.items():
result_html += create_tag_dom(label, is_ignore(label, 1), prob)
result_html = '<div>' + result_html + '</div>'
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 1)}
result_text = '<div id="m5dd_result">' + ', '.join(result_filter.keys()) + '</div>'
crop_image = cropImage(image)
return result_html, result_text, crop_image
def predict_batch(zip_file, score_threshold: float, progress=gr.Progress()):
result = ''
with zipfile.ZipFile(zip_file) as zf:
for file in progress.tqdm(zf.namelist()):
print(file)
if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg"):
image_file = zf.open(file)
image = PIL.Image.open(image_file)
image = image.convert("RGB")
result_threshold = genTag(image, score_threshold)
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
tag = ', '.join(result_filter.keys())
result = result + str(file) + '\n' + str(tag) + '\n\n'
return result
with gr.Blocks(css="style.css", js="script.js") as demo:
with gr.Tab(label='Single'):
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(label='Upload a image',
type='pil',
elem_classes='m5dd_image',
sources=["upload", "clipboard"])
score_threshold = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.1,
value=0.3)
run_button = gr.Button('Run')
with gr.Accordion(label="Crop Image", open=False):
crop_image = gr.Image(elem_classes='m5dd_image2',
format='jpg',
show_label=False,
show_share_button=False,
container=False)
result_text = gr.HTML(value="")
with gr.Column(scale=2):
result_html = gr.HTML(value="")
with gr.Tab(label='Batch'):
with gr.Row():
with gr.Column(scale=1):
batch_file = gr.File(label="Upload a ZIP file containing images",
file_types=['.zip'])
score_threshold2 = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.1,
value=0.3)
run_button2 = gr.Button('Run')
with gr.Column(scale=2):
result_text2 = gr.Textbox(lines=20,
max_lines=20,
label='Result',
show_copy_button=True,
autoscroll=False)
run_button.click(
fn=predict,
inputs=[image, score_threshold],
outputs=[result_html, result_text, crop_image],
api_name='predict',
)
run_button2.click(
fn=predict_batch,
inputs=[batch_file, score_threshold2],
outputs=[result_text2],
api_name='predict_batch',
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()