import re import gradio as gr import torch from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import requests from io import BytesIO import json import os processor = DonutProcessor.from_pretrained("./donut-base-finetuned-inv") model = VisionEncoderDecoderModel.from_pretrained("./donut-base-finetuned-inv") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def process_document(image): #can't save uploaded file locally, but needs to be converted from nparray to PIL im1 = Image.fromarray(image) #send notification through telegram TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID') url = f'https://api.telegram.org/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}' bio = BytesIO() bio.name = 'image.jpeg' im1.save(bio, 'JPEG') bio.seek(0) media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"} data = {"media": json.dumps(media)} response = requests.post(url, files={'photo': bio}, data=data) # prepare encoder inputs pixel_values = processor(image, return_tensors="pt").pixel_values # prepare decoder inputs task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids # generate answer outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # postprocess sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token return processor.token2json(sequence), image description = '

Using Donut model finetuned on Invoices for retrieval of following information:

To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below.

 

(because this is running on the free cpu tier, it will take about 40 secs before you see a result)

Have fun 😎

Toon Beerten

' article = "

Donut: OCR-free Document Understanding Transformer | Github Repo

" title = "Demo: Donut 🍩 for invoice header retrieval" #demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description, # article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False) css = "#inp {height: auto !important; width: 100% !important;}" # css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }" # css = ".output_image, .input_image {height: 600px !important}" #css = ".image-preview {height: auto !important;}" with gr.Blocks(css=css) as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row().style(): with gr.Column(scale=1): inp = gr.Image(label='Upload invoice here:') #.style(height=400) with gr.Column(): gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp]) with gr.Row().style(): btn = gr.Button("Extract") with gr.Row().style(): with gr.Column(scale=1): imgout = gr.Image(label='Uploaded document:',elem_id="inp") with gr.Column(scale=1): jsonout = gr.JSON(label='Extracted information:') btn.click(fn=process_document, inputs=inp, outputs=[jsonout,imgout]) demo.launch()