padalavinaybhushan commited on
Commit
0e5d606
β€’
1 Parent(s): 60a9256
Files changed (9) hide show
  1. acze.png +0 -0
  2. acze_tech.pdf +0 -0
  3. acze_tech.png +0 -0
  4. app.py +497 -0
  5. invoice.png +0 -0
  6. north_sea.pdf +0 -0
  7. north_sea.png +0 -0
  8. packages.txt +2 -0
  9. requirements.txt +4 -0
acze.png ADDED
acze_tech.pdf ADDED
The diff for this file is too large to render. See raw diff
 
acze_tech.png ADDED
app.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from PIL import Image, ImageDraw
6
+ import traceback
7
+
8
+ import gradio as gr
9
+ from gradio import processing_utils
10
+
11
+ import torch
12
+ from docquery import pipeline
13
+ from docquery.document import load_bytes, load_document, ImageDocument
14
+ from docquery.ocr_reader import get_ocr_reader
15
+
16
+
17
+ def ensure_list(x):
18
+ if isinstance(x, list):
19
+ return x
20
+ else:
21
+ return [x]
22
+
23
+
24
+ CHECKPOINTS = {
25
+ "LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
26
+ }
27
+
28
+ PIPELINES = {}
29
+
30
+
31
+ def construct_pipeline(task, model):
32
+ global PIPELINES
33
+ if model in PIPELINES:
34
+ return PIPELINES[model]
35
+
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
38
+ PIPELINES[model] = ret
39
+ return ret
40
+
41
+
42
+ def run_pipeline(model, question, document, top_k):
43
+ pipeline = construct_pipeline("document-question-answering", model)
44
+ return pipeline(question=question, **document.context, top_k=top_k)
45
+
46
+
47
+ # TODO: Move into docquery
48
+ # TODO: Support words past the first page (or window?)
49
+ def lift_word_boxes(document, page):
50
+ return document.context["image"][page][1]
51
+
52
+
53
+ def expand_bbox(word_boxes):
54
+ if len(word_boxes) == 0:
55
+ return None
56
+
57
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
58
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
59
+ return [min_x, min_y, max_x, max_y]
60
+
61
+
62
+ # LayoutLM boxes are normalized to 0, 1000
63
+ def normalize_bbox(box, width, height, padding=0.005):
64
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
65
+ if padding != 0:
66
+ min_x = max(0, min_x - padding)
67
+ min_y = max(0, min_y - padding)
68
+ max_x = min(max_x + padding, 1)
69
+ max_y = min(max_y + padding, 1)
70
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
71
+
72
+
73
+ EXAMPLES = [
74
+ [
75
+ "acze_tech.png",
76
+ "Tech Invoice",
77
+ ],
78
+ [
79
+ "acze.png",
80
+ "Commercial Goods Invoice",
81
+ ],
82
+ [
83
+ "north_sea.png",
84
+ "Energy Invoice",
85
+ ],
86
+ ]
87
+
88
+ QUESTION_FILES = {
89
+ "Tech Invoice": "acze_tech.pdf",
90
+ "Energy Invoice": "north_sea.pdf",
91
+ }
92
+
93
+ for q in QUESTION_FILES.keys():
94
+ assert any(x[1] == q for x in EXAMPLES)
95
+
96
+ FIELDS = {
97
+ "Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
98
+ "Vendor Address": ["Vendor Address?"],
99
+ "Customer Name": ["Customer Name?"],
100
+ "Customer Address": ["Customer Address?"],
101
+ "Invoice Number": ["Invoice Number?"],
102
+ "Invoice Date": ["Invoice Date?"],
103
+ "Due Date": ["Due Date?"],
104
+ "Subtotal": ["Subtotal?"],
105
+ "Total Tax": ["Total Tax?"],
106
+ "Invoice Total": ["Invoice Total?"],
107
+ "Amount Due": ["Amount Due?"],
108
+ "Payment Terms": ["Payment Terms?"],
109
+ "Remit To Name": ["Remit To Name?"],
110
+ "Remit To Address": ["Remit To Address?"],
111
+ }
112
+
113
+
114
+ def empty_table(fields):
115
+ return {"value": [[name, None] for name in fields.keys()], "interactive": False}
116
+
117
+
118
+ def process_document(document, fields, model, error=None):
119
+ if document is not None and error is None:
120
+ preview, json_output, table = process_fields(document, fields, model)
121
+ return (
122
+ document,
123
+ fields,
124
+ preview,
125
+ gr.update(visible=True),
126
+ gr.update(visible=False, value=None),
127
+ json_output,
128
+ table,
129
+ )
130
+ else:
131
+ return (
132
+ None,
133
+ fields,
134
+ None,
135
+ gr.update(visible=False),
136
+ gr.update(visible=True, value=error) if error is not None else None,
137
+ None,
138
+ gr.update(**empty_table(fields)),
139
+ )
140
+
141
+
142
+ def process_path(path, fields, model):
143
+ error = None
144
+ document = None
145
+ if path:
146
+ try:
147
+ document = load_document(path)
148
+ except Exception as e:
149
+ traceback.print_exc()
150
+ error = str(e)
151
+
152
+ return process_document(document, fields, model, error)
153
+
154
+
155
+ def process_upload(file, fields, model):
156
+ return process_path(file.name if file else None, fields, model)
157
+
158
+
159
+ colors = ["#64A087", "green", "black"]
160
+
161
+
162
+ def annotate_page(prediction, pages, document):
163
+ if prediction is not None and "word_ids" in prediction:
164
+ image = pages[prediction["page"]]
165
+ draw = ImageDraw.Draw(image, "RGBA")
166
+ word_boxes = lift_word_boxes(document, prediction["page"])
167
+ x1, y1, x2, y2 = normalize_bbox(
168
+ expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
169
+ image.width,
170
+ image.height,
171
+ )
172
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
173
+
174
+
175
+ def process_question(
176
+ question, document, img_gallery, model, fields, output, output_table
177
+ ):
178
+ field_name = question
179
+ if field_name is not None:
180
+ fields = {field_name: [question], **fields}
181
+
182
+ if not question or document is None:
183
+ return None, document, fields, output, gr.update(value=output_table)
184
+
185
+ text_value = None
186
+ pages = [processing_utils.decode_base64_to_image(p) for p in img_gallery]
187
+ prediction = run_pipeline(model, question, document, 1)
188
+ annotate_page(prediction, pages, document)
189
+
190
+ output = {field_name: prediction, **output}
191
+ table = [[field_name, prediction.get("answer")]] + output_table.values.tolist()
192
+ return (
193
+ None,
194
+ gr.update(visible=True, value=pages),
195
+ fields,
196
+ output,
197
+ gr.update(value=table, interactive=False),
198
+ )
199
+
200
+
201
+ def process_fields(document, fields, model=list(CHECKPOINTS.keys())[0]):
202
+ pages = [x.copy().convert("RGB") for x in document.preview]
203
+
204
+ ret = {}
205
+ table = []
206
+
207
+ for (field_name, questions) in fields.items():
208
+ answers = [
209
+ a
210
+ for q in questions
211
+ for a in ensure_list(run_pipeline(model, q, document, top_k=1))
212
+ if a.get("score", 1) > 0.5
213
+ ]
214
+ answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
215
+ top = answers[0] if len(answers) > 0 else None
216
+ annotate_page(top, pages, document)
217
+ ret[field_name] = top
218
+ table.append([field_name, top.get("answer") if top is not None else None])
219
+
220
+ return (
221
+ gr.update(visible=True, value=pages),
222
+ gr.update(visible=True, value=ret),
223
+ table
224
+ )
225
+
226
+
227
+ def load_example_document(img, title, fields, model):
228
+ document = None
229
+ if img is not None:
230
+ if title in QUESTION_FILES:
231
+ document = load_document(QUESTION_FILES[title])
232
+ else:
233
+ document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
234
+
235
+ return process_document(document, fields, model)
236
+
237
+
238
+ CSS = """
239
+ #question input {
240
+ font-size: 16px;
241
+ }
242
+ #url-textbox, #question-textbox {
243
+ padding: 0 !important;
244
+ }
245
+ #short-upload-box .w-full {
246
+ min-height: 10rem !important;
247
+ }
248
+ /* I think something like this can be used to re-shape
249
+ * the table
250
+ */
251
+ /*
252
+ .gr-samples-table tr {
253
+ display: inline;
254
+ }
255
+ .gr-samples-table .p-2 {
256
+ width: 100px;
257
+ }
258
+ */
259
+ #select-a-file {
260
+ width: 100%;
261
+ }
262
+ #file-clear {
263
+ padding-top: 2px !important;
264
+ padding-bottom: 2px !important;
265
+ padding-left: 8px !important;
266
+ padding-right: 8px !important;
267
+ margin-top: 10px;
268
+ }
269
+ .gradio-container .gr-button-primary {
270
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
271
+ border: 1px solid #B0DCCC;
272
+ border-radius: 8px;
273
+ color: #1B8700;
274
+ }
275
+ .gradio-container.dark button#submit-button {
276
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
277
+ border: 1px solid #B0DCCC;
278
+ border-radius: 8px;
279
+ color: #1B8700
280
+ }
281
+
282
+ table.gr-samples-table tr td {
283
+ border: none;
284
+ outline: none;
285
+ }
286
+
287
+ table.gr-samples-table tr td:first-of-type {
288
+ width: 0%;
289
+ }
290
+
291
+ div#short-upload-box div.absolute {
292
+ display: none !important;
293
+ }
294
+
295
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
296
+ gap: 0px 2%;
297
+ }
298
+
299
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
300
+ gap: 0px;
301
+ }
302
+
303
+ gradio-app h2, .gradio-app h2 {
304
+ padding-top: 10px;
305
+ }
306
+
307
+ #answer {
308
+ overflow-y: scroll;
309
+ color: white;
310
+ background: #666;
311
+ border-color: #666;
312
+ font-size: 20px;
313
+ font-weight: bold;
314
+ }
315
+
316
+ #answer span {
317
+ color: white;
318
+ }
319
+
320
+ #answer textarea {
321
+ color:white;
322
+ background: #777;
323
+ border-color: #777;
324
+ font-size: 18px;
325
+ }
326
+
327
+ #url-error input {
328
+ color: red;
329
+ }
330
+
331
+ #results-table {
332
+ max-height: 600px;
333
+ overflow-y: scroll;
334
+ }
335
+
336
+ """
337
+
338
+ with gr.Blocks(css=CSS) as demo:
339
+ gr.Markdown("# DocQuery for Invoices")
340
+ gr.Markdown(
341
+ "DocQuery (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=invoices_space))"
342
+ " uses LayoutLMv1 fine-tuned on an invoice dataset"
343
+ " as well as DocVQA and SQuAD, which boot its general comprehension skills. The model is an enhanced"
344
+ " QA architecture that supports selecting blocks of text which may be non-consecutive, which is a major"
345
+ " issue when dealing with invoice documents (e.g. addresses)."
346
+ " To use it, simply upload an image or PDF invoice and the model will predict values for several fields."
347
+ " You can also create additional fields by simply typing in a question."
348
+ " DocQuery is available on [Github](https://github.com/impira/docquery)."
349
+ )
350
+
351
+ document = gr.Variable()
352
+ fields = gr.Variable(value={**FIELDS})
353
+ example_question = gr.Textbox(visible=False)
354
+ example_image = gr.Image(visible=False)
355
+
356
+ with gr.Row(equal_height=True):
357
+ with gr.Column():
358
+ with gr.Row():
359
+ gr.Markdown("## Select an invoice", elem_id="select-a-file")
360
+ img_clear_button = gr.Button(
361
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
362
+ )
363
+ image = gr.Gallery(visible=False)
364
+ with gr.Row(equal_height=True):
365
+ with gr.Column():
366
+ with gr.Row():
367
+ url = gr.Textbox(
368
+ show_label=False,
369
+ placeholder="URL",
370
+ lines=1,
371
+ max_lines=1,
372
+ elem_id="url-textbox",
373
+ )
374
+ submit = gr.Button("Get")
375
+ url_error = gr.Textbox(
376
+ visible=False,
377
+ elem_id="url-error",
378
+ max_lines=1,
379
+ interactive=False,
380
+ label="Error",
381
+ )
382
+ gr.Markdown("β€” or β€”")
383
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
384
+ gr.Examples(
385
+ examples=EXAMPLES,
386
+ inputs=[example_image, example_question],
387
+ )
388
+
389
+ with gr.Column() as col:
390
+ gr.Markdown("## Results")
391
+ with gr.Tabs():
392
+ with gr.TabItem("Table"):
393
+ output_table = gr.Dataframe(
394
+ headers=["Field", "Value"],
395
+ **empty_table(fields.value),
396
+ elem_id="results-table"
397
+ )
398
+
399
+ with gr.TabItem("JSON"):
400
+ output = gr.JSON(label="Output", visible=True)
401
+
402
+ model = gr.Radio(
403
+ choices=list(CHECKPOINTS.keys()),
404
+ value=list(CHECKPOINTS.keys())[0],
405
+ label="Model",
406
+ visible=False,
407
+ )
408
+
409
+ gr.Markdown("### Ask a question")
410
+ with gr.Row():
411
+ question = gr.Textbox(
412
+ label="Question",
413
+ show_label=False,
414
+ placeholder="e.g. What is the invoice number?",
415
+ lines=1,
416
+ max_lines=1,
417
+ elem_id="question-textbox",
418
+ )
419
+ clear_button = gr.Button("Clear", variant="secondary", visible=False)
420
+ submit_button = gr.Button(
421
+ "Add", variant="primary", elem_id="submit-button"
422
+ )
423
+
424
+ for cb in [img_clear_button, clear_button]:
425
+ cb.click(
426
+ lambda _: (
427
+ gr.update(visible=False, value=None), # image
428
+ None, # document
429
+ # {**FIELDS}, # fields
430
+ gr.update(value=None), # output
431
+ gr.update(**empty_table(fields.value)), # output_table
432
+ gr.update(visible=False),
433
+ None,
434
+ None,
435
+ None,
436
+ gr.update(visible=False, value=None),
437
+ None,
438
+ ),
439
+ inputs=clear_button,
440
+ outputs=[
441
+ image,
442
+ document,
443
+ # fields,
444
+ output,
445
+ output_table,
446
+ img_clear_button,
447
+ example_image,
448
+ upload,
449
+ url,
450
+ url_error,
451
+ question,
452
+ ],
453
+ )
454
+
455
+ submit_outputs = [
456
+ document,
457
+ fields,
458
+ image,
459
+ img_clear_button,
460
+ url_error,
461
+ output,
462
+ output_table,
463
+ ]
464
+
465
+ upload.change(
466
+ fn=process_upload,
467
+ inputs=[upload, fields, model],
468
+ outputs=submit_outputs,
469
+ )
470
+
471
+ submit.click(
472
+ fn=process_path,
473
+ inputs=[url, fields, model],
474
+ outputs=submit_outputs,
475
+ )
476
+
477
+ for action in [question.submit, submit_button.click]:
478
+ action(
479
+ fn=process_question,
480
+ inputs=[question, document, image, model, fields, output, output_table],
481
+ outputs=[question, image, fields, output, output_table],
482
+ )
483
+
484
+ # model.change(
485
+ # process_question,
486
+ # inputs=[question, document, model],
487
+ # outputs=[image, output, output_table],
488
+ # )
489
+
490
+ example_image.change(
491
+ fn=load_example_document,
492
+ inputs=[example_image, example_question, fields, model],
493
+ outputs=submit_outputs,
494
+ )
495
+
496
+ if __name__ == "__main__":
497
+ demo.launch(enable_queue=False)
invoice.png ADDED
north_sea.pdf ADDED
Binary file (70.9 kB). View file
 
north_sea.png ADDED
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
2
+ git+https://github.com/impira/docquery.git@8d92692c36f63ef652f3c84cccedd5674ee7b383
3
+ sentencepiece
4
+ torch