Spaces:
Runtime error
Runtime error
""" | |
Python Backend API to chat with private data | |
08/16/2023 | |
D.M. Theekshana Samaradiwakara | |
""" | |
import os | |
import time | |
import streamlit as st | |
from streamlit.logger import get_logger | |
logger = get_logger(__name__) | |
from ui.htmlTemplates import css, bot_template, user_template, source_template | |
from config import MODELS, DATASETS | |
from qaPipeline import QAPipeline | |
import qaPipeline_functions | |
from faissDb import create_faiss | |
# loads environment variables | |
from dotenv import load_dotenv | |
load_dotenv() | |
isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS') | |
isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS') | |
st.set_page_config(page_title="Chat with data", | |
page_icon=":books:") | |
st.write(css, unsafe_allow_html=True) | |
qaPipeline = QAPipeline() | |
# qaPipeline = qaPipeline_functions | |
def initialize_session_state(): | |
# Initialise all session state variables with defaults | |
SESSION_DEFAULTS = { | |
"model": MODELS["DEFAULT"], | |
"dataset": DATASETS["DEFAULT"], | |
"chat_history": None, | |
"is_parameters_changed":False, | |
"show_source_files": False, | |
"user_question":'', | |
} | |
for k, v in SESSION_DEFAULTS.items(): | |
if k not in st.session_state: | |
st.session_state[k] = v | |
def side_bar(): | |
with st.sidebar: | |
st.subheader("Chat parameters") | |
with st.form('param_form'): | |
st.info('Info: use openai chat model for best results') | |
chat_model = st.selectbox( | |
"Chat model", | |
MODELS, | |
key="chat_model", | |
help="Select the LLM model for the chat", | |
# on_change=update_parameters_change, | |
) | |
# data_source = st.selectbox( | |
# "dataset", | |
# DATASETS, | |
# key="data_source", | |
# help="Select the private data_source for the chat", | |
# on_change=update_parameters_change, | |
# ) | |
st.session_state.dataset = "DEFAULT" | |
show_source = st.checkbox( | |
label="show source files", | |
key="show_source", | |
help="Select this to show relavant source files for the query", | |
# on_change=update_parameters_change, | |
) | |
submitted = st.form_submit_button( | |
"Save Parameters", | |
# on_click=update_parameters_change | |
) | |
if submitted: | |
parameters_change_button(chat_model, show_source) | |
# if st.session_state.is_parameters_changed: | |
# st.button("Update", | |
# on_click=parameters_change_button, | |
# args=[chat_model, show_source] | |
# ) | |
st.markdown("\n") | |
# if st.button("Create FAISS db"): | |
# try: | |
# with st.spinner('creating faiss vector store'): | |
# create_faiss() | |
# st.success('faiss saved') | |
# except Exception as e: | |
# st.error(f"Error : {e}")#, icon=":books:") | |
# return | |
st.markdown( | |
"### How to use\n" | |
"1. Select the chat model\n" # noqa: E501 | |
"2. Select \"show source files\" to show the source files related to the answer.📄\n" | |
"3. Ask a question about the documents💬\n" | |
) | |
def chat_body(): | |
st.header("Chat with your own data:") | |
with st.form('chat_body'): | |
user_question = st.text_input( | |
"Ask a question about your documents:", | |
placeholder="enter question", | |
key='user_question', | |
# on_change=submit_user_question, | |
) | |
submitted = st.form_submit_button( | |
"Submit", | |
# on_click=update_parameters_change | |
) | |
if submitted: | |
submit_user_question() | |
# if user_question: | |
# submit_user_question() | |
# # user_question = False | |
def submit_user_question(): | |
with st.spinner("Processing"): | |
user_question = st.session_state.user_question | |
# st.success(user_question) | |
handle_userinput(user_question) | |
# st.session_state.user_question='' | |
def main(): | |
initialize_session_state() | |
side_bar() | |
chat_body() | |
def update_parameters_change(): | |
st.session_state.is_parameters_changed = True | |
def parameters_change_button(chat_model, show_source): | |
st.session_state.model = chat_model | |
st.session_state.dataset = "DEFAULT" | |
st.session_state.show_source_files = show_source | |
st.session_state.is_parameters_changed = False | |
alert = st.success("chat parameters updated") | |
time.sleep(1) # Wait for 3 seconds | |
alert.empty() # Clear the alert | |
def get_answer_from_backend(query, model, dataset): | |
# response = qaPipeline.run(query=query, model=model, dataset=dataset) | |
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset) | |
return response | |
def show_query_response(query, response, show_source_files): | |
docs = [] | |
if isinstance(response, dict): | |
answer, docs = response['answer'], response['source_documents'] | |
else: | |
answer = response | |
st.write(user_template.replace( | |
"{{MSG}}", query), unsafe_allow_html=True) | |
st.write(bot_template.replace( | |
"{{MSG}}", answer ), unsafe_allow_html=True) | |
if show_source_files: | |
# st.write(source_template.replace( | |
# "{{MSG}}", "source files" ), unsafe_allow_html=True) | |
if len(docs)>0 : | |
st.markdown("#### source files : ") | |
for source in docs: | |
# st.info(source.metadata) | |
with st.expander(source.metadata["source"]): | |
st.markdown(source.page_content) | |
# st.write(response) | |
def is_query_valid(query: str) -> bool: | |
if (not query) or (query.strip() == ''): | |
st.error("Please enter a question!") | |
return False | |
return True | |
def handle_userinput(query): | |
# Get the answer from the chain | |
try: | |
if not is_query_valid(query): | |
st.stop() | |
model = MODELS[st.session_state.model] | |
dataset = DATASETS[st.session_state.dataset] | |
show_source_files = st.session_state.show_source_files | |
# Try to access openai and deeplake | |
print(f">\n model: {model} \n dataset : {dataset} \n show_source_files : {show_source_files}") | |
response = get_answer_from_backend(query, model, dataset) | |
show_query_response(query, response, show_source_files) | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
st.error(f"Error Occured! see log info for more details.")#, icon=":books:") | |
return | |
if __name__ == "__main__": | |
main() | |
# initialize_session_state() | |
# side_bar() | |
# chat_body() |