File size: 7,115 Bytes
cb4c41f
 
 
 
 
f50ef7b
dea123e
1592bb3
93fe32c
 
 
cb4c41f
789ac0d
 
cb4c41f
 
 
 
be4a233
66dee59
35135c8
 
 
 
 
 
 
 
 
 
cb4c41f
650ae3b
cb4c41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b1dae
 
b7221a3
cb4c41f
8bfd542
93fee22
 
e45a37b
 
f6bce53
e45a37b
e60f9c6
7384818
 
 
da509d5
12d54a0
7384818
 
66dee59
f4ccbdc
 
93fee22
 
 
b7221a3
8732834
 
b7221a3
bc05740
 
c1b495f
c8c471f
 
b73fff5
0bba431
c8c471f
 
12d54a0
bc05740
b7221a3
 
 
d3ffd5d
 
b7221a3
cb4c41f
be4a233
35135c8
be4a233
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import re
import gradio as gr

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os


processor = DonutProcessor.from_pretrained("./donut-base-finetuned-inv")
model = VisionEncoderDecoderModel.from_pretrained("./donut-base-finetuned-inv")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def update_status(state):
    if state == "start_or_clear": 
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))
    elif state == "processing":
        state = 'finished_processing' #current state becomes
        return (gr.update(value="",visible=False),gr.update(value="",visible=False))
    elif state == "finished_processing":
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))

def process_document(image):

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
          
    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    
    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token

    img2.update(visible=False)
    return processor.token2json(sequence), image

title = '<table align="center" border="0" cellpadding="1" cellspacing="1" ><tbody><tr><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling_small.gif" style="float:right; height:50px; width:50px" /></td><td style="text-align:center"><h1>&nbsp; &nbsp;Welcome</h1></td><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling2_small.gif" style="float:left; height:50px; width:50px" /></td></tr></tbody></table>'
paragraph1 = '<p>Basic idea of this 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a>&nbsp;and the 🤗&nbsp;<a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be &#39;Other&#39;). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately.&nbsp;Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a&nbsp;NVIDIA RTX A4000 for 20k steps with a val_metric of&nbsp;0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
#    article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph3 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click &#39;submit&#39;, or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result. On a GPU it takes less than 2 seconds)</em></p><p>&nbsp;</p><p>Have fun&nbsp;😎</p><p>Toon Beerten</p>'

css = "#inp {height: auto !important; width: 100% !important;}"
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"

#css = ".image-preview {height: auto !important;}"
#css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}')

with gr.Blocks(css=css) as demo:
    state = gr.State(value='start_or_clear')
    
    gr.HTML(title)
    gr.HTML(paragraph1)
    gr.HTML(paragraph2)
    gr.HTML(paragraph3)
    
    with gr.Row().style():
        with gr.Column(scale=1):
            inp = gr.Image(label='Upload invoice here:')   #.style(height=400)          
        with gr.Column(scale=2):
             gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
    with gr.Row().style(equal_height=True,height=200,rounded=False):       
        with gr.Column(scale=1):
            img2 = gr.Image("drinking.gif",label=' ',visible=False).style(rounded=True)
        with gr.Column(scale=1):
            btn = gr.Button("↓   Extract   ↓")
        with gr.Column(scale=1):
            img3 = gr.Image("snowangel.gif",label=' ',visible=False).style(rounded=True)
    with gr.Row().style():
        with gr.Column(scale=2):
            imgout = gr.Image(label='Uploaded document:',elem_id="inp")
        with gr.Column(scale=1):
            jsonout = gr.JSON(label='Extracted information:')
    #imgout.clear(fn=update_status,inputs=state,outputs=[img2,img3])    
    #imgout.change(fn=update_status,inputs=state,outputs=[img2,img3])        
    btn.click(fn=process_document, inputs=inp, outputs=[jsonout,imgout])

demo.launch()