File size: 10,282 Bytes
cb4c41f
 
 
 
 
f50ef7b
dea123e
1592bb3
93fe32c
 
 
cb4c41f
a773f5d
 
cb4c41f
 
 
 
be4a233
66dee59
35135c8
 
 
 
 
 
 
 
 
 
429bc8a
 
 
 
 
 
3cc134f
 
429bc8a
 
 
 
17851d3
429bc8a
 
 
 
 
 
2497ea4
 
de95cee
281fc5d
cb4c41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b1dae
 
b7221a3
cb4c41f
82aea69
fceae27
82aea69
93fee22
4fd7d4e
e45a37b
 
de09c5c
f7b4868
e60f9c6
3cc134f
 
7384818
 
 
da509d5
12d54a0
7384818
 
66dee59
f4ccbdc
 
facb0ae
93fee22
 
 
de09c5c
b7221a3
8732834
 
b7221a3
bc05740
 
c1b495f
c8c471f
 
429bc8a
 
 
 
762ba24
12d54a0
bc05740
b7221a3
 
 
d3ffd5d
 
429bc8a
 
 
3cc134f
 
cb4c41f
be4a233
35135c8
be4a233
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import re
import gradio as gr

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os


processor = DonutProcessor.from_pretrained("to-be/donut-base-finetuned-invoices")
model = VisionEncoderDecoderModel.from_pretrained("to-be/donut-base-finetuned-invoices")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def update_status(state):
    if state == "start_or_clear": 
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))
    elif state == "processing":
        state = 'finished_processing' #current state becomes
        return (gr.update(value="",visible=False),gr.update(value="",visible=False))
    elif state == "finished_processing":
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))

def process_document(image,sendimg):

    if sendimg == True:
        im1 = Image.fromarray(image)
    elif sendimg == False:
        im1 = Image.open('./no_image.jpg')
    #keep track of demo count
    resp = requests.get('https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut%2Fdemo&label=demos%20served&labelColor=%23edd239&countColor=%23d9e3f0')
    
    #send notification through telegram
    TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
    CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID')
    url = f'https://149.154.167.220/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}'
    bio = BytesIO()
    bio.name = 'image.jpeg'
    im1.save(bio, 'JPEG')
    bio.seek(0)
    media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"}
    data = {"media": json.dumps(media)}
    try:
        response = requests.post(url, files={'photo': bio}, data=data)
    except:
        print("telegram api error")
    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
          
    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    
    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token

    img2.update(visible=False)
    return processor.token2json(sequence), image

title = '<table align="center" border="0" cellpadding="1" cellspacing="1" ><tbody><tr><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling_small.gif" style="float:right; height:50px; width:50px" /></td><td style="text-align:center"><h1>Demo: invoice header extraction with Donut</h1></td><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling2_small.gif" style="float:left; height:50px; width:50px" /></td></tr></tbody></table>'
paragraph0 = '<p><strong>(update 29/03/2023: for more info, you can read <a href="https://toon-beerten.medium.com/hands-on-document-data-extraction-with-transformer-7130df3b6132">my article on medium</a>)<br />(update 28/04/2023: want to finetune with your own data? Read&nbsp;<a href="https://towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-1-2-b5a826bc2ac3">this article</a>)</strong></p>'
paragraph1 = '<p>Basic idea of the base 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> I finetuned it on invoices. For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a>&nbsp;and the 🤗&nbsp;<a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be &#39;Other&#39;). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately.&nbsp;Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a&nbsp;NVIDIA RTX A4000 for 20k steps with a val_metric of&nbsp;0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
paragraph3 = '<p><strong>Benchmark observations:</strong><br />From all documents in the validation set,&nbsp; 60% of them had all indexes captured correctly.</p><p>Here are the results per index:</p><p style="margin-left:40px"><img alt="" src="https://s3.amazonaws.com/moonup/production/uploads/1677749023966-6335a49ceb6132ca653239a0.png" style="height:70%; width:70%" /></p><p>Some other observations:<br />- when trying with a non invoice document, it&#39;s quite reliably identified as Doctype: &#39;Other&#39;<br />- validation set contained mostly same layout invoices as the train set. If it was validated against completely differently sourced invoices, the results would be different<br />- Document date is able to be recognized across different notations, however, it&#39;s often wrong because the data set was not diverse (as in time span of dates) enough</p>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
#    article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph4 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click &#39;submit&#39;, or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result. On a GPU it takes less than 2 seconds)</em></p><p>&nbsp;</p><p>Have fun&nbsp;😎</p><p>Toon Beerten</p>'
smallprint = '<p>✤&nbsp;<span style="font-size:11px">To get an idea of the usage, you can opt to let me get personally notified via Telegram with the image uploaded. All data will be automatically deleted after 48 hours</span></p>'
css = "#inp {height: auto !important; width: 100% !important;}"
visit_badge = '<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut&labelColor=%23edd239&countColor=%23d9e3f0&style=flat" /></a>'

# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"

#css = ".image-preview {height: auto !important;}"
#css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}')

with gr.Blocks(css=css) as demo:
    state = gr.State(value='start_or_clear')
    
    gr.HTML(title)
    gr.HTML(paragraph0)
    gr.HTML(paragraph1)
    gr.HTML(paragraph2)
    gr.HTML(paragraph3)
    gr.HTML(paragraph4)
    
    with gr.Row().style():
        with gr.Column(scale=1):
            inp = gr.Image(label='Upload invoice here:')   #.style(height=400)          
        with gr.Column(scale=2):
             gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
    with gr.Row().style(equal_height=True,height=200,rounded=False):       
        with gr.Column(scale=1):
            img2 = gr.Image("drinking.gif",label=' ',visible=False).style(rounded=True)
        with gr.Column(scale=2):
            btn = gr.Button(" ↓   Extract   ↓ ")
        with gr.Column(scale=2):
            #img3 = gr.Image("snowangel.gif",label=' ',visible=False).style(rounded=True)
            sendimg = gr.Checkbox(value=True, label="Allow usage data collection for at most 48 hours ✤")
    with gr.Row().style():
        with gr.Column(scale=2):
            imgout = gr.Image(label='Uploaded document:',elem_id="inp")
        with gr.Column(scale=1):
            jsonout = gr.JSON(label='Extracted information:')
    #imgout.clear(fn=update_status,inputs=state,outputs=[img2,img3])    
    #imgout.change(fn=update_status,inputs=state,outputs=[img2,img3])        
    btn.click(fn=process_document, inputs=[inp,sendimg], outputs=[jsonout,imgout])

    gr.HTML(smallprint)
    gr.HTML(visit_badge)
    

demo.launch()