Ankur Goyal commited on
Commit
2af0878
1 Parent(s): 8bd074d

Plumbing for fields

Browse files
Files changed (1) hide show
  1. app.py +57 -33
app.py CHANGED
@@ -21,9 +21,7 @@ def ensure_list(x):
21
 
22
 
23
  CHECKPOINTS = {
24
- "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
25
- "LayoutLMv1 for Invoices 💸": "impira/layoutlm-invoices",
26
- "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
27
  }
28
 
29
  PIPELINES = {}
@@ -71,10 +69,10 @@ def normalize_bbox(box, width, height, padding=0.005):
71
  return [min_x * width, min_y * height, max_x * width, max_y * height]
72
 
73
 
74
- examples = [
75
  [
76
  "invoice.png",
77
- "What is the invoice number?",
78
  ],
79
  [
80
  "contract.jpeg",
@@ -86,8 +84,12 @@ examples = [
86
  ],
87
  ]
88
 
89
- question_files = {
90
- "What are net sales for 2020?": "statement.pdf",
 
 
 
 
91
  }
92
 
93
 
@@ -135,6 +137,19 @@ def process_upload(file):
135
  colors = ["#64A087", "green", "black"]
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
139
  if not question or document is None:
140
  return None, None, None
@@ -150,16 +165,7 @@ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
150
  # prediction for now
151
  break
152
 
153
- if "word_ids" in p:
154
- image = pages[p["page"]]
155
- draw = ImageDraw.Draw(image, "RGBA")
156
- word_boxes = lift_word_boxes(document, p["page"])
157
- x1, y1, x2, y2 = normalize_bbox(
158
- expand_bbox([word_boxes[i] for i in p["word_ids"]]),
159
- image.width,
160
- image.height,
161
- )
162
- draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
163
 
164
  return (
165
  gr.update(visible=True, value=pages),
@@ -171,16 +177,33 @@ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
171
  )
172
 
173
 
174
- def load_example_document(img, question, model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  if img is not None:
176
- if question in question_files:
177
- document = load_document(question_files[question])
 
178
  else:
179
- document = ImageDocument(Image.fromarray(img), get_ocr_reader())
180
- preview, answer, answer_text = process_question(question, document, model)
181
- return document, question, preview, gr.update(visible=True), answer, answer_text
182
  else:
183
- return None, None, None, gr.update(visible=False), None, None
184
 
185
 
186
  CSS = """
@@ -280,12 +303,13 @@ gradio-app h2, .gradio-app h2 {
280
  with gr.Blocks(css=CSS) as demo:
281
  gr.Markdown("# DocQuery: Document Query Engine")
282
  gr.Markdown(
283
- "DocQuery (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=docquery_space))"
284
- " uses LayoutLMv1 fine-tuned on DocVQA, a document visual question"
285
- " answering dataset, as well as SQuAD, which boosts its English-language comprehension."
286
- " To use it, simply upload an image or PDF, type a question, and click 'submit', or "
287
- " click one of the examples to load them."
288
- " DocQuery is MIT-licensed and available on [Github](https://github.com/impira/docquery)."
 
289
  )
290
 
291
  document = gr.Variable()
@@ -295,7 +319,7 @@ with gr.Blocks(css=CSS) as demo:
295
  with gr.Row(equal_height=True):
296
  with gr.Column():
297
  with gr.Row():
298
- gr.Markdown("## 1. Select a file", elem_id="select-a-file")
299
  img_clear_button = gr.Button(
300
  "Clear", variant="secondary", elem_id="file-clear", visible=False
301
  )
@@ -321,7 +345,7 @@ with gr.Blocks(css=CSS) as demo:
321
  gr.Markdown("— or —")
322
  upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
323
  gr.Examples(
324
- examples=examples,
325
  inputs=[example_image, example_question],
326
  )
327
 
@@ -411,7 +435,7 @@ with gr.Blocks(css=CSS) as demo:
411
  example_image.change(
412
  fn=load_example_document,
413
  inputs=[example_image, example_question, model],
414
- outputs=[document, question, image, img_clear_button, output, output_text],
415
  )
416
 
417
  if __name__ == "__main__":
 
21
 
22
 
23
  CHECKPOINTS = {
24
+ "LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
 
 
25
  }
26
 
27
  PIPELINES = {}
 
69
  return [min_x * width, min_y * height, max_x * width, max_y * height]
70
 
71
 
72
+ EXAMPLES = [
73
  [
74
  "invoice.png",
75
+ "Invoice 1",
76
  ],
77
  [
78
  "contract.jpeg",
 
84
  ],
85
  ]
86
 
87
+ QUESTION_FILES = {}
88
+
89
+ FIELDS = {
90
+ "Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
91
+ "Vendor Address": ["Vendor Address?"],
92
+ "Invoice Total": ["Invoice Total?"],
93
  }
94
 
95
 
 
137
  colors = ["#64A087", "green", "black"]
138
 
139
 
140
+ def annotate_page(prediction, pages, document):
141
+ if "word_ids" in prediction:
142
+ image = pages[prediction["page"]]
143
+ draw = ImageDraw.Draw(image, "RGBA")
144
+ word_boxes = lift_word_boxes(document, prediction["page"])
145
+ x1, y1, x2, y2 = normalize_bbox(
146
+ expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
147
+ image.width,
148
+ image.height,
149
+ )
150
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
151
+
152
+
153
  def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
154
  if not question or document is None:
155
  return None, None, None
 
165
  # prediction for now
166
  break
167
 
168
+ annotate_page(p, pages, document)
 
 
 
 
 
 
 
 
 
169
 
170
  return (
171
  gr.update(visible=True, value=pages),
 
177
  )
178
 
179
 
180
+ def process_fields(document, model=list(CHECKPOINTS.keys())[0]):
181
+ pages = [x.copy().convert("RGB") for x in document.preview]
182
+ ret = {}
183
+
184
+ for (field_name, questions) in FIELDS.items():
185
+ answers = [run_pipeline(model, q, document, top_k=1) for q in questions]
186
+ answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
187
+ top = answers[0]
188
+ annotate_page(top, pages, document)
189
+ ret[field_name] = top
190
+ return (
191
+ gr.update(visible=True, value=pages),
192
+ gr.update(visible=True, value=ret),
193
+ )
194
+
195
+
196
+ def load_example_document(img, title, model):
197
  if img is not None:
198
+ if title in QUESTION_FILES:
199
+ print("using document")
200
+ document = load_document(QUESTION_FILES[title])
201
  else:
202
+ document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
203
+ preview, answer = process_fields(document, model)
204
+ return document, preview, gr.update(visible=True), answer
205
  else:
206
+ return None, None, gr.update(visible=False), None
207
 
208
 
209
  CSS = """
 
303
  with gr.Blocks(css=CSS) as demo:
304
  gr.Markdown("# DocQuery: Document Query Engine")
305
  gr.Markdown(
306
+ "DocQuery (created by [Impira](https://impira.com)) uses LayoutLMv1 fine-tuned on an invoice dataset"
307
+ " as well as DocVQA and SQuAD, which boot its general comprehension skills. The model is an enhanced"
308
+ " QA architecture that supports selecting blocks of text which may be non-consecutive, which is a major"
309
+ " issue when dealing with invoice documents (e.g. addresses)."
310
+ " To use it, simply upload an image or PDF invoice and the model will predict values for several fields."
311
+ " You can also create additional fields by simply typing in a question."
312
+ " DocQuery is available on [Github](https://github.com/impira/docquery)."
313
  )
314
 
315
  document = gr.Variable()
 
319
  with gr.Row(equal_height=True):
320
  with gr.Column():
321
  with gr.Row():
322
+ gr.Markdown("## 1. Select an invoice", elem_id="select-a-file")
323
  img_clear_button = gr.Button(
324
  "Clear", variant="secondary", elem_id="file-clear", visible=False
325
  )
 
345
  gr.Markdown("— or —")
346
  upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
347
  gr.Examples(
348
+ examples=EXAMPLES,
349
  inputs=[example_image, example_question],
350
  )
351
 
 
435
  example_image.change(
436
  fn=load_example_document,
437
  inputs=[example_image, example_question, model],
438
+ outputs=[document, image, img_clear_button, output],
439
  )
440
 
441
  if __name__ == "__main__":