Spaces:
Runtime error
Runtime error
""" | |
Python Backend API to chat with private data | |
08/16/2023 | |
D.M. Theekshana Samaradiwakara | |
""" | |
import os | |
import time | |
import streamlit as st | |
from streamlit.logger import get_logger | |
logger = get_logger(__name__) | |
from ui.htmlTemplates import css, bot_template, user_template, source_template | |
from config import MODELS, DATASETS | |
from qaPipeline import QAPipeline | |
from faissDb import create_faiss | |
# loads environment variables | |
from dotenv import load_dotenv | |
load_dotenv() | |
isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS') | |
isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS') | |
st.set_page_config(page_title="Chat with data", | |
page_icon=":books:") | |
st.write(css, unsafe_allow_html=True) | |
SESSION_DEFAULTS = { | |
"model": MODELS["DEFAULT"], | |
"dataset": DATASETS["DEFAULT"], | |
"chat_history": None, | |
"is_parameters_changed":False, | |
"show_source_files": False, | |
"user_question":'', | |
} | |
for k, v in SESSION_DEFAULTS.items(): | |
if k not in st.session_state: | |
st.session_state[k] = v | |
with st.sidebar: | |
st.subheader("Chat parameters") | |
with st.form('param_form'): | |
chat_model = st.selectbox( | |
"Chat model", | |
MODELS, | |
key="chat_model", | |
help="Select the LLM model for the chat", | |
# on_change=update_parameters_change, | |
) | |
st.session_state.dataset = "DEFAULT" | |
show_source = st.checkbox( | |
label="show source files", | |
key="show_source", | |
help="Select this to show relavant source files for the query", | |
) | |
submitted = st.form_submit_button( | |
"Submit", | |
# on_click=parameters_change_button, | |
# args=[chat_model, show_source] | |
) | |
# submitted = st.button( | |
# "Submit", | |
# # on_click=parameters_change_button, | |
# # args=[chat_model, show_source] | |
# ) | |
if submitted: | |
st.session_state.model = chat_model | |
st.session_state.dataset = "DEFAULT" | |
st.session_state.show_source_files = show_source | |
st.session_state.is_parameters_changed = False | |
alert = st.success("chat parameters updated") | |
time.sleep(1) # Wait for 3 seconds | |
alert.empty() # Clear the alert | |
st.markdown("\n") | |
# if st.button("Create FAISS db"): | |
# try: | |
# with st.spinner('creating faiss vector store'): | |
# create_faiss() | |
# st.success('faiss saved') | |
# except Exception as e: | |
# st.error(f"Error : {e}")#, icon=":books:") | |
# return | |
st.markdown( | |
"### How to use\n" | |
"1. Select the chat model\n" # noqa: E501 | |
"2. Select \"show source files\" to show the source files related to the answer.📄\n" | |
"3. Ask a question about the documents💬\n" | |
) | |
st.header("Chat with your own data:") | |
# 👈 Add the caching decorator | |
def load_QaPipeline(): | |
print('> QAPipeline loaded for front end') | |
return QAPipeline() | |
qaPipeline = load_QaPipeline() | |
# qaPipeline = QAPipeline() | |
with st.form('chat_body'): | |
user_question = st.text_input( | |
"Ask a question about your documents:", | |
placeholder="enter question", | |
key='user_question', | |
# on_change=submit_user_question, | |
) | |
submitted = st.form_submit_button( | |
"Submit", | |
# on_click=submit_user_question | |
) | |
if submitted: | |
with st.spinner("Processing"): | |
user_question = st.session_state.user_question | |
# st.success(user_question) | |
query = user_question | |
# st.session_state.user_question='' | |
# Get the answer from the chain | |
try: | |
if (not query) or (query.strip() == ''): | |
st.error("Please enter a question!") | |
st.stop() | |
model = MODELS[st.session_state.model] | |
dataset = DATASETS[st.session_state.dataset] | |
show_source_files = st.session_state.show_source_files | |
# Try to access openai and deeplake | |
print(f">\n model: {model} \n dataset : {dataset} \n show_source_files : {show_source_files}") | |
# response = qaPipeline.run(query=query, model=model, dataset=dataset) | |
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset) | |
docs = [] | |
if isinstance(response, dict): | |
answer, docs = response['answer'], response['source_documents'] | |
else: | |
answer = response | |
st.write(user_template.replace( | |
"{{MSG}}", query), unsafe_allow_html=True) | |
st.write(bot_template.replace( | |
"{{MSG}}", answer ), unsafe_allow_html=True) | |
if show_source_files: | |
# st.write(source_template.replace( | |
# "{{MSG}}", "source files" ), unsafe_allow_html=True) | |
if len(docs)>0 : | |
st.markdown("#### source files : ") | |
for source in docs: | |
# st.info(source.metadata) | |
with st.expander(source.metadata["source"]): | |
st.markdown(source.page_content) | |
# st.write(response) | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
st.error(f"Error : {e}")#, icon=":books:") | |