Ankur Goyal commited on
Commit
bc6a638
1 Parent(s): 8171e8e

Improve state management/data flow

Browse files
Files changed (1) hide show
  1. app.py +51 -9
app.py CHANGED
@@ -2,13 +2,12 @@ import os
2
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
- print("Importing")
6
-
7
  import streamlit as st
8
 
9
  import torch
10
  from docquery.pipeline import get_pipeline
11
- from docquery.document import load_bytes
 
12
 
13
  def ensure_list(x):
14
  if isinstance(x, list):
@@ -16,27 +15,70 @@ def ensure_list(x):
16
  else:
17
  return [x]
18
 
 
19
  @st.experimental_singleton
20
  def construct_pipeline():
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  ret = get_pipeline(device=device)
23
  return ret
24
 
 
25
  @st.cache
26
  def run_pipeline(question, document):
27
  return construct_pipeline()(question=question, **document.context)
28
 
29
- st.title("DocQuery: Query Documents Using NLP")
30
- file = st.file_uploader("Upload a PDF or Image document")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  question = st.text_input("QUESTION", "")
32
 
33
- if file is not None:
 
 
34
  col1, col2 = st.columns(2)
35
-
36
- document = load_bytes(file, file.name)
37
  col1.image(document.preview, use_column_width=True)
38
 
39
- if file is not None and question is not None and len(question) > 0:
40
  predictions = run_pipeline(question=question, document=document)
41
 
42
  col2.header("Answers")
2
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
 
 
5
  import streamlit as st
6
 
7
  import torch
8
  from docquery.pipeline import get_pipeline
9
+ from docquery.document import load_bytes, load_document
10
+
11
 
12
  def ensure_list(x):
13
  if isinstance(x, list):
15
  else:
16
  return [x]
17
 
18
+
19
  @st.experimental_singleton
20
  def construct_pipeline():
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  ret = get_pipeline(device=device)
23
  return ret
24
 
25
+
26
  @st.cache
27
  def run_pipeline(question, document):
28
  return construct_pipeline()(question=question, **document.context)
29
 
30
+
31
+ st.markdown("# DocQuery: Query Documents w/ NLP")
32
+
33
+ if "document" not in st.session_state:
34
+ st.session_state["document"] = None
35
+
36
+ input_type = st.radio("Pick an input type", ["Upload", "URL"], horizontal=True)
37
+
38
+
39
+ def load_file_cb():
40
+ if st.session_state.file_input is None:
41
+ return
42
+
43
+ file = st.session_state.file_input
44
+ with loading_placeholder:
45
+ with st.spinner("Processing..."):
46
+ document = load_bytes(file, file.name)
47
+ _ = document.context
48
+ st.session_state.document = document
49
+
50
+
51
+ def load_url(url):
52
+ if st.session_state.url_input is None:
53
+ return
54
+
55
+ url = st.session_state.url_input
56
+ with loading_placeholder:
57
+ with st.spinner("Downloading..."):
58
+ document = load_document(url)
59
+ with st.spinner("Processing..."):
60
+ _ = document.context
61
+ st.session_state.document = document
62
+
63
+
64
+ if input_type == "Upload":
65
+ file = st.file_uploader(
66
+ "Upload a PDF or Image document", key="file_input", on_change=load_file_cb
67
+ )
68
+
69
+ elif input_type == "URL":
70
+ # url = st.text_input("URL", "", on_change=load_url_callback, key="url_input")
71
+ url = st.text_input("URL", "", key="url_input", on_change=load_url_cb)
72
+
73
  question = st.text_input("QUESTION", "")
74
 
75
+ document = st.session_state.document
76
+ loading_placeholder = st.empty()
77
+ if document is not None:
78
  col1, col2 = st.columns(2)
 
 
79
  col1.image(document.preview, use_column_width=True)
80
 
81
+ if document is not None and question is not None and len(question) > 0:
82
  predictions = run_pipeline(question=question, document=document)
83
 
84
  col2.header("Answers")