File size: 3,713 Bytes
d8e07ba
 
 
 
 
 
 
 
 
 
cd03817
d8e07ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import re
import streamlit as st
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import os
from PIL import Image
import PyPDF2
from pypdf.errors import PdfReadError
from pypdf import PdfReader
import pypdfium2 as pdfium

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

device ="cpu"
model.to(device)

#create uploader
document = st.file_uploader(label="Upload the document you want to explore",type=["png",'jpg', "jpeg","pdf"])

question = st.text_input(str("Insert here you question?"))

if document == None:
    st.write("Please upload the document in the box above")
else:
    try:
        PdfReader(document)
        pdf = pdfium.PdfDocument(document)
        page = pdf.get_page(0)
        pil_image = page.render(scale = 300/72).to_pil()
        #st.image(pil_image, caption="Document uploaded", use_column_width=True)
        task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
        #question = "What's the total amount?"
        prompt = task_prompt.replace("{user_input}", question)
        decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
        pixel_values = processor(pil_image, return_tensors="pt").pixel_values     
        outputs = model.generate(
                pixel_values.to(device),
            decoder_input_ids=decoder_input_ids.to(device),
            max_length=model.decoder.config.max_position_embeddings,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
        st.image(pil_image,"Document uploaded")
        st.write(processor.token2json(sequence))
        print(processor.token2json(sequence))


    except PdfReadError:
        #image = Image.open(document)
        #st.image(document, caption="Document uploaded", use_column_width=False)
        # prepare decoder inputs
        document = Image.open(document)

        task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
        #question = "What's the total amount?"
        prompt = task_prompt.replace("{user_input}", question)
        decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
        pixel_values = processor(document, return_tensors="pt").pixel_values

        outputs = model.generate(
            pixel_values.to(device),
            decoder_input_ids=decoder_input_ids.to(device),
            max_length=model.decoder.config.max_position_embeddings,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
        st.image(document,"Document uploaded")
        st.write(processor.token2json(sequence))