barunsaha commited on
Commit
3256927
2 Parent(s): 13b960a b5434b8

Merge pull request #23 from barun-saha/visual

Browse files
Files changed (1) hide show
  1. app.py +29 -71
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import datetime
2
  import logging
3
  import pathlib
@@ -7,9 +10,7 @@ from typing import List
7
 
8
  import json5
9
  import streamlit as st
10
- from langchain_community.chat_message_histories import (
11
- StreamlitChatMessageHistory
12
- )
13
  from langchain_core.messages import HumanMessage
14
  from langchain_core.prompts import ChatPromptTemplate
15
 
@@ -47,17 +48,9 @@ def _get_prompt_template(is_refinement: bool) -> str:
47
  return template
48
 
49
 
50
- # @st.cache_resource
51
- # def _get_tokenizer() -> AutoTokenizer:
52
- # """
53
- # Get Mistral tokenizer for counting tokens.
54
- #
55
- # :return: The tokenizer.
56
- # """
57
- #
58
- # return AutoTokenizer.from_pretrained(
59
- # pretrained_model_name_or_path=GlobalConfig.HF_LLM_MODEL_NAME
60
- # )
61
 
62
 
63
  APP_TEXT = _load_strings()
@@ -66,9 +59,10 @@ APP_TEXT = _load_strings()
66
  CHAT_MESSAGES = 'chat_messages'
67
  DOWNLOAD_FILE_KEY = 'download_file_name'
68
  IS_IT_REFINEMENT = 'is_it_refinement'
 
 
69
 
70
  logger = logging.getLogger(__name__)
71
- progress_bar = st.progress(0, text='Setting up SlideDeck AI...')
72
 
73
  texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
74
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
@@ -110,7 +104,6 @@ def build_ui():
110
  with st.expander('Usage Policies and Limitations'):
111
  display_page_footer_content()
112
 
113
- progress_bar.progress(50, text='Setting up chat interface...')
114
  set_up_chat_ui()
115
 
116
 
@@ -131,8 +124,6 @@ def set_up_chat_ui():
131
  st.chat_message('ai').write(
132
  random.choice(APP_TEXT['ai_greetings'])
133
  )
134
- progress_bar.progress(100, text='Done!')
135
- progress_bar.empty()
136
 
137
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
138
 
@@ -156,8 +147,6 @@ def set_up_chat_ui():
156
  placeholder=APP_TEXT['chat_placeholder'],
157
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
158
  ):
159
-
160
- progress_bar_pptx = st.progress(0, 'Preparing to run...')
161
  if not text_helper.is_valid_prompt(prompt):
162
  st.error(
163
  'Not enough information provided!'
@@ -190,47 +179,22 @@ def set_up_chat_ui():
190
  }
191
  )
192
 
193
- progress_bar_pptx.progress(5, 'Calling LLM...will retry if connection times out...')
194
- response: dict = llm_helper.hf_api_query({
195
- 'inputs': formatted_template,
196
- 'parameters': {
197
- 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
198
- 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
199
- 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
200
- 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
201
- 'num_return_sequences': 1,
202
- 'return_full_text': False,
203
- # "repetition_penalty": 0.0001
204
- },
205
- 'options': {
206
- 'wait_for_model': True,
207
- 'use_cache': True
208
- }
209
- })
210
-
211
- if len(response) > 0 and 'generated_text' in response[0]:
212
- response: str = response[0]['generated_text'].strip()
213
 
214
- st.chat_message('ai').code(response, language='json')
 
 
 
 
 
215
 
216
  history.add_user_message(prompt)
217
  history.add_ai_message(response)
218
 
219
- # if GlobalConfig.COUNT_TOKENS:
220
- # tokenizer = _get_tokenizer()
221
- # tokens_count_in = len(tokenizer.tokenize(formatted_template))
222
- # tokens_count_out = len(tokenizer.tokenize(response))
223
- # logger.debug(
224
- # 'Tokens count:: input: %d, output: %d',
225
- # tokens_count_in, tokens_count_out
226
- # )
227
-
228
- # _display_messages_history(view_messages)
229
-
230
  # The content has been generated as JSON
231
  # There maybe trailing ``` at the end of the response -- remove them
232
  # To be careful: ``` may be part of the content as well when code is generated
233
- progress_bar_pptx.progress(50, 'Analyzing response...')
234
  response_cleaned = text_helper.get_clean_json(response)
235
 
236
  logger.info(
@@ -240,9 +204,12 @@ def set_up_chat_ui():
240
  logger.debug('Cleaned JSON: %s', response_cleaned)
241
 
242
  # Now create the PPT file
243
- progress_bar_pptx.progress(75, 'Creating the slide deck...give it a moment...')
244
- generate_slide_deck(response_cleaned)
245
- progress_bar_pptx.progress(100, text='Done!')
 
 
 
246
 
247
  logger.info(
248
  '#messages in history / 2: %d',
@@ -250,11 +217,13 @@ def set_up_chat_ui():
250
  )
251
 
252
 
253
- def generate_slide_deck(json_str: str):
254
  """
255
- Create a slide deck.
 
256
 
257
  :param json_str: The content in *valid* JSON format.
 
258
  """
259
 
260
  if DOWNLOAD_FILE_KEY in st.session_state:
@@ -276,17 +245,6 @@ def generate_slide_deck(json_str: str):
276
  output_file_path=path
277
  )
278
  except ValueError:
279
- # st.error(
280
- # f"{APP_TEXT['json_parsing_error']}"
281
- # f"\n\nAdditional error info: {ve}"
282
- # f"\n\nHere are some sample instructions that you could try to possibly fix this error;"
283
- # f" if these don't work, try rephrasing or refreshing:"
284
- # f"\n\n"
285
- # "- Regenerate content and fix the JSON error."
286
- # "\n- Regenerate content and fix the JSON error. Quotes inside quotes should be escaped."
287
- # )
288
- # logger.error('%s', APP_TEXT['json_parsing_error'])
289
- # logger.error('Additional error info: %s', str(ve))
290
  st.error(
291
  'Encountered error while parsing JSON...will fix it and retry'
292
  )
@@ -302,8 +260,8 @@ def generate_slide_deck(json_str: str):
302
  except Exception as ex:
303
  st.error(APP_TEXT['content_generation_error'])
304
  logger.error('Caught a generic exception: %s', str(ex))
305
- finally:
306
- _display_download_button(path)
307
 
308
 
309
  def _is_it_refinement() -> bool:
 
1
+ """
2
+ Streamlit app containing the UI and the application logic.
3
+ """
4
  import datetime
5
  import logging
6
  import pathlib
 
10
 
11
  import json5
12
  import streamlit as st
13
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
 
 
14
  from langchain_core.messages import HumanMessage
15
  from langchain_core.prompts import ChatPromptTemplate
16
 
 
48
  return template
49
 
50
 
51
+ @st.cache_resource
52
+ def _get_llm():
53
+ return llm_helper.get_hf_endpoint()
 
 
 
 
 
 
 
 
54
 
55
 
56
  APP_TEXT = _load_strings()
 
59
  CHAT_MESSAGES = 'chat_messages'
60
  DOWNLOAD_FILE_KEY = 'download_file_name'
61
  IS_IT_REFINEMENT = 'is_it_refinement'
62
+ APPROX_TARGET_LENGTH = GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH / 2
63
+
64
 
65
  logger = logging.getLogger(__name__)
 
66
 
67
  texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
68
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
 
104
  with st.expander('Usage Policies and Limitations'):
105
  display_page_footer_content()
106
 
 
107
  set_up_chat_ui()
108
 
109
 
 
124
  st.chat_message('ai').write(
125
  random.choice(APP_TEXT['ai_greetings'])
126
  )
 
 
127
 
128
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
129
 
 
147
  placeholder=APP_TEXT['chat_placeholder'],
148
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
149
  ):
 
 
150
  if not text_helper.is_valid_prompt(prompt):
151
  st.error(
152
  'Not enough information provided!'
 
179
  }
180
  )
181
 
182
+ progress_bar = st.progress(0, 'Preparing to call LLM...')
183
+ response = ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ for chunk in _get_llm().stream(formatted_template):
186
+ response += chunk
187
+
188
+ # Update the progress bar
189
+ progress_percentage = min(len(response) / APPROX_TARGET_LENGTH, 0.95)
190
+ progress_bar.progress(progress_percentage, text='Streaming content...')
191
 
192
  history.add_user_message(prompt)
193
  history.add_ai_message(response)
194
 
 
 
 
 
 
 
 
 
 
 
 
195
  # The content has been generated as JSON
196
  # There maybe trailing ``` at the end of the response -- remove them
197
  # To be careful: ``` may be part of the content as well when code is generated
 
198
  response_cleaned = text_helper.get_clean_json(response)
199
 
200
  logger.info(
 
204
  logger.debug('Cleaned JSON: %s', response_cleaned)
205
 
206
  # Now create the PPT file
207
+ progress_bar.progress(0.95, text='Searching photos and generating the slide deck...')
208
+ path = generate_slide_deck(response_cleaned)
209
+ progress_bar.progress(1.0, text='Done!')
210
+
211
+ st.chat_message('ai').code(response, language='json')
212
+ _display_download_button(path)
213
 
214
  logger.info(
215
  '#messages in history / 2: %d',
 
217
  )
218
 
219
 
220
+ def generate_slide_deck(json_str: str) -> pathlib.Path:
221
  """
222
+ Create a slide deck and return the file path. In case there is any error creating the slide
223
+ deck, the path may be to an empty file.
224
 
225
  :param json_str: The content in *valid* JSON format.
226
+ :return: The file of the .pptx file.
227
  """
228
 
229
  if DOWNLOAD_FILE_KEY in st.session_state:
 
245
  output_file_path=path
246
  )
247
  except ValueError:
 
 
 
 
 
 
 
 
 
 
 
248
  st.error(
249
  'Encountered error while parsing JSON...will fix it and retry'
250
  )
 
260
  except Exception as ex:
261
  st.error(APP_TEXT['content_generation_error'])
262
  logger.error('Caught a generic exception: %s', str(ex))
263
+
264
+ return path
265
 
266
 
267
  def _is_it_refinement() -> bool: