Spaces:
Running
Running
import datetime | |
import logging | |
import pathlib | |
import random | |
import tempfile | |
from typing import List | |
import json5 | |
import streamlit as st | |
from langchain_community.chat_message_histories import ( | |
StreamlitChatMessageHistory | |
) | |
from langchain_core.messages import HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from transformers import AutoTokenizer | |
from global_config import GlobalConfig | |
from helpers import llm_helper, pptx_helper | |
def _load_strings() -> dict: | |
""" | |
Load various strings to be displayed in the app. | |
:return: The dictionary of strings. | |
""" | |
with open(GlobalConfig.APP_STRINGS_FILE, 'r', encoding='utf-8') as in_file: | |
return json5.loads(in_file.read()) | |
def _get_prompt_template(is_refinement: bool) -> str: | |
""" | |
Return a prompt template. | |
:param is_refinement: Whether this is the initial or refinement prompt. | |
:return: The prompt template as f-string. | |
""" | |
if is_refinement: | |
with open(GlobalConfig.REFINEMENT_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file: | |
template = in_file.read() | |
else: | |
with open(GlobalConfig.INITIAL_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file: | |
template = in_file.read() | |
return template | |
def _get_tokenizer() -> AutoTokenizer: | |
""" | |
Get Mistral tokenizer for counting tokens. | |
:return: The tokenizer. | |
""" | |
return AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path=GlobalConfig.HF_LLM_MODEL_NAME | |
) | |
APP_TEXT = _load_strings() | |
# Session variables | |
CHAT_MESSAGES = 'chat_messages' | |
DOWNLOAD_FILE_KEY = 'download_file_name' | |
IS_IT_REFINEMENT = 'is_it_refinement' | |
logger = logging.getLogger(__name__) | |
progress_bar = st.progress(0, text='Setting up SlideDeck AI...') | |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys()) | |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts] | |
pptx_template = st.sidebar.radio( | |
'Select a presentation template:', | |
texts, | |
captions=captions, | |
horizontal=True | |
) | |
def display_page_header_content(): | |
""" | |
Display content in the page header. | |
""" | |
st.title(APP_TEXT['app_name']) | |
st.subheader(APP_TEXT['caption']) | |
# st.markdown( | |
# '![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501 | |
# ) | |
def display_page_footer_content(): | |
""" | |
Display content in the page footer. | |
""" | |
st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2']) | |
def build_ui(): | |
""" | |
Display the input elements for content generation. | |
""" | |
display_page_header_content() | |
with st.expander('Usage Policies and Limitations'): | |
display_page_footer_content() | |
progress_bar.progress(50, text='Setting up chat interface...') | |
set_up_chat_ui() | |
def set_up_chat_ui(): | |
""" | |
Prepare the chat interface and related functionality. | |
""" | |
with st.expander('Usage Instructions'): | |
st.write(GlobalConfig.CHAT_USAGE_INSTRUCTIONS) | |
st.markdown( | |
'SlideDeck AI is powered by' | |
' [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)' | |
) | |
# view_messages = st.expander('View the messages in the session state') | |
st.chat_message('ai').write( | |
random.choice(APP_TEXT['ai_greetings']) | |
) | |
progress_bar.progress(100, text='Done!') | |
progress_bar.empty() | |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES) | |
if _is_it_refinement(): | |
template = _get_prompt_template(is_refinement=True) | |
logger.debug('Getting refinement template') | |
else: | |
template = _get_prompt_template(is_refinement=False) | |
logger.debug('Getting initial template') | |
prompt_template = ChatPromptTemplate.from_template(template) | |
# Since Streamlit app reloads at every interaction, display the chat history | |
# from the save session state | |
for msg in history.messages: | |
msg_type = msg.type | |
if msg_type == 'user': | |
st.chat_message(msg_type).write(msg.content) | |
else: | |
st.chat_message(msg_type).code(msg.content, language='json') | |
if prompt := st.chat_input( | |
placeholder=APP_TEXT['chat_placeholder'], | |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH | |
): | |
progress_bar_pptx = st.progress(0, 'Preparing to run...') | |
if not _is_valid_prompt(prompt): | |
return | |
logger.info('User input: %s', prompt) | |
st.chat_message('user').write(prompt) | |
user_messages = _get_user_messages() | |
user_messages.append(prompt) | |
list_of_msgs = [ | |
f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages) | |
] | |
list_of_msgs = '\n'.join(list_of_msgs) | |
if _is_it_refinement(): | |
formatted_template = prompt_template.format( | |
**{ | |
'instructions': list_of_msgs, | |
'previous_content': _get_last_response() | |
} | |
) | |
else: | |
formatted_template = prompt_template.format( | |
**{ | |
'question': prompt, | |
} | |
) | |
progress_bar_pptx.progress(5, 'Calling LLM...will retry if connection times out...') | |
response: dict = llm_helper.hf_api_query({ | |
'inputs': formatted_template, | |
'parameters': { | |
'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE, | |
'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH, | |
'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH, | |
'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH, | |
'num_return_sequences': 1, | |
'return_full_text': False, | |
# "repetition_penalty": 0.0001 | |
}, | |
'options': { | |
'wait_for_model': True, | |
'use_cache': True | |
} | |
}) | |
if len(response) > 0 and 'generated_text' in response[0]: | |
response: str = response[0]['generated_text'].strip() | |
st.chat_message('ai').code(response, language='json') | |
history.add_user_message(prompt) | |
history.add_ai_message(response) | |
if GlobalConfig.COUNT_TOKENS: | |
tokenizer = _get_tokenizer() | |
tokens_count_in = len(tokenizer.tokenize(formatted_template)) | |
tokens_count_out = len(tokenizer.tokenize(response)) | |
logger.debug( | |
'Tokens count:: input: %d, output: %d', | |
tokens_count_in, tokens_count_out | |
) | |
# _display_messages_history(view_messages) | |
# The content has been generated as JSON | |
# There maybe trailing ``` at the end of the response -- remove them | |
# To be careful: ``` may be part of the content as well when code is generated | |
progress_bar_pptx.progress(50, 'Analyzing response...') | |
response_cleaned = _clean_json(response) | |
# Now create the PPT file | |
progress_bar_pptx.progress(75, 'Creating the slide deck...give it a moment') | |
generate_slide_deck(response_cleaned) | |
progress_bar_pptx.progress(100, text='Done!') | |
def generate_slide_deck(json_str: str): | |
""" | |
Create a slide deck. | |
:param json_str: The content in *valid* JSON format. | |
""" | |
if DOWNLOAD_FILE_KEY in st.session_state: | |
path = pathlib.Path(st.session_state[DOWNLOAD_FILE_KEY]) | |
logger.debug('DOWNLOAD_FILE_KEY found in session') | |
else: | |
temp = tempfile.NamedTemporaryFile(delete=False, suffix='.pptx') | |
path = pathlib.Path(temp.name) | |
st.session_state[DOWNLOAD_FILE_KEY] = str(path) | |
logger.debug('DOWNLOAD_FILE_KEY not found in session') | |
logger.debug('Creating PPTX file: %s...', st.session_state[DOWNLOAD_FILE_KEY]) | |
try: | |
pptx_helper.generate_powerpoint_presentation( | |
json_str, | |
slides_template=pptx_template, | |
output_file_path=path | |
) | |
_display_download_button(path) | |
except ValueError as ve: | |
st.error(APP_TEXT['json_parsing_error']) | |
logger.error('%s', APP_TEXT['json_parsing_error']) | |
logger.error('Additional error info: %s', str(ve)) | |
except Exception as ex: | |
st.error(APP_TEXT['content_generation_error']) | |
logger.error('Caught a generic exception: %s', str(ex)) | |
def _is_valid_prompt(prompt: str) -> bool: | |
""" | |
Verify whether user input satisfies the concerned constraints. | |
:param prompt: The user input text. | |
:return: True if all criteria are satisfied; False otherwise. | |
""" | |
if len(prompt) < 5 or ' ' not in prompt: | |
st.error( | |
'Not enough information provided!' | |
' Please be a little more descriptive and type a few words with a few characters :)' | |
) | |
return False | |
return True | |
def _is_it_refinement() -> bool: | |
""" | |
Whether it is the initial prompt or a refinement. | |
:return: True if it is the initial prompt; False otherwise. | |
""" | |
if IS_IT_REFINEMENT in st.session_state: | |
return True | |
if len(st.session_state[CHAT_MESSAGES]) >= 2: | |
# Prepare for the next call | |
st.session_state[IS_IT_REFINEMENT] = True | |
return True | |
return False | |
def _get_user_messages() -> List[str]: | |
""" | |
Get a list of user messages submitted until now from the session state. | |
:return: The list of user messages. | |
""" | |
return [ | |
msg.content for msg in st.session_state[CHAT_MESSAGES] if isinstance(msg, HumanMessage) | |
] | |
def _get_last_response() -> str: | |
""" | |
Get the last response generated by AI. | |
:return: The response text. | |
""" | |
return st.session_state[CHAT_MESSAGES][-1].content | |
def _display_messages_history(view_messages: st.expander): | |
""" | |
Display the history of messages. | |
:param view_messages: The list of AI and Human messages. | |
""" | |
with view_messages: | |
view_messages.json(st.session_state[CHAT_MESSAGES]) | |
def _clean_json(json_str: str) -> str: | |
""" | |
Attempt to clean a JSON response string from the LLM by removing the trailing ``` | |
and any text beyond that. May not be always accurate. | |
:param json_str: The input string in JSON format. | |
:return: The "cleaned" JSON string. | |
""" | |
str_len = len(json_str) | |
response_cleaned = json_str | |
try: | |
idx = json_str.rindex('```') | |
logger.debug( | |
'Fixing JSON response: str_len: %d, idx of ```: %d', | |
str_len, idx | |
) | |
if idx + 3 == str_len: | |
# The response ends with ``` -- most likely the end of JSON response string | |
response_cleaned = json_str[:idx] | |
elif idx + 3 < str_len: | |
# Looks like there are some more content beyond the last ``` | |
# In the best case, it would be some additional plain-text response from the LLM | |
# and is unlikely to contain } or ] that are present in JSON | |
if '}' not in json_str[idx + 3:]: # the remainder of the text | |
response_cleaned = json_str[:idx] | |
except ValueError: | |
# No ``` found | |
pass | |
return response_cleaned | |
def _display_download_button(file_path: pathlib.Path): | |
""" | |
Display a download button to download a slide deck. | |
:param file_path: The path of the .pptx file. | |
""" | |
with open(file_path, 'rb') as download_file: | |
st.download_button( | |
'Download PPTX file ⬇️', | |
data=download_file, | |
file_name='Presentation.pptx', | |
key=datetime.datetime.now() | |
) | |
def main(): | |
""" | |
Trigger application run. | |
""" | |
build_ui() | |
if __name__ == '__main__': | |
main() | |