import re import gradio as gr import torch from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import requests from io import BytesIO import json import os processor = DonutProcessor.from_pretrained("to-be/donut-base-finetuned-invoices") model = VisionEncoderDecoderModel.from_pretrained("to-be/donut-base-finetuned-invoices") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def update_status(state): if state == "start_or_clear": state = 'processing' #current state becomes return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True)) elif state == "processing": state = 'finished_processing' #current state becomes return (gr.update(value="",visible=False),gr.update(value="",visible=False)) elif state == "finished_processing": state = 'processing' #current state becomes return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True)) def process_document(image,sendimg): if sendimg == True: im1 = Image.fromarray(image) elif sendimg == False: im1 = Image.open('./no_image.jpg') #keep track of demo count resp = requests.get('https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut%2Fdemo&label=demos%20served&labelColor=%23edd239&countColor=%23d9e3f0') #send notification through telegram TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID') url = f'https://api.telegram.org/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}' bio = BytesIO() bio.name = 'image.jpeg' im1.save(bio, 'JPEG') bio.seek(0) media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"} data = {"media": json.dumps(media)} try: response = requests.post(url, files={'photo': bio}, data=data) except error: print("telegram api error") # prepare encoder inputs pixel_values = processor(image, return_tensors="pt").pixel_values # prepare decoder inputs task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids # generate answer outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # postprocess sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token img2.update(visible=False) return processor.token2json(sequence), image title = '

Demo: invoice header extraction with Donut

' paragraph0 = '

(update 29/03/2023: for more info, you can read my article on medium)
(update 28/04/2023: want to finetune with your own data? Read this article)

' paragraph1 = '

Basic idea of the base 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.
I finetuned it on invoices. For more info, see the original paper and the 🤗 model.

' paragraph2 = '

Training:
The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be 'Other'). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately. Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.
It was trained for about 4 hours on a NVIDIA RTX A4000 for 20k steps with a val_metric of 0.03413819904382196 at the end.
The following indexes were included in the train set:

' paragraph3 = '

Benchmark observations:
From all documents in the validation set,  60% of them had all indexes captured correctly.

Here are the results per index:

Some other observations:
- when trying with a non invoice document, it's quite reliably identified as Doctype: 'Other'
- validation set contained mostly same layout invoices as the train set. If it was validated against completely differently sourced invoices, the results would be different
- Document date is able to be recognized across different notations, however, it's often wrong because the data set was not diverse (as in time span of dates) enough

' #demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description, # article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False) paragraph4 = '

Try it out:
To use it, simply upload your image and click 'submit', or click one of the examples to load them.
(because this is running on the free cpu tier, it will take about 40 secs before you see a result. On a GPU it takes less than 2 seconds)

 

Have fun 😎

Toon Beerten

' smallprint = '

✤ To get an idea of the usage, you can opt to let me get personally notified via Telegram with the image uploaded. All data will be automatically deleted after 48 hours

' css = "#inp {height: auto !important; width: 100% !important;}" visit_badge = '' # css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }" # css = ".output_image, .input_image {height: 600px !important}" #css = ".image-preview {height: auto !important;}" #css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}') with gr.Blocks(css=css) as demo: state = gr.State(value='start_or_clear') gr.HTML(title) gr.HTML(paragraph0) gr.HTML(paragraph1) gr.HTML(paragraph2) gr.HTML(paragraph3) gr.HTML(paragraph4) with gr.Row().style(): with gr.Column(scale=1): inp = gr.Image(label='Upload invoice here:') #.style(height=400) with gr.Column(scale=2): gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:') with gr.Row().style(equal_height=True,height=200,rounded=False): with gr.Column(scale=1): img2 = gr.Image("drinking.gif",label=' ',visible=False).style(rounded=True) with gr.Column(scale=2): btn = gr.Button(" ↓ Extract ↓ ") with gr.Column(scale=2): #img3 = gr.Image("snowangel.gif",label=' ',visible=False).style(rounded=True) sendimg = gr.Checkbox(value=True, label="Allow usage data collection for at most 48 hours ✤") with gr.Row().style(): with gr.Column(scale=2): imgout = gr.Image(label='Uploaded document:',elem_id="inp") with gr.Column(scale=1): jsonout = gr.JSON(label='Extracted information:') #imgout.clear(fn=update_status,inputs=state,outputs=[img2,img3]) #imgout.change(fn=update_status,inputs=state,outputs=[img2,img3]) btn.click(fn=process_document, inputs=[inp,sendimg], outputs=[jsonout,imgout]) gr.HTML(smallprint) gr.HTML(visit_badge) demo.launch()