lfoppiano commited on
Commit
452ec4c
1 Parent(s): 88c1cba

improve error handling and avoid model switching after upload

Browse files
Files changed (1) hide show
  1. streamlit_app.py +26 -18
streamlit_app.py CHANGED
@@ -20,8 +20,8 @@ from grobid_client_generic import GrobidClientGeneric
20
  if 'rqa' not in st.session_state:
21
  st.session_state['rqa'] = {}
22
 
23
- if 'api_key' not in st.session_state:
24
- st.session_state['api_key'] = False
25
 
26
  if 'api_keys' not in st.session_state:
27
  st.session_state['api_keys'] = {}
@@ -128,11 +128,10 @@ def play_old_messages():
128
  else:
129
  st.write(message['content'])
130
 
131
-
132
- is_api_key_provided = st.session_state['api_key']
133
 
134
  with st.sidebar:
135
- model = st.radio(
136
  "Model (cannot be changed after selection or upload)",
137
  ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
138
  index=1,
@@ -141,7 +140,8 @@ with st.sidebar:
141
  "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
142
  # "LLama2-70B-Chat + Sentence BERT (embeddings)",
143
  ],
144
- help="Select the model you want to use.")
 
145
 
146
  if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
147
  api_key = st.text_input('Huggingface API Key',
@@ -151,23 +151,25 @@ with st.sidebar:
151
  "Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
152
 
153
  if api_key:
154
- st.session_state['api_key'] = is_api_key_provided = True
155
- st.session_state['api_keys']['mistral-7b-instruct-v0.1'] = api_key
156
- if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
157
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
158
- st.session_state['rqa'][model] = init_qa(model)
 
159
 
160
  elif model == 'chatgpt-3.5-turbo':
161
  api_key = st.text_input('OpenAI API Key', type="password") if 'OPENAI_API_KEY' not in os.environ else \
162
- os.environ['OPENAI_API_KEY']
163
  st.markdown(
164
  "Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
165
  if api_key:
166
- st.session_state['api_key'] = is_api_key_provided = True
167
- st.session_state['api_keys']['chatgpt-3.5-turbo'] = api_key
168
- if 'OPENAI_API_KEY' not in os.environ:
169
- os.environ['OPENAI_API_KEY'] = api_key
170
- st.session_state['rqa'][model] = init_qa(model)
 
171
  # else:
172
  # is_api_key_provided = st.session_state['api_key']
173
 
@@ -175,7 +177,7 @@ st.title("📝 Scientific Document Insight Q&A")
175
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
176
 
177
  uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
178
- disabled=not is_api_key_provided,
179
  help="The full-text is extracted using Grobid. ")
180
 
181
  question = st.chat_input(
@@ -220,6 +222,9 @@ with st.sidebar:
220
  """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
221
 
222
  if uploaded_file and not st.session_state.loaded_embeddings:
 
 
 
223
  with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
224
  binary = uploaded_file.getvalue()
225
  tmp_file = NamedTemporaryFile()
@@ -240,6 +245,9 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
240
  st.markdown(message["content"], unsafe_allow_html=True)
241
  elif message['mode'] == "Embeddings":
242
  st.write(message["content"])
 
 
 
243
 
244
  with st.chat_message("user"):
245
  st.markdown(question)
 
20
  if 'rqa' not in st.session_state:
21
  st.session_state['rqa'] = {}
22
 
23
+ if 'model' not in st.session_state:
24
+ st.session_state['model'] = None
25
 
26
  if 'api_keys' not in st.session_state:
27
  st.session_state['api_keys'] = {}
 
128
  else:
129
  st.write(message['content'])
130
 
131
+ # is_api_key_provided = st.session_state['api_key']
 
132
 
133
  with st.sidebar:
134
+ st.session_state['model'] = model = st.radio(
135
  "Model (cannot be changed after selection or upload)",
136
  ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
137
  index=1,
 
140
  "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
141
  # "LLama2-70B-Chat + Sentence BERT (embeddings)",
142
  ],
143
+ help="Select the model you want to use.",
144
+ disabled=st.session_state['doc_id'] is not None)
145
 
146
  if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
147
  api_key = st.text_input('Huggingface API Key',
 
151
  "Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
152
 
153
  if api_key:
154
+ # st.session_state['api_key'] = is_api_key_provided = True
155
+ with st.spinner("Preparing environment"):
156
+ st.session_state['api_keys']['mistral-7b-instruct-v0.1'] = api_key
157
+ if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
158
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
159
+ st.session_state['rqa'][model] = init_qa(model)
160
 
161
  elif model == 'chatgpt-3.5-turbo':
162
  api_key = st.text_input('OpenAI API Key', type="password") if 'OPENAI_API_KEY' not in os.environ else \
163
+ os.environ['OPENAI_API_KEY']
164
  st.markdown(
165
  "Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
166
  if api_key:
167
+ # st.session_state['api_key'] = is_api_key_provided = True
168
+ with st.spinner("Preparing environment"):
169
+ st.session_state['api_keys']['chatgpt-3.5-turbo'] = api_key
170
+ if 'OPENAI_API_KEY' not in os.environ:
171
+ os.environ['OPENAI_API_KEY'] = api_key
172
+ st.session_state['rqa'][model] = init_qa(model)
173
  # else:
174
  # is_api_key_provided = st.session_state['api_key']
175
 
 
177
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
178
 
179
  uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
180
+ disabled=st.session_state['model'] is not None and st.session_state['model'] not in st.session_state['api_keys'],
181
  help="The full-text is extracted using Grobid. ")
182
 
183
  question = st.chat_input(
 
222
  """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
223
 
224
  if uploaded_file and not st.session_state.loaded_embeddings:
225
+ if model not in st.session_state['api_keys']:
226
+ st.error("Before uploading a document, you must enter the API key. ")
227
+ st.stop()
228
  with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
229
  binary = uploaded_file.getvalue()
230
  tmp_file = NamedTemporaryFile()
 
245
  st.markdown(message["content"], unsafe_allow_html=True)
246
  elif message['mode'] == "Embeddings":
247
  st.write(message["content"])
248
+ if model not in st.session_state['rqa']:
249
+ st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
250
+ st.stop()
251
 
252
  with st.chat_message("user"):
253
  st.markdown(question)