Ankur Goyal commited on
Commit
225fcc2
1 Parent(s): 588673f

Support Donut

Browse files
Files changed (2) hide show
  1. app.py +43 -26
  2. requirements.txt +1 -0
app.py CHANGED
@@ -19,16 +19,23 @@ def ensure_list(x):
19
  return [x]
20
 
21
 
22
- @st.experimental_singleton
23
- def construct_pipeline():
 
 
 
 
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
- ret = get_pipeline(device=device)
26
  return ret
27
 
28
 
29
- @st.cache
30
- def run_pipeline(question, document, top_k):
31
- return construct_pipeline()(question=question, **document.context, top_k=top_k)
 
32
 
33
 
34
  # TODO: Move into docquery
@@ -56,13 +63,14 @@ st.markdown("# DocQuery: Query Documents w/ NLP")
56
  if "document" not in st.session_state:
57
  st.session_state["document"] = None
58
 
59
- input_col, model_col = st.columns([2,1])
60
 
61
  with input_col:
62
  input_type = st.radio("Pick an input type", ["Upload", "URL"], horizontal=True)
63
 
64
  with model_col:
65
- model_type = st.radio("Pick a model", ["Upload", "URL"], horizontal=True)
 
66
 
67
  def load_file_cb():
68
  if st.session_state.file_input is None:
@@ -109,30 +117,39 @@ if document is not None:
109
 
110
  colors = ["blue", "red", "green"]
111
  if document is not None and question is not None and len(question) > 0:
112
- col2.header("Answers")
113
  with col2:
114
  answers_placeholder = st.empty()
115
  answers_loading_placeholder = st.empty()
116
 
117
- with answers_loading_placeholder:
118
- with st.spinner("Processing question..."):
119
- predictions = run_pipeline(question=question, document=document, top_k=1)
120
-
121
- with answers_placeholder:
122
- word_boxes = lift_word_boxes(document)
123
- image = image.copy()
124
- draw = ImageDraw.Draw(image)
125
- for i, p in enumerate(ensure_list(predictions)):
126
- col2.markdown(f"#### { p['answer'] }: ({round(p['score'] * 100, 1)}%)")
127
- x1, y1, x2, y2 = normalize_bbox(
128
- expand_bbox(word_boxes[p["start"] : p["end"] + 1]),
129
- image.width,
130
- image.height,
131
- )
132
- draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i], width=3)
 
 
 
 
 
 
 
 
 
133
 
134
  if document is not None:
135
- col1.image(image, use_column_width='auto')
136
 
137
  "DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."
138
 
19
  return [x]
20
 
21
 
22
+ CHECKPOINTS = {
23
+ "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
24
+ "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
25
+ }
26
+
27
+
28
+ @st.experimental_singleton(show_spinner=False)
29
+ def construct_pipeline(model):
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ ret = get_pipeline(checkpoint=CHECKPOINTS[model], device=device)
32
  return ret
33
 
34
 
35
+ @st.cache(show_spinner=False)
36
+ def run_pipeline(model, question, document, top_k):
37
+ pipeline = construct_pipeline(model)
38
+ return pipeline(question=question, **document.context, top_k=top_k)
39
 
40
 
41
  # TODO: Move into docquery
63
  if "document" not in st.session_state:
64
  st.session_state["document"] = None
65
 
66
+ input_col, model_col = st.columns([2, 1])
67
 
68
  with input_col:
69
  input_type = st.radio("Pick an input type", ["Upload", "URL"], horizontal=True)
70
 
71
  with model_col:
72
+ model_type = st.radio("Pick a model", list(CHECKPOINTS.keys()), horizontal=True)
73
+
74
 
75
  def load_file_cb():
76
  if st.session_state.file_input is None:
117
 
118
  colors = ["blue", "red", "green"]
119
  if document is not None and question is not None and len(question) > 0:
120
+ col2.header(f"Answers ({model_type})")
121
  with col2:
122
  answers_placeholder = st.empty()
123
  answers_loading_placeholder = st.empty()
124
 
125
+ with answers_loading_placeholder:
126
+ # Run this (one-time) expensive operation outside of the processing
127
+ # question placeholder
128
+ with st.spinner("Constructing pipeline..."):
129
+ construct_pipeline(model_type)
130
+
131
+ with st.spinner("Processing question..."):
132
+ predictions = run_pipeline(
133
+ model=model_type, question=question, document=document, top_k=1
134
+ )
135
+
136
+ with answers_placeholder:
137
+ image = image.copy()
138
+ draw = ImageDraw.Draw(image)
139
+ for i, p in enumerate(ensure_list(predictions)):
140
+ col2.markdown(f"#### { p['answer'] }: ({round(p['score'] * 100, 1)}%)")
141
+ if "start" in p and "end" in p:
142
+ x1, y1, x2, y2 = normalize_bbox(
143
+ expand_bbox(
144
+ lift_word_boxes(document)[p["start"] : p["end"] + 1]
145
+ ),
146
+ image.width,
147
+ image.height,
148
+ )
149
+ draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i], width=3)
150
 
151
  if document is not None:
152
+ col1.image(image, use_column_width="auto")
153
 
154
  "DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."
155
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
  git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
3
  git+https://github.com/impira/docquery.git@43683e0dae72cadf8e8b4927191978109153458c
 
1
  torch
2
  git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
3
  git+https://github.com/impira/docquery.git@43683e0dae72cadf8e8b4927191978109153458c
4
+ sentencepiece