File size: 6,611 Bytes
bc12901
 
 
 
1af0b6d
bc12901
d229b67
bc12901
ab36703
 
bc12901
 
bc6a638
 
bc12901
 
 
 
 
 
 
bc6a638
225fcc2
 
 
 
 
 
 
 
8171e8e
225fcc2
8171e8e
 
bc6a638
225fcc2
 
 
 
1af0b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc12901
bc6a638
 
 
 
 
 
d229b67
 
 
 
bc6a638
87ad231
d229b67
 
 
87ad231
 
225fcc2
 
bc6a638
 
 
 
 
 
 
 
 
 
 
 
 
2919076
bc6a638
 
 
 
 
 
 
 
 
 
 
 
d229b67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc6a638
 
 
 
 
 
d229b67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc12901
bc6a638
 
 
588673f
1af0b6d
bc12901
d229b67
1af0b6d
 
225fcc2
87ad231
d229b67
 
87ad231
225fcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af0b6d
 
225fcc2
bc12901
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from PIL import ImageDraw
import streamlit as st
from st_clickable_images import clickable_images

st.set_page_config(layout="wide")

import torch
from docquery.pipeline import get_pipeline
from docquery.document import load_bytes, load_document


def ensure_list(x):
    if isinstance(x, list):
        return x
    else:
        return [x]


CHECKPOINTS = {
    "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
    "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
}


@st.experimental_singleton(show_spinner=False)
def construct_pipeline(model):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ret = get_pipeline(checkpoint=CHECKPOINTS[model], device=device)
    return ret


@st.cache(show_spinner=False)
def run_pipeline(model, question, document, top_k):
    pipeline = construct_pipeline(model)
    return pipeline(question=question, **document.context, top_k=top_k)


# TODO: Move into docquery
# TODO: Support words past the first page (or window?)
def lift_word_boxes(document):
    return document.context["image"][0][1]


def expand_bbox(word_boxes):
    if len(word_boxes) == 0:
        return None

    min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
    return [min(min_x), min(min_y), max(max_x), max(max_y)]


# LayoutLM boxes are normalized to 0, 1000
def normalize_bbox(box, width, height):
    pct = [c / 1000 for c in box]
    return [pct[0] * width, pct[1] * height, pct[2] * width, pct[3] * height]


st.markdown("# DocQuery: Query Documents w/ NLP")

if "document" not in st.session_state:
    st.session_state["document"] = None

if "last_clicked" not in st.session_state:
    st.session_state["last_clicked"] = None

input_col, model_col = st.columns(2)

with input_col:
    input_type = st.radio(
        "Pick an input type", ["Upload", "URL", "Examples"], horizontal=True
    )

with model_col:
    model_type = st.radio("Pick a model", list(CHECKPOINTS.keys()), horizontal=True)


def load_file_cb():
    if st.session_state.file_input is None:
        return

    file = st.session_state.file_input
    with loading_placeholder:
        with st.spinner("Processing..."):
            document = load_bytes(file, file.name)
            _ = document.context
            st.session_state.document = document


def load_url_cb():
    if st.session_state.url_input is None:
        return

    url = st.session_state.url_input
    with loading_placeholder:
        with st.spinner("Downloading..."):
            document = load_document(url)
        with st.spinner("Processing..."):
            _ = document.context
        st.session_state.document = document


examples = [
    (
        "https://templates.invoicehome.com/invoice-template-us-neat-750px.png",
        "What is the invoice number?",
    ),
    (
        "https://miro.medium.com/max/787/1*iECQRIiOGTmEFLdWkVIH2g.jpeg",
        "What is the purchase amount?",
    ),
    (
        "https://www.accountingcoach.com/wp-content/uploads/2013/10/income-statement-example@2x.png",
        "What are net sales for 2020?",
    ),
]
imgs_clicked = []

if input_type == "Upload":
    file = st.file_uploader(
        "Upload a PDF or Image document", key="file_input", on_change=load_file_cb
    )
elif input_type == "URL":
    url = st.text_input("URL", "", key="url_input", on_change=load_url_cb)
elif input_type == "Examples":
    example_cols = st.columns(len(examples))
    for (i, (path, question)) in enumerate(examples):
        with example_cols[i]:
            imgs_clicked.append(
                clickable_images(
                    [path],
                    div_style={
                        "display": "flex",
                        "justify-content": "center",
                        "flex-wrap": "wrap",
                        "cursor": "pointer",
                    },
                    img_style={"margin": "5px", "height": "200px"},
                )
            )
            st.markdown(
                f"<p style='text-align: center'>{question}</p>",
                unsafe_allow_html=True,
            )
print(imgs_clicked)
imgs_clicked = [-1] * len(imgs_clicked)

#    clicked = clickable_images(
#        [x[0] for x in examples],
#        titles=[x[1] for x in examples],
#        div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
#        img_style={"margin": "5px", "height": "200px"},
#    )
#
#    st.markdown(f"Image #{clicked} clicked" if clicked > -1 else "No image clicked")


question = st.text_input("QUESTION", "", key="question")

document = st.session_state.document
loading_placeholder = st.empty()
if document is not None:
    col1, col2 = st.columns(2)
    image = document.preview

question = st.session_state.question
colors = ["blue", "red", "green"]
if document is not None and question is not None and len(question) > 0:
    col2.header(f"Answers ({model_type})")
    with col2:
        answers_placeholder = st.container()
        answers_loading_placeholder = st.container()

        with answers_loading_placeholder:
            # Run this (one-time) expensive operation outside of the processing
            # question placeholder
            with st.spinner("Constructing pipeline..."):
                construct_pipeline(model_type)

            with st.spinner("Processing question..."):
                predictions = run_pipeline(
                    model=model_type, question=question, document=document, top_k=1
                )

        with answers_placeholder:
            image = image.copy()
            draw = ImageDraw.Draw(image)
            for i, p in enumerate(ensure_list(predictions)):
                col2.markdown(f"#### { p['answer'] }: ({round(p['score'] * 100, 1)}%)")
                if "start" in p and "end" in p:
                    x1, y1, x2, y2 = normalize_bbox(
                        expand_bbox(
                            lift_word_boxes(document)[p["start"] : p["end"] + 1]
                        ),
                        image.width,
                        image.height,
                    )
                    draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i], width=3)

if document is not None:
    col1.image(image, use_column_width="auto")

"DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."

"[Github Repo](https://github.com/impira/docquery)"