lfoppiano commited on
Commit
88c1cba
β€’
1 Parent(s): 64372e1

move settings on the sidebar, allow env variables

Browse files
Files changed (1) hide show
  1. streamlit_app.py +78 -57
streamlit_app.py CHANGED
@@ -18,11 +18,14 @@ from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_t
18
  from grobid_client_generic import GrobidClientGeneric
19
 
20
  if 'rqa' not in st.session_state:
21
- st.session_state['rqa'] = None
22
 
23
  if 'api_key' not in st.session_state:
24
  st.session_state['api_key'] = False
25
 
 
 
 
26
  if 'doc_id' not in st.session_state:
27
  st.session_state['doc_id'] = None
28
 
@@ -42,13 +45,16 @@ if 'git_rev' not in st.session_state:
42
  if "messages" not in st.session_state:
43
  st.session_state.messages = []
44
 
 
 
 
45
 
46
  def new_file():
47
  st.session_state['loaded_embeddings'] = None
48
  st.session_state['doc_id'] = None
49
 
50
 
51
- @st.cache_resource
52
  def init_qa(model):
53
  if model == 'chatgpt-3.5-turbo':
54
  chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
@@ -67,6 +73,7 @@ def init_qa(model):
67
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
68
  else:
69
  st.error("The model was not loaded properly. Try reloading. ")
 
70
 
71
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
72
 
@@ -94,7 +101,6 @@ def init_ner():
94
  grobid_quantities_client=quantities_client,
95
  grobid_superconductors_client=materials_client
96
  )
97
-
98
  return gqa
99
 
100
 
@@ -125,51 +131,52 @@ def play_old_messages():
125
 
126
  is_api_key_provided = st.session_state['api_key']
127
 
128
- model = st.sidebar.radio("Model (cannot be changed after selection or upload)",
129
- ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
130
- index=1,
131
- captions=[
132
- "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
133
- "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
134
- # "LLama2-70B-Chat + Sentence BERT (embeddings)",
135
- ],
136
- help="Select the model you want to use.",
137
- disabled=is_api_key_provided)
138
-
139
- if not st.session_state['api_key']:
140
  if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
141
- api_key = st.sidebar.text_input('Huggingface API Key',
142
- type="password") # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
 
 
 
 
143
  if api_key:
144
  st.session_state['api_key'] = is_api_key_provided = True
145
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
146
- st.session_state['rqa'] = init_qa(model)
 
 
 
147
  elif model == 'chatgpt-3.5-turbo':
148
- api_key = st.sidebar.text_input('OpenAI API Key',
149
- type="password") # if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
 
 
150
  if api_key:
151
  st.session_state['api_key'] = is_api_key_provided = True
152
- os.environ['OPENAI_API_KEY'] = api_key
153
- st.session_state['rqa'] = init_qa(model)
154
- else:
155
- is_api_key_provided = st.session_state['api_key']
 
 
156
 
157
  st.title("πŸ“ Scientific Document Insight Q&A")
158
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
159
 
160
- upload_col, radio_col, context_col = st.columns([7, 2, 2])
161
- with upload_col:
162
- uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
163
- disabled=not is_api_key_provided,
164
- help="The full-text is extracted using Grobid. ")
165
- with radio_col:
166
- mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0,
167
- help="LLM will respond the question, Embedding will show the "
168
- "paragraphs relevant to the question in the paper.")
169
- with context_col:
170
- context_size = st.slider("Context size", 3, 10, value=4,
171
- help="Number of paragraphs to consider when answering a question",
172
- disabled=not uploaded_file)
173
 
174
  question = st.chat_input(
175
  "Ask something about the article",
@@ -178,14 +185,29 @@ question = st.chat_input(
178
  )
179
 
180
  with st.sidebar:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  st.header("Documentation")
182
  st.markdown("https://github.com/lfoppiano/document-qa")
183
  st.markdown(
184
  """After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
185
 
186
- st.markdown(
187
- '**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
188
- unsafe_allow_html=True)
189
  if st.session_state['git_rev'] != "unknown":
190
  st.markdown("**Revision number**: [" + st.session_state[
191
  'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
@@ -203,9 +225,9 @@ if uploaded_file and not st.session_state.loaded_embeddings:
203
  tmp_file = NamedTemporaryFile()
204
  tmp_file.write(bytearray(binary))
205
  # hash = get_file_hash(tmp_file.name)[:10]
206
- st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name,
207
- chunk_size=250,
208
- perc_overlap=0.1)
209
  st.session_state['loaded_embeddings'] = True
210
  st.session_state.messages = []
211
 
@@ -226,27 +248,26 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
226
  text_response = None
227
  if mode == "Embeddings":
228
  with st.spinner("Generating LLM response..."):
229
- text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
230
- context_size=context_size)
231
  elif mode == "LLM":
232
  with st.spinner("Generating response..."):
233
- _, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
234
- context_size=context_size)
235
 
236
  if not text_response:
237
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
238
 
239
  with st.chat_message("assistant"):
240
  if mode == "LLM":
241
- with st.spinner("Processing NER on LLM response..."):
242
- entities = gqa.process_single_text(text_response)
243
- # for entity in entities:
244
- # entity
245
- decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
246
- decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
247
- decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
248
- st.markdown(decorated_text, unsafe_allow_html=True)
249
- text_response = decorated_text
250
  else:
251
  st.write(text_response)
252
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
 
18
  from grobid_client_generic import GrobidClientGeneric
19
 
20
  if 'rqa' not in st.session_state:
21
+ st.session_state['rqa'] = {}
22
 
23
  if 'api_key' not in st.session_state:
24
  st.session_state['api_key'] = False
25
 
26
+ if 'api_keys' not in st.session_state:
27
+ st.session_state['api_keys'] = {}
28
+
29
  if 'doc_id' not in st.session_state:
30
  st.session_state['doc_id'] = None
31
 
 
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
47
 
48
+ if 'ner_processing' not in st.session_state:
49
+ st.session_state['ner_processing'] = False
50
+
51
 
52
  def new_file():
53
  st.session_state['loaded_embeddings'] = None
54
  st.session_state['doc_id'] = None
55
 
56
 
57
+ # @st.cache_resource
58
  def init_qa(model):
59
  if model == 'chatgpt-3.5-turbo':
60
  chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
 
73
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
74
  else:
75
  st.error("The model was not loaded properly. Try reloading. ")
76
+ st.stop()
77
 
78
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
79
 
 
101
  grobid_quantities_client=quantities_client,
102
  grobid_superconductors_client=materials_client
103
  )
 
104
  return gqa
105
 
106
 
 
131
 
132
  is_api_key_provided = st.session_state['api_key']
133
 
134
+ with st.sidebar:
135
+ model = st.radio(
136
+ "Model (cannot be changed after selection or upload)",
137
+ ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
138
+ index=1,
139
+ captions=[
140
+ "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
141
+ "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
142
+ # "LLama2-70B-Chat + Sentence BERT (embeddings)",
143
+ ],
144
+ help="Select the model you want to use.")
145
+
146
  if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
147
+ api_key = st.text_input('Huggingface API Key',
148
+ type="password") if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ[
149
+ 'HUGGINGFACEHUB_API_TOKEN']
150
+ st.markdown(
151
+ "Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
152
+
153
  if api_key:
154
  st.session_state['api_key'] = is_api_key_provided = True
155
+ st.session_state['api_keys']['mistral-7b-instruct-v0.1'] = api_key
156
+ if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
157
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
158
+ st.session_state['rqa'][model] = init_qa(model)
159
+
160
  elif model == 'chatgpt-3.5-turbo':
161
+ api_key = st.text_input('OpenAI API Key', type="password") if 'OPENAI_API_KEY' not in os.environ else \
162
+ os.environ['OPENAI_API_KEY']
163
+ st.markdown(
164
+ "Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
165
  if api_key:
166
  st.session_state['api_key'] = is_api_key_provided = True
167
+ st.session_state['api_keys']['chatgpt-3.5-turbo'] = api_key
168
+ if 'OPENAI_API_KEY' not in os.environ:
169
+ os.environ['OPENAI_API_KEY'] = api_key
170
+ st.session_state['rqa'][model] = init_qa(model)
171
+ # else:
172
+ # is_api_key_provided = st.session_state['api_key']
173
 
174
  st.title("πŸ“ Scientific Document Insight Q&A")
175
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
176
 
177
+ uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
178
+ disabled=not is_api_key_provided,
179
+ help="The full-text is extracted using Grobid. ")
 
 
 
 
 
 
 
 
 
 
180
 
181
  question = st.chat_input(
182
  "Ask something about the article",
 
185
  )
186
 
187
  with st.sidebar:
188
+ st.header("Settings")
189
+ mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0, horizontal=True,
190
+ help="LLM will respond the question, Embedding will show the "
191
+ "paragraphs relevant to the question in the paper.")
192
+ chunk_size = st.slider("Chunks size", 100, 2000, value=250,
193
+ help="Size of chunks in which the document is partitioned",
194
+ disabled=not uploaded_file)
195
+ context_size = st.slider("Context size", 3, 10, value=4,
196
+ help="Number of chunks to consider when answering a question",
197
+ disabled=not uploaded_file)
198
+
199
+ st.session_state['ner_processing'] = st.checkbox("NER processing on LLM response")
200
+ st.markdown(
201
+ '**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
202
+ unsafe_allow_html=True)
203
+
204
+ st.divider()
205
+
206
  st.header("Documentation")
207
  st.markdown("https://github.com/lfoppiano/document-qa")
208
  st.markdown(
209
  """After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
210
 
 
 
 
211
  if st.session_state['git_rev'] != "unknown":
212
  st.markdown("**Revision number**: [" + st.session_state[
213
  'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
 
225
  tmp_file = NamedTemporaryFile()
226
  tmp_file.write(bytearray(binary))
227
  # hash = get_file_hash(tmp_file.name)[:10]
228
+ st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
229
+ chunk_size=chunk_size,
230
+ perc_overlap=0.1)
231
  st.session_state['loaded_embeddings'] = True
232
  st.session_state.messages = []
233
 
 
248
  text_response = None
249
  if mode == "Embeddings":
250
  with st.spinner("Generating LLM response..."):
251
+ text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id,
252
+ context_size=context_size)
253
  elif mode == "LLM":
254
  with st.spinner("Generating response..."):
255
+ _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
256
+ context_size=context_size)
257
 
258
  if not text_response:
259
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
260
 
261
  with st.chat_message("assistant"):
262
  if mode == "LLM":
263
+ if st.session_state['ner_processing']:
264
+ with st.spinner("Processing NER on LLM response..."):
265
+ entities = gqa.process_single_text(text_response)
266
+ decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
267
+ decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
268
+ decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
269
+ text_response = decorated_text
270
+ st.markdown(text_response, unsafe_allow_html=True)
 
271
  else:
272
  st.write(text_response)
273
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})