File size: 7,658 Bytes
cb4c41f
 
 
 
 
f50ef7b
dea123e
1592bb3
93fe32c
 
 
cb4c41f
789ac0d
 
cb4c41f
 
 
 
be4a233
 
35135c8
 
 
 
 
 
 
 
 
 
 
cb4c41f
46b1dae
ee93a12
6b45627
5c58f13
6b45627
f50ef7b
21ab7aa
 
 
1592bb3
 
9485a97
1592bb3
4fa7e16
 
 
dea123e
cb4c41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b1dae
 
b7221a3
cb4c41f
bd7f10b
93fee22
 
e45a37b
 
93fee22
e45a37b
df954f0
7384818
 
 
da509d5
7384818
 
 
f4ccbdc
 
93fee22
 
 
b7221a3
8732834
 
b7221a3
bc05740
 
c1b495f
b73fff5
46b1dae
b73fff5
0bba431
b73fff5
c1b495f
93fee22
bc05740
b7221a3
 
 
0f0cb7d
b7221a3
cb4c41f
be4a233
35135c8
be4a233
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import re
import gradio as gr

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os


processor = DonutProcessor.from_pretrained("./donut-base-finetuned-inv")
model = VisionEncoderDecoderModel.from_pretrained("./donut-base-finetuned-inv")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

state = "start_or_clear"

def update_status(inp):
    if state == "start_or_clear": 
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))
    elif state == "processing":
        state = 'finished_processing' #current state becomes
        return (gr.update(value="",visible=False),gr.update(value="",visible=False))
    elif state == "finished_processing":
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))

def process_document(image):
    img2.update(visible=True)
    gr.update()
    #can't save uploaded file locally, but needs to be converted from nparray to PIL
    im1 = Image.fromarray(image)
    
    #send notification through telegram
    TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
    CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID')
    url = f'https://api.telegram.org/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}'
    bio = BytesIO()
    bio.name = 'image.jpeg'
    im1.save(bio, 'JPEG')
    bio.seek(0)
    media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"}
    data = {"media": json.dumps(media)}
    response = requests.post(url, files={'photo': bio}, data=data)
    
    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
          
    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    
    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token

    img2.update(visible=False)
    return processor.token2json(sequence), image

title = '<table align="center" border="0" cellpadding="1" cellspacing="1" style="width:100pc"><tbody><tr><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling_small.gif" style="float:right; height:50px; width:50px" /></td><td style="text-align:center"><h1>&nbsp; &nbsp;Welcome</h1></td><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling2_small.gif" style="float:left; height:50px; width:50px" /></td></tr></tbody></table>'
paragraph1 = '<p>Basic idea of this 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a>&nbsp;and the 🤗&nbsp;<a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be &#39;Other&#39;). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately.&nbsp;Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a&nbsp;NVIDIA RTX A4000 for 20k steps with a val_metric of&nbsp;0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
#    article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph3 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click &#39;submit&#39;, or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result)</em></p><p>&nbsp;</p><p>Have fun&nbsp;😎</p><p>Toon Beerten</p>'

css = "#inp {height: auto !important; width: 100% !important;}"
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"

#css = ".image-preview {height: auto !important;}"


with gr.Blocks(css=css) as demo:
    
    gr.HTML(title)
    gr.HTML(paragraph1)
    gr.HTML(paragraph2)
    gr.HTML(paragraph3)
    
    with gr.Row().style():
        with gr.Column(scale=1):
            inp = gr.Image(label='Upload invoice here:')   #.style(height=400)          
        with gr.Column(scale=2):
             gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
    with gr.Row().style(equal_height=True,height=200,rounded=False):       
        with gr.Column(scale=1):
            img2 = gr.Image("drinking.gif",label=' ',visible=False).style(rounded=True)
        with gr.Column(scale=2):
            btn = gr.Button("↓   Extract   ↓")
        with gr.Column(scale=1):
            img3 = gr.Image("snowangel.gif",label=' ').style(rounded=True)
    with gr.Row(css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}').style():
        with gr.Column(scale=2):
            imgout = gr.Image(label='Uploaded document:',elem_id="inp")
        with gr.Column(scale=1):
            jsonout = gr.JSON(label='Extracted information:')
    imgout.change(fn=update_status,inputs=inp,outputs=[img2,img3])        
    btn.click(fn=process_document, inputs=inp, outputs=[jsonout,imgout])

demo.launch()