Ankur Goyal commited on
Commit
253dc57
1 Parent(s): d703b38

May remove the fields variable

Browse files
Files changed (1) hide show
  1. app.py +71 -54
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image, ImageDraw
6
  import traceback
7
 
8
  import gradio as gr
 
9
 
10
  import torch
11
  from docquery import pipeline
@@ -99,16 +100,17 @@ FIELDS = {
99
  "Payment Terms": ["Payment Terms?"],
100
  }
101
 
102
- EMPTY_TABLE = dict(
103
- headers=["Field", "Value"], value=[[name, None] for name in FIELDS.keys()]
104
- )
105
 
 
 
106
 
107
- def process_document(document, model, error=None):
 
108
  if document is not None and error is None:
109
- preview, json_output, table = process_fields(document, model)
110
  return (
111
  document,
 
112
  preview,
113
  gr.update(visible=True),
114
  gr.update(visible=False, value=None),
@@ -118,6 +120,7 @@ def process_document(document, model, error=None):
118
  else:
119
  return (
120
  None,
 
121
  None,
122
  gr.update(visible=False),
123
  gr.update(visible=True, value=error) if error is not None else None,
@@ -129,6 +132,7 @@ def process_document(document, model, error=None):
129
  def process_path(path, model):
130
  error = None
131
  document = None
 
132
  if path:
133
  try:
134
  document = load_document(path)
@@ -136,7 +140,7 @@ def process_path(path, model):
136
  traceback.print_exc()
137
  error = str(e)
138
 
139
- return process_document(document, model, error)
140
 
141
 
142
  def process_upload(file, model):
@@ -159,40 +163,36 @@ def annotate_page(prediction, pages, document):
159
  draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
160
 
161
 
162
- def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
 
 
163
  if not question or document is None:
164
  return None, None, None
165
 
166
  text_value = None
167
- predictions = run_pipeline(model, question, document, 3)
168
- pages = [x.copy().convert("RGB") for x in document.preview]
169
- for i, p in enumerate(ensure_list(predictions)):
170
- if i == 0:
171
- text_value = p["answer"]
172
- else:
173
- # Keep the code around to produce multiple boxes, but only show the top
174
- # prediction for now
175
- break
176
-
177
- annotate_page(p, pages, document)
178
-
179
  return (
180
  gr.update(visible=True, value=pages),
181
- gr.update(visible=True, value=predictions),
182
- gr.update(
183
- visible=True,
184
- value=text_value,
185
- ),
186
  )
187
 
188
 
189
- def process_fields(document, model=list(CHECKPOINTS.keys())[0]):
190
  pages = [x.copy().convert("RGB") for x in document.preview]
191
 
192
  ret = {}
193
  table = []
194
 
195
- for (field_name, questions) in FIELDS.items():
196
  answers = [run_pipeline(model, q, document, top_k=1) for q in questions]
197
  answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
198
  top = answers[0]
@@ -208,23 +208,22 @@ def process_fields(document, model=list(CHECKPOINTS.keys())[0]):
208
 
209
 
210
  def load_example_document(img, title, model):
 
 
211
  if img is not None:
212
  if title in QUESTION_FILES:
213
- print("using document")
214
  document = load_document(QUESTION_FILES[title])
215
  else:
216
  document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
217
- else:
218
- document = None
219
 
220
- return process_document(document, model)
221
 
222
 
223
  CSS = """
224
  #question input {
225
  font-size: 16px;
226
  }
227
- #url-textbox {
228
  padding: 0 !important;
229
  }
230
  #short-upload-box .w-full {
@@ -327,6 +326,7 @@ with gr.Blocks(css=CSS) as demo:
327
  )
328
 
329
  document = gr.Variable()
 
330
  example_question = gr.Textbox(visible=False)
331
  example_image = gr.Image(visible=False)
332
 
@@ -364,13 +364,16 @@ with gr.Blocks(css=CSS) as demo:
364
  )
365
 
366
  with gr.Column() as col:
 
 
 
 
 
 
 
 
 
367
  gr.Markdown("## 2. Ask a question")
368
- question = gr.Textbox(
369
- label="Question",
370
- placeholder="e.g. What is the invoice number?",
371
- lines=1,
372
- max_lines=1,
373
- )
374
  model = gr.Radio(
375
  choices=list(CHECKPOINTS.keys()),
376
  value=list(CHECKPOINTS.keys())[0],
@@ -379,24 +382,27 @@ with gr.Blocks(css=CSS) as demo:
379
  )
380
 
381
  with gr.Row():
382
- clear_button = gr.Button("Clear", variant="secondary")
 
 
 
 
 
 
 
 
383
  submit_button = gr.Button(
384
- "Submit", variant="primary", elem_id="submit-button"
385
  )
386
- with gr.Tabs():
387
- with gr.TabItem("Table"):
388
- output_table = gr.Dataframe(**EMPTY_TABLE)
389
-
390
- with gr.TabItem("JSON"):
391
- output = gr.JSON(label="Output", visible=False)
392
 
393
  for cb in [img_clear_button, clear_button]:
394
  cb.click(
395
  lambda _: (
396
- gr.update(visible=False, value=None), # image
397
- None, # document
398
- gr.update(visible=False, value=None), # output
399
- gr.update(**EMPTY_TABLE), # output_table
 
400
  gr.update(visible=False),
401
  None,
402
  None,
@@ -408,6 +414,7 @@ with gr.Blocks(css=CSS) as demo:
408
  outputs=[
409
  image,
410
  document,
 
411
  output,
412
  output_table,
413
  img_clear_button,
@@ -419,22 +426,32 @@ with gr.Blocks(css=CSS) as demo:
419
  ],
420
  )
421
 
 
 
 
 
 
 
 
 
 
 
422
  upload.change(
423
  fn=process_upload,
424
  inputs=[upload, model],
425
- outputs=[document, image, img_clear_button, url_error, output, output_table],
426
  )
427
 
428
  submit.click(
429
  fn=process_path,
430
  inputs=[url, model],
431
- outputs=[document, image, img_clear_button, url_error, output, output_table],
432
  )
433
 
434
  question.submit(
435
  fn=process_question,
436
- inputs=[question, document, model],
437
- outputs=[image, output, output_table],
438
  )
439
 
440
  submit_button.click(
@@ -452,7 +469,7 @@ with gr.Blocks(css=CSS) as demo:
452
  example_image.change(
453
  fn=load_example_document,
454
  inputs=[example_image, example_question, model],
455
- outputs=[document, image, img_clear_button, url_error, output, output_table],
456
  )
457
 
458
  if __name__ == "__main__":
 
6
  import traceback
7
 
8
  import gradio as gr
9
+ from gradio import processing_utils
10
 
11
  import torch
12
  from docquery import pipeline
 
100
  "Payment Terms": ["Payment Terms?"],
101
  }
102
 
 
 
 
103
 
104
+ def empty_table(fields):
105
+ return {"value": [[name, None] for name in fields.keys()], "interactive": False}
106
 
107
+
108
+ def process_document(document, fields, model, error=None):
109
  if document is not None and error is None:
110
+ preview, json_output, table = process_fields(document, fields, model)
111
  return (
112
  document,
113
+ fields,
114
  preview,
115
  gr.update(visible=True),
116
  gr.update(visible=False, value=None),
 
120
  else:
121
  return (
122
  None,
123
+ fields,
124
  None,
125
  gr.update(visible=False),
126
  gr.update(visible=True, value=error) if error is not None else None,
 
132
  def process_path(path, model):
133
  error = None
134
  document = None
135
+ fields = {**FIELDS}
136
  if path:
137
  try:
138
  document = load_document(path)
 
140
  traceback.print_exc()
141
  error = str(e)
142
 
143
+ return process_document(document, fields, model, error)
144
 
145
 
146
  def process_upload(file, model):
 
163
  draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
164
 
165
 
166
+ def process_question(
167
+ question, document, img_gallery, model, fields, output, output_table
168
+ ):
169
  if not question or document is None:
170
  return None, None, None
171
 
172
  text_value = None
173
+ pages = [processing_utils.decode_base64_to_image(p) for p in img_gallery]
174
+ prediction = run_pipeline(model, question, document, 1)
175
+ annotate_page(prediction, pages, document)
176
+
177
+ field_name = question.rstrip("?")
178
+ fields = {**FIELDS, field_name: [question]}
179
+ output[field_name] = prediction
180
+ table = output_table.values.tolist() + [[field_name, prediction.get("answer")]]
 
 
 
 
181
  return (
182
  gr.update(visible=True, value=pages),
183
+ fields,
184
+ output,
185
+ gr.update(value=table, interactive=False),
 
 
186
  )
187
 
188
 
189
+ def process_fields(document, fields, model=list(CHECKPOINTS.keys())[0]):
190
  pages = [x.copy().convert("RGB") for x in document.preview]
191
 
192
  ret = {}
193
  table = []
194
 
195
+ for (field_name, questions) in fields.items():
196
  answers = [run_pipeline(model, q, document, top_k=1) for q in questions]
197
  answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
198
  top = answers[0]
 
208
 
209
 
210
  def load_example_document(img, title, model):
211
+ document = None
212
+ fields = {**FIELDS}
213
  if img is not None:
214
  if title in QUESTION_FILES:
 
215
  document = load_document(QUESTION_FILES[title])
216
  else:
217
  document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
 
 
218
 
219
+ return process_document(document, fields, model)
220
 
221
 
222
  CSS = """
223
  #question input {
224
  font-size: 16px;
225
  }
226
+ #url-textbox, #question-textbox {
227
  padding: 0 !important;
228
  }
229
  #short-upload-box .w-full {
 
326
  )
327
 
328
  document = gr.Variable()
329
+ fields = gr.Variable(value={**FIELDS})
330
  example_question = gr.Textbox(visible=False)
331
  example_image = gr.Image(visible=False)
332
 
 
364
  )
365
 
366
  with gr.Column() as col:
367
+ with gr.Tabs():
368
+ with gr.TabItem("Table"):
369
+ output_table = gr.Dataframe(
370
+ headers=["Field", "Value"], **empty_table(fields.value)
371
+ )
372
+
373
+ with gr.TabItem("JSON"):
374
+ output = gr.JSON(label="Output", visible=False)
375
+
376
  gr.Markdown("## 2. Ask a question")
 
 
 
 
 
 
377
  model = gr.Radio(
378
  choices=list(CHECKPOINTS.keys()),
379
  value=list(CHECKPOINTS.keys())[0],
 
382
  )
383
 
384
  with gr.Row():
385
+ question = gr.Textbox(
386
+ label="Question",
387
+ show_label=False,
388
+ placeholder="e.g. What is the invoice number?",
389
+ lines=1,
390
+ max_lines=1,
391
+ elem_id="question-textbox",
392
+ )
393
+ clear_button = gr.Button("Clear", variant="secondary", visible=False)
394
  submit_button = gr.Button(
395
+ "Add", variant="primary", elem_id="submit-button"
396
  )
 
 
 
 
 
 
397
 
398
  for cb in [img_clear_button, clear_button]:
399
  cb.click(
400
  lambda _: (
401
+ gr.update(visible=False, value=None), # image
402
+ None, # document
403
+ {**FIELDS}, # fields
404
+ gr.update(visible=False, value=None), # output
405
+ gr.update(**empty_table(FIELDS)), # output_table
406
  gr.update(visible=False),
407
  None,
408
  None,
 
414
  outputs=[
415
  image,
416
  document,
417
+ fields,
418
  output,
419
  output_table,
420
  img_clear_button,
 
426
  ],
427
  )
428
 
429
+ submit_outputs = [
430
+ document,
431
+ fields,
432
+ image,
433
+ img_clear_button,
434
+ url_error,
435
+ output,
436
+ output_table,
437
+ ]
438
+
439
  upload.change(
440
  fn=process_upload,
441
  inputs=[upload, model],
442
+ outputs=submit_outputs,
443
  )
444
 
445
  submit.click(
446
  fn=process_path,
447
  inputs=[url, model],
448
+ outputs=submit_outputs,
449
  )
450
 
451
  question.submit(
452
  fn=process_question,
453
+ inputs=[question, document, image, model, fields, output, output_table],
454
+ outputs=[image, fields, output, output_table],
455
  )
456
 
457
  submit_button.click(
 
469
  example_image.change(
470
  fn=load_example_document,
471
  inputs=[example_image, example_question, model],
472
+ outputs=submit_outputs,
473
  )
474
 
475
  if __name__ == "__main__":