ajimeno's picture
First commit
c12418b
import torch
import streamlit as st
import os
from PIL import Image
from io import BytesIO
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor
task_prompt = "<s_unstructured-invoices>"
def run_prediction(sample):
global pretrained_model, processor, task_prompt
if isinstance(sample, dict):
# prepare inputs
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
else: # sample is an image
# prepare encoder inputs
pixel_values = processor(sample, return_tensors="pt").pixel_values
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
outputs = pretrained_model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device)
)
# process output
prediction = processor.token2json(processor.batch_decode(outputs)[0])
# load reference target
if isinstance(sample, dict):
target = processor.token2json(sample["target_sequence"])
else:
target = "<not_provided>"
return prediction, target
logo = Image.open("./img/rsz_unstructured_logo.png")
st.image(logo)
st.markdown('''
### Invoice Parser
This is an OCR-free Document Understanding Transformer. It was fine-tuned with 1000 invoice images -> RVL-CDIP dataset.
The original implementation can be found on [here](https://github.com/clovaai/donut).
At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines.
Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community.
''')
image_upload = None
photo = None
with st.sidebar:
# file upload
uploaded_file = st.file_uploader("Upload an invoice")
if uploaded_file is not None:
# To read file as bytes:
image_bytes_data = uploaded_file.getvalue()
image_upload = Image.open(BytesIO(image_bytes_data)) #.frombytes('RGBA', (128,128), image_bytes_data, 'raw')
# st.write(bytes_data)
col1, col2 = st.columns(2)
if image_upload:
image = image_upload
else:
image = Image.open(f"./img/4fabfaab-1299.png")
with col1:
st.image(image, caption='Your target invoice')
with st.spinner(f'baking the invoice ...'):
processor = DonutProcessor.from_pretrained("unstructuredio/donut-invoices", max_length=1200, use_auth_token=os.environ['TOKEN'])
pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-invoices", max_length=1200, use_auth_token=os.environ['TOKEN'])
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
with col2:
st.info(f'Parsing invoice')
parsed_info, _ = run_prediction(image.convert("RGB"))
st.text(f'\nInvoice Summary:')
st.json(parsed_info)