File size: 6,645 Bytes
6581de9
5279e45
6581de9
 
cd76fe5
43a5321
6581de9
 
43a5321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb030c
43a5321
 
 
 
 
 
 
 
 
 
6581de9
 
 
f3c6bfe
333e730
fde368a
6527a47
f958a5b
fde368a
5534f9c
1d18244
5534f9c
fde368a
4bf6412
 
a605f4c
df376d6
4bf6412
 
f958a5b
2294783
9ccb695
4bf6412
f9ce3f3
d6a46a2
 
 
 
cd76fe5
f9ce3f3
 
df376d6
e67b731
df376d6
 
 
e67b731
df376d6
 
0efa4b1
a7fdfc4
d6a46a2
1f662e3
1dc8f91
 
af15a96
 
 
 
 
c91f43f
af15a96
 
 
 
2294783
 
 
 
 
 
 
 
af15a96
2294783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af15a96
2294783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import streamlit as st

from PIL import Image
from io import BytesIO
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor


def run_prediction(sample):
    global pretrained_model, processor, task_prompt
    if isinstance(sample, dict):
        # prepare inputs
        pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
    else:  # sample is an image
        # prepare encoder inputs
        pixel_values = processor(image, return_tensors="pt").pixel_values
    
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # run inference
    outputs = pretrained_model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=pretrained_model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    prediction = processor.batch_decode(outputs.sequences)[0]
    
    # post-processing
    if "cord" in task_prompt:
        prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        # prediction = re.sub(r"<.*?>", "", prediction, count=1).strip()  # remove first task start token
    prediction = processor.token2json(prediction)
    
    # load reference target
    if isinstance(sample, dict):
        target = processor.token2json(sample["target_sequence"])
    else:
        target = "<not_provided>"
    
    return prediction, target
    

task_prompt = f"<s>"

logo = Image.open("./img/rsz_unstructured_logo.png")
st.image(logo)

st.markdown('''
### Receipt Parser
This is an OCR-free Document Understanding Transformer nicknamed 🍩. It was fine-tuned with 1000 receipt images -> SROIE dataset.
The original 🍩 implementation can be found on [here](https://github.com/clovaai/donut).

At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines 🤩. 
Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community 😊.
''')

image_upload = None
photo = None
with st.sidebar:
    information = st.radio(
    "What information inside the 🧾s are you interested in extracting?",
    ('Receipt Summary', 'Receipt Menu Details', 'Extract all', 'Unstructured.io Parser'))
    receipt = st.selectbox('Pick one 🧾', ['1', '2', '3', '4', '5', '6'], index=1)

    # file upload
    uploaded_file = st.file_uploader("Upload a 🧾")
    if uploaded_file is not None:
        # To read file as bytes:
        image_bytes_data = uploaded_file.getvalue()
        image_upload = Image.open(BytesIO(image_bytes_data))  #.frombytes('RGBA', (128,128), image_bytes_data, 'raw')
        # st.write(bytes_data)

    camera_click = st.button('Use my camera')
    img_file_buffer = None
    if camera_click:
        img_file_buffer = st.camera_input("Take a picture of your receipt!")
    
    if img_file_buffer:
        # To read image file buffer as a PIL Image:
        photo = Image.open(img_file_buffer)
        st.info("picture taken!")
        
st.text(f'{information} mode is ON!\nTarget 🧾: {receipt}')  # \n(opening image @:./img/receipt-{receipt}.png)')

col1, col2 = st.columns(2)

if photo:
    image = photo
    st.info("photo loaded to image")
elif image_upload:
    image = image_upload
else:
    image = Image.open(f"./img/receipt-{receipt}.jpg")
with col1:
    st.image(image, caption='Your target receipt')

if st.button('Parse receipt! 🐍'):
    with st.spinner(f'baking the 🍩s...'):
        if information == 'Receipt Summary':
            processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
            pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
            task_prompt = f"<s>"
            device = "cuda" if torch.cuda.is_available() else "cpu"
            pretrained_model.to(device)
        
        elif information == 'Receipt Menu Details':
            processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
            pretrained_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
            task_prompt = f"<s_cord-v2>"
            device = "cuda" if torch.cuda.is_available() else "cpu"
            pretrained_model.to(device)
            
        elif information == 'Unstructured.io Parser':
            processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-labelstudio-A1.0")
            pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-labelstudio-A1.0")
            task_prompt = f"<s>"
            device = "cuda" if torch.cuda.is_available() else "cpu"
            pretrained_model.to(device)
            
        else:  # Extract all
            processor_a = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
            processor_b = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
            pretrained_model_a = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
            pretrained_model_b = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
    with col2:
        if information == 'Extract all':
            st.info(f'parsing 🧾 (extracting all)...')
            pretrained_model, processor, task_prompt = pretrained_model_a, processor_a, f"<s>"
            pretrained_model.to(device)
            parsed_receipt_info_a, _ = run_prediction(image)
            pretrained_model, processor, task_prompt = pretrained_model_b, processor_b, f"<s_cord-v2>"
            pretrained_model.to(device)
            parsed_receipt_info_b, _ = run_prediction(image)
            st.text(f'\nReceipt Summary:')
            st.json(parsed_receipt_info_a)
            st.text(f'\nReceipt Menu Details:')
            st.json(parsed_receipt_info_b)
        else:
            st.info(f'parsing 🧾...')
            parsed_receipt_info, _ = run_prediction(image)
            st.text(f'\n{information}')
            st.json(parsed_receipt_info)