Spaces:
Running
Running
Merge pull request #23 from barun-saha/visual
Browse files
app.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import datetime
|
2 |
import logging
|
3 |
import pathlib
|
@@ -7,9 +10,7 @@ from typing import List
|
|
7 |
|
8 |
import json5
|
9 |
import streamlit as st
|
10 |
-
from langchain_community.chat_message_histories import
|
11 |
-
StreamlitChatMessageHistory
|
12 |
-
)
|
13 |
from langchain_core.messages import HumanMessage
|
14 |
from langchain_core.prompts import ChatPromptTemplate
|
15 |
|
@@ -47,17 +48,9 @@ def _get_prompt_template(is_refinement: bool) -> str:
|
|
47 |
return template
|
48 |
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# Get Mistral tokenizer for counting tokens.
|
54 |
-
#
|
55 |
-
# :return: The tokenizer.
|
56 |
-
# """
|
57 |
-
#
|
58 |
-
# return AutoTokenizer.from_pretrained(
|
59 |
-
# pretrained_model_name_or_path=GlobalConfig.HF_LLM_MODEL_NAME
|
60 |
-
# )
|
61 |
|
62 |
|
63 |
APP_TEXT = _load_strings()
|
@@ -66,9 +59,10 @@ APP_TEXT = _load_strings()
|
|
66 |
CHAT_MESSAGES = 'chat_messages'
|
67 |
DOWNLOAD_FILE_KEY = 'download_file_name'
|
68 |
IS_IT_REFINEMENT = 'is_it_refinement'
|
|
|
|
|
69 |
|
70 |
logger = logging.getLogger(__name__)
|
71 |
-
progress_bar = st.progress(0, text='Setting up SlideDeck AI...')
|
72 |
|
73 |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
|
74 |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
|
@@ -110,7 +104,6 @@ def build_ui():
|
|
110 |
with st.expander('Usage Policies and Limitations'):
|
111 |
display_page_footer_content()
|
112 |
|
113 |
-
progress_bar.progress(50, text='Setting up chat interface...')
|
114 |
set_up_chat_ui()
|
115 |
|
116 |
|
@@ -131,8 +124,6 @@ def set_up_chat_ui():
|
|
131 |
st.chat_message('ai').write(
|
132 |
random.choice(APP_TEXT['ai_greetings'])
|
133 |
)
|
134 |
-
progress_bar.progress(100, text='Done!')
|
135 |
-
progress_bar.empty()
|
136 |
|
137 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
138 |
|
@@ -156,8 +147,6 @@ def set_up_chat_ui():
|
|
156 |
placeholder=APP_TEXT['chat_placeholder'],
|
157 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
158 |
):
|
159 |
-
|
160 |
-
progress_bar_pptx = st.progress(0, 'Preparing to run...')
|
161 |
if not text_helper.is_valid_prompt(prompt):
|
162 |
st.error(
|
163 |
'Not enough information provided!'
|
@@ -190,47 +179,22 @@ def set_up_chat_ui():
|
|
190 |
}
|
191 |
)
|
192 |
|
193 |
-
|
194 |
-
response
|
195 |
-
'inputs': formatted_template,
|
196 |
-
'parameters': {
|
197 |
-
'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
|
198 |
-
'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
|
199 |
-
'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
|
200 |
-
'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
|
201 |
-
'num_return_sequences': 1,
|
202 |
-
'return_full_text': False,
|
203 |
-
# "repetition_penalty": 0.0001
|
204 |
-
},
|
205 |
-
'options': {
|
206 |
-
'wait_for_model': True,
|
207 |
-
'use_cache': True
|
208 |
-
}
|
209 |
-
})
|
210 |
-
|
211 |
-
if len(response) > 0 and 'generated_text' in response[0]:
|
212 |
-
response: str = response[0]['generated_text'].strip()
|
213 |
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
history.add_user_message(prompt)
|
217 |
history.add_ai_message(response)
|
218 |
|
219 |
-
# if GlobalConfig.COUNT_TOKENS:
|
220 |
-
# tokenizer = _get_tokenizer()
|
221 |
-
# tokens_count_in = len(tokenizer.tokenize(formatted_template))
|
222 |
-
# tokens_count_out = len(tokenizer.tokenize(response))
|
223 |
-
# logger.debug(
|
224 |
-
# 'Tokens count:: input: %d, output: %d',
|
225 |
-
# tokens_count_in, tokens_count_out
|
226 |
-
# )
|
227 |
-
|
228 |
-
# _display_messages_history(view_messages)
|
229 |
-
|
230 |
# The content has been generated as JSON
|
231 |
# There maybe trailing ``` at the end of the response -- remove them
|
232 |
# To be careful: ``` may be part of the content as well when code is generated
|
233 |
-
progress_bar_pptx.progress(50, 'Analyzing response...')
|
234 |
response_cleaned = text_helper.get_clean_json(response)
|
235 |
|
236 |
logger.info(
|
@@ -240,9 +204,12 @@ def set_up_chat_ui():
|
|
240 |
logger.debug('Cleaned JSON: %s', response_cleaned)
|
241 |
|
242 |
# Now create the PPT file
|
243 |
-
|
244 |
-
generate_slide_deck(response_cleaned)
|
245 |
-
|
|
|
|
|
|
|
246 |
|
247 |
logger.info(
|
248 |
'#messages in history / 2: %d',
|
@@ -250,11 +217,13 @@ def set_up_chat_ui():
|
|
250 |
)
|
251 |
|
252 |
|
253 |
-
def generate_slide_deck(json_str: str):
|
254 |
"""
|
255 |
-
Create a slide deck.
|
|
|
256 |
|
257 |
:param json_str: The content in *valid* JSON format.
|
|
|
258 |
"""
|
259 |
|
260 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
@@ -276,17 +245,6 @@ def generate_slide_deck(json_str: str):
|
|
276 |
output_file_path=path
|
277 |
)
|
278 |
except ValueError:
|
279 |
-
# st.error(
|
280 |
-
# f"{APP_TEXT['json_parsing_error']}"
|
281 |
-
# f"\n\nAdditional error info: {ve}"
|
282 |
-
# f"\n\nHere are some sample instructions that you could try to possibly fix this error;"
|
283 |
-
# f" if these don't work, try rephrasing or refreshing:"
|
284 |
-
# f"\n\n"
|
285 |
-
# "- Regenerate content and fix the JSON error."
|
286 |
-
# "\n- Regenerate content and fix the JSON error. Quotes inside quotes should be escaped."
|
287 |
-
# )
|
288 |
-
# logger.error('%s', APP_TEXT['json_parsing_error'])
|
289 |
-
# logger.error('Additional error info: %s', str(ve))
|
290 |
st.error(
|
291 |
'Encountered error while parsing JSON...will fix it and retry'
|
292 |
)
|
@@ -302,8 +260,8 @@ def generate_slide_deck(json_str: str):
|
|
302 |
except Exception as ex:
|
303 |
st.error(APP_TEXT['content_generation_error'])
|
304 |
logger.error('Caught a generic exception: %s', str(ex))
|
305 |
-
|
306 |
-
|
307 |
|
308 |
|
309 |
def _is_it_refinement() -> bool:
|
|
|
1 |
+
"""
|
2 |
+
Streamlit app containing the UI and the application logic.
|
3 |
+
"""
|
4 |
import datetime
|
5 |
import logging
|
6 |
import pathlib
|
|
|
10 |
|
11 |
import json5
|
12 |
import streamlit as st
|
13 |
+
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
|
|
|
|
14 |
from langchain_core.messages import HumanMessage
|
15 |
from langchain_core.prompts import ChatPromptTemplate
|
16 |
|
|
|
48 |
return template
|
49 |
|
50 |
|
51 |
+
@st.cache_resource
|
52 |
+
def _get_llm():
|
53 |
+
return llm_helper.get_hf_endpoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
APP_TEXT = _load_strings()
|
|
|
59 |
CHAT_MESSAGES = 'chat_messages'
|
60 |
DOWNLOAD_FILE_KEY = 'download_file_name'
|
61 |
IS_IT_REFINEMENT = 'is_it_refinement'
|
62 |
+
APPROX_TARGET_LENGTH = GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH / 2
|
63 |
+
|
64 |
|
65 |
logger = logging.getLogger(__name__)
|
|
|
66 |
|
67 |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
|
68 |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
|
|
|
104 |
with st.expander('Usage Policies and Limitations'):
|
105 |
display_page_footer_content()
|
106 |
|
|
|
107 |
set_up_chat_ui()
|
108 |
|
109 |
|
|
|
124 |
st.chat_message('ai').write(
|
125 |
random.choice(APP_TEXT['ai_greetings'])
|
126 |
)
|
|
|
|
|
127 |
|
128 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
129 |
|
|
|
147 |
placeholder=APP_TEXT['chat_placeholder'],
|
148 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
149 |
):
|
|
|
|
|
150 |
if not text_helper.is_valid_prompt(prompt):
|
151 |
st.error(
|
152 |
'Not enough information provided!'
|
|
|
179 |
}
|
180 |
)
|
181 |
|
182 |
+
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
183 |
+
response = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
for chunk in _get_llm().stream(formatted_template):
|
186 |
+
response += chunk
|
187 |
+
|
188 |
+
# Update the progress bar
|
189 |
+
progress_percentage = min(len(response) / APPROX_TARGET_LENGTH, 0.95)
|
190 |
+
progress_bar.progress(progress_percentage, text='Streaming content...')
|
191 |
|
192 |
history.add_user_message(prompt)
|
193 |
history.add_ai_message(response)
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
# The content has been generated as JSON
|
196 |
# There maybe trailing ``` at the end of the response -- remove them
|
197 |
# To be careful: ``` may be part of the content as well when code is generated
|
|
|
198 |
response_cleaned = text_helper.get_clean_json(response)
|
199 |
|
200 |
logger.info(
|
|
|
204 |
logger.debug('Cleaned JSON: %s', response_cleaned)
|
205 |
|
206 |
# Now create the PPT file
|
207 |
+
progress_bar.progress(0.95, text='Searching photos and generating the slide deck...')
|
208 |
+
path = generate_slide_deck(response_cleaned)
|
209 |
+
progress_bar.progress(1.0, text='Done!')
|
210 |
+
|
211 |
+
st.chat_message('ai').code(response, language='json')
|
212 |
+
_display_download_button(path)
|
213 |
|
214 |
logger.info(
|
215 |
'#messages in history / 2: %d',
|
|
|
217 |
)
|
218 |
|
219 |
|
220 |
+
def generate_slide_deck(json_str: str) -> pathlib.Path:
|
221 |
"""
|
222 |
+
Create a slide deck and return the file path. In case there is any error creating the slide
|
223 |
+
deck, the path may be to an empty file.
|
224 |
|
225 |
:param json_str: The content in *valid* JSON format.
|
226 |
+
:return: The file of the .pptx file.
|
227 |
"""
|
228 |
|
229 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
|
|
245 |
output_file_path=path
|
246 |
)
|
247 |
except ValueError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
st.error(
|
249 |
'Encountered error while parsing JSON...will fix it and retry'
|
250 |
)
|
|
|
260 |
except Exception as ex:
|
261 |
st.error(APP_TEXT['content_generation_error'])
|
262 |
logger.error('Caught a generic exception: %s', str(ex))
|
263 |
+
|
264 |
+
return path
|
265 |
|
266 |
|
267 |
def _is_it_refinement() -> bool:
|