barunsaha commited on
Commit
b5434b8
1 Parent(s): f809d41

Use LangChain to get streaming response from the LLM; update progress bar to display the current status

Browse files
Files changed (1) hide show
  1. app.py +44 -79
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import datetime
2
  import logging
3
  import pathlib
@@ -7,9 +10,7 @@ from typing import List
7
 
8
  import json5
9
  import streamlit as st
10
- from langchain_community.chat_message_histories import (
11
- StreamlitChatMessageHistory
12
- )
13
  from langchain_core.messages import HumanMessage
14
  from langchain_core.prompts import ChatPromptTemplate
15
 
@@ -47,17 +48,9 @@ def _get_prompt_template(is_refinement: bool) -> str:
47
  return template
48
 
49
 
50
- # @st.cache_resource
51
- # def _get_tokenizer() -> AutoTokenizer:
52
- # """
53
- # Get Mistral tokenizer for counting tokens.
54
- #
55
- # :return: The tokenizer.
56
- # """
57
- #
58
- # return AutoTokenizer.from_pretrained(
59
- # pretrained_model_name_or_path=GlobalConfig.HF_LLM_MODEL_NAME
60
- # )
61
 
62
 
63
  APP_TEXT = _load_strings()
@@ -66,9 +59,10 @@ APP_TEXT = _load_strings()
66
  CHAT_MESSAGES = 'chat_messages'
67
  DOWNLOAD_FILE_KEY = 'download_file_name'
68
  IS_IT_REFINEMENT = 'is_it_refinement'
 
 
69
 
70
  logger = logging.getLogger(__name__)
71
- progress_bar = st.progress(0, text='Setting up SlideDeck AI...')
72
 
73
  texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
74
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
@@ -110,7 +104,6 @@ def build_ui():
110
  with st.expander('Usage Policies and Limitations'):
111
  display_page_footer_content()
112
 
113
- progress_bar.progress(50, text='Setting up chat interface...')
114
  set_up_chat_ui()
115
 
116
 
@@ -131,8 +124,6 @@ def set_up_chat_ui():
131
  st.chat_message('ai').write(
132
  random.choice(APP_TEXT['ai_greetings'])
133
  )
134
- progress_bar.progress(100, text='Done!')
135
- progress_bar.empty()
136
 
137
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
138
 
@@ -188,66 +179,51 @@ def set_up_chat_ui():
188
  }
189
  )
190
 
191
- with st.status(
192
- 'Calling LLM...will retry if connection times out...',
193
- expanded=False
194
- ) as status:
195
- response: dict = llm_helper.hf_api_query({
196
- 'inputs': formatted_template,
197
- 'parameters': {
198
- 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
199
- 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
200
- 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
201
- 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
202
- 'num_return_sequences': 1,
203
- 'return_full_text': False,
204
- # "repetition_penalty": 0.0001
205
- },
206
- 'options': {
207
- 'wait_for_model': True,
208
- 'use_cache': True
209
- }
210
- })
211
 
212
- if len(response) > 0 and 'generated_text' in response[0]:
213
- response: str = response[0]['generated_text'].strip()
214
 
215
- st.chat_message('ai').code(response, language='json')
 
 
216
 
217
- history.add_user_message(prompt)
218
- history.add_ai_message(response)
219
 
220
- # The content has been generated as JSON
221
- # There maybe trailing ``` at the end of the response -- remove them
222
- # To be careful: ``` may be part of the content as well when code is generated
223
- response_cleaned = text_helper.get_clean_json(response)
224
 
225
- logger.info(
226
- 'Cleaned JSON response:: original length: %d | cleaned length: %d',
227
- len(response), len(response_cleaned)
228
- )
229
- logger.debug('Cleaned JSON: %s', response_cleaned)
230
 
231
- # Now create the PPT file
232
- status.update(
233
- label='Searching photos and creating the slide deck...give it a moment...',
234
- state='running',
235
- expanded=False
236
- )
237
- generate_slide_deck(response_cleaned)
238
- status.update(label='Done!', state='complete', expanded=True)
239
 
240
- logger.info(
241
- '#messages in history / 2: %d',
242
- len(st.session_state[CHAT_MESSAGES]) / 2
243
- )
 
 
 
244
 
245
 
246
- def generate_slide_deck(json_str: str):
247
  """
248
- Create a slide deck.
 
249
 
250
  :param json_str: The content in *valid* JSON format.
 
251
  """
252
 
253
  if DOWNLOAD_FILE_KEY in st.session_state:
@@ -269,17 +245,6 @@ def generate_slide_deck(json_str: str):
269
  output_file_path=path
270
  )
271
  except ValueError:
272
- # st.error(
273
- # f"{APP_TEXT['json_parsing_error']}"
274
- # f"\n\nAdditional error info: {ve}"
275
- # f"\n\nHere are some sample instructions that you could try to possibly fix this error;"
276
- # f" if these don't work, try rephrasing or refreshing:"
277
- # f"\n\n"
278
- # "- Regenerate content and fix the JSON error."
279
- # "\n- Regenerate content and fix the JSON error. Quotes inside quotes should be escaped."
280
- # )
281
- # logger.error('%s', APP_TEXT['json_parsing_error'])
282
- # logger.error('Additional error info: %s', str(ve))
283
  st.error(
284
  'Encountered error while parsing JSON...will fix it and retry'
285
  )
@@ -295,8 +260,8 @@ def generate_slide_deck(json_str: str):
295
  except Exception as ex:
296
  st.error(APP_TEXT['content_generation_error'])
297
  logger.error('Caught a generic exception: %s', str(ex))
298
- finally:
299
- _display_download_button(path)
300
 
301
 
302
  def _is_it_refinement() -> bool:
 
1
+ """
2
+ Streamlit app containing the UI and the application logic.
3
+ """
4
  import datetime
5
  import logging
6
  import pathlib
 
10
 
11
  import json5
12
  import streamlit as st
13
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
 
 
14
  from langchain_core.messages import HumanMessage
15
  from langchain_core.prompts import ChatPromptTemplate
16
 
 
48
  return template
49
 
50
 
51
+ @st.cache_resource
52
+ def _get_llm():
53
+ return llm_helper.get_hf_endpoint()
 
 
 
 
 
 
 
 
54
 
55
 
56
  APP_TEXT = _load_strings()
 
59
  CHAT_MESSAGES = 'chat_messages'
60
  DOWNLOAD_FILE_KEY = 'download_file_name'
61
  IS_IT_REFINEMENT = 'is_it_refinement'
62
+ APPROX_TARGET_LENGTH = GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH / 2
63
+
64
 
65
  logger = logging.getLogger(__name__)
 
66
 
67
  texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
68
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
 
104
  with st.expander('Usage Policies and Limitations'):
105
  display_page_footer_content()
106
 
 
107
  set_up_chat_ui()
108
 
109
 
 
124
  st.chat_message('ai').write(
125
  random.choice(APP_TEXT['ai_greetings'])
126
  )
 
 
127
 
128
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
129
 
 
179
  }
180
  )
181
 
182
+ progress_bar = st.progress(0, 'Preparing to call LLM...')
183
+ response = ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ for chunk in _get_llm().stream(formatted_template):
186
+ response += chunk
187
 
188
+ # Update the progress bar
189
+ progress_percentage = min(len(response) / APPROX_TARGET_LENGTH, 0.95)
190
+ progress_bar.progress(progress_percentage, text='Streaming content...')
191
 
192
+ history.add_user_message(prompt)
193
+ history.add_ai_message(response)
194
 
195
+ # The content has been generated as JSON
196
+ # There maybe trailing ``` at the end of the response -- remove them
197
+ # To be careful: ``` may be part of the content as well when code is generated
198
+ response_cleaned = text_helper.get_clean_json(response)
199
 
200
+ logger.info(
201
+ 'Cleaned JSON response:: original length: %d | cleaned length: %d',
202
+ len(response), len(response_cleaned)
203
+ )
204
+ logger.debug('Cleaned JSON: %s', response_cleaned)
205
 
206
+ # Now create the PPT file
207
+ progress_bar.progress(0.95, text='Searching photos and generating the slide deck...')
208
+ path = generate_slide_deck(response_cleaned)
209
+ progress_bar.progress(1.0, text='Done!')
 
 
 
 
210
 
211
+ st.chat_message('ai').code(response, language='json')
212
+ _display_download_button(path)
213
+
214
+ logger.info(
215
+ '#messages in history / 2: %d',
216
+ len(st.session_state[CHAT_MESSAGES]) / 2
217
+ )
218
 
219
 
220
+ def generate_slide_deck(json_str: str) -> pathlib.Path:
221
  """
222
+ Create a slide deck and return the file path. In case there is any error creating the slide
223
+ deck, the path may be to an empty file.
224
 
225
  :param json_str: The content in *valid* JSON format.
226
+ :return: The file of the .pptx file.
227
  """
228
 
229
  if DOWNLOAD_FILE_KEY in st.session_state:
 
245
  output_file_path=path
246
  )
247
  except ValueError:
 
 
 
 
 
 
 
 
 
 
 
248
  st.error(
249
  'Encountered error while parsing JSON...will fix it and retry'
250
  )
 
260
  except Exception as ex:
261
  st.error(APP_TEXT['content_generation_error'])
262
  logger.error('Caught a generic exception: %s', str(ex))
263
+
264
+ return path
265
 
266
 
267
  def _is_it_refinement() -> bool: