MonkeyJuice commited on
Commit
aa7c58e
1 Parent(s): 7c078a3

add batch handle

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +51 -41
  3. script.js +24 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏃
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -4,6 +4,7 @@ from __future__ import annotations
4
 
5
  import gradio as gr
6
  import PIL.Image
 
7
  from genTag import genTag
8
 
9
  def predict(image: PIL.Image.Image, score_threshold: float):
@@ -16,47 +17,51 @@ def predict(image: PIL.Image.Image, score_threshold: float):
16
  result_text = '<div id="m5dd_result">' + str(result_text) + '</div>'
17
  return result_html, result_text
18
 
19
- js = """
20
- async () => {
21
- document.addEventListener('click', function(event) {
22
- let tagItem = event.target.closest('.m5dd_list')
23
- let resultArea = event.target.closest('#m5dd_result')
24
- if (tagItem){
25
- if (tagItem.classList.contains('use')){
26
- tagItem.classList.remove('use')
27
- }else{
28
- tagItem.classList.add('use')
29
- }
30
- document.getElementById('m5dd_result').innerText =
31
- Array.from(document.querySelectorAll('.m5dd_list.use>span:nth-child(1)'))
32
- .map(v=>v.innerText)
33
- .join(', ')
34
- }else if (resultArea){
35
- const selection = window.getSelection()
36
- selection.removeAllRanges()
37
- const range = document.createRange()
38
- range.selectNodeContents(resultArea)
39
- selection.addRange(range)
40
- }else{
41
- return
42
- }
43
- })
44
- }
45
- """
46
 
47
- with gr.Blocks(css="style.css") as demo:
48
- with gr.Row():
49
- with gr.Column(scale=1):
50
- image = gr.Image(label='Input', type='pil')
51
- score_threshold = gr.Slider(label='Score threshold',
52
- minimum=0,
53
- maximum=1,
54
- step=0.05,
55
- value=0.5)
56
- run_button = gr.Button('Run')
57
- result_text = gr.HTML(value="<div></div>")
58
- with gr.Column(scale=3):
59
- result_html = gr.HTML(value="<div></div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  run_button.click(
62
  fn=predict,
@@ -64,6 +69,11 @@ with gr.Blocks(css="style.css") as demo:
64
  outputs=[result_html, result_text],
65
  api_name='predict',
66
  )
67
- demo.load(None,None,None,_js=js)
 
 
 
 
 
68
 
69
  demo.queue().launch()
 
4
 
5
  import gradio as gr
6
  import PIL.Image
7
+ import zipfile
8
  from genTag import genTag
9
 
10
  def predict(image: PIL.Image.Image, score_threshold: float):
 
17
  result_text = '<div id="m5dd_result">' + str(result_text) + '</div>'
18
  return result_html, result_text
19
 
20
+ def predict_batch(zip_file, score_threshold: float, progress=gr.Progress()):
21
+ result = ''
22
+ with zipfile.ZipFile(zip_file) as zf:
23
+ for file in progress.tqdm(zf.namelist()):
24
+ print(file)
25
+ if file.endswith(".png") or file.endswith(".jpg"):
26
+ image_file = zf.open(file)
27
+ image = PIL.Image.open(image_file)
28
+ image = image.convert("RGB")
29
+ result_threshold = genTag(image, score_threshold)
30
+ tag = ', '.join(result_threshold.keys())
31
+ result = result + str(file) + '\n' + str(tag) + '\n'
32
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ with gr.Blocks(css="style.css", js="script.js") as demo:
35
+ with gr.Tab(label='Single'):
36
+ with gr.Row():
37
+ with gr.Column(scale=1):
38
+ image = gr.Image(label='Upload a image',
39
+ type='pil',
40
+ sources=["upload", "clipboard"],
41
+ height='20em')
42
+ score_threshold = gr.Slider(label='Score threshold',
43
+ minimum=0,
44
+ maximum=1,
45
+ step=0.05,
46
+ value=0.5)
47
+ run_button = gr.Button('Run')
48
+ result_text = gr.HTML(value="<div></div>")
49
+ with gr.Column(scale=2):
50
+ result_html = gr.HTML(value="<div></div>")
51
+ with gr.Tab(label='Batch'):
52
+ with gr.Row():
53
+ with gr.Column(scale=1):
54
+ batch_file = gr.File(label="Upload a ZIP file containing images",
55
+ file_types=['.zip'],
56
+ height='20em')
57
+ score_threshold2 = gr.Slider(label='Score threshold',
58
+ minimum=0,
59
+ maximum=1,
60
+ step=0.05,
61
+ value=0.5)
62
+ run_button2 = gr.Button('Run')
63
+ with gr.Column(scale=2):
64
+ result_text2 = gr.Textbox(lines=5, show_copy_button=True)
65
 
66
  run_button.click(
67
  fn=predict,
 
69
  outputs=[result_html, result_text],
70
  api_name='predict',
71
  )
72
+ run_button2.click(
73
+ fn=predict_batch,
74
+ inputs=[batch_file, score_threshold2],
75
+ outputs=[result_text2],
76
+ api_name='predict_batch',
77
+ )
78
 
79
  demo.queue().launch()
script.js ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ document.addEventListener('click', function (event) {
3
+ let tagItem = event.target.closest('.m5dd_list')
4
+ let resultArea = event.target.closest('#m5dd_result')
5
+ if (tagItem) {
6
+ if (tagItem.classList.contains('use')) {
7
+ tagItem.classList.remove('use')
8
+ } else {
9
+ tagItem.classList.add('use')
10
+ }
11
+ document.getElementById('m5dd_result').innerText =
12
+ Array.from(document.querySelectorAll('.m5dd_list.use>span:nth-child(1)'))
13
+ .map(v => v.innerText)
14
+ .join(', ')
15
+ } else if (resultArea) {
16
+ const selection = window.getSelection()
17
+ selection.removeAllRanges()
18
+ const range = document.createRange()
19
+ range.selectNodeContents(resultArea)
20
+ selection.addRange(range)
21
+ } else {
22
+ return
23
+ }
24
+ })