RAG_UI / app.py
darthPanda's picture
hooha
b7bb8ad
raw
history blame
5.69 kB
import streamlit as st
import os
import embed_pdf
import shutil
def clear_directory(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
def clear_pdf_files(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) and file_path.endswith('.pdf'):
os.remove(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
# clear_pdf_files("pdf")
# clear_directory("index")
# create sidebar and ask for openai api key if not set in secrets
secrets_file_path = os.path.join(".streamlit", "secrets.toml")
# if os.path.exists(secrets_file_path):
# try:
# if "OPENAI_API_KEY" in st.secrets:
# os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# else:
# print("OpenAI API Key not found in environment variables")
# except FileNotFoundError:
# print('Secrets file not found')
# else:
# print('Secrets file not found')
# if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"):
# os.environ["OPENAI_API_KEY"] = st.sidebar.text_input(
# "OpenAI API Key", type="password"
# )
# else:
# if st.sidebar.button("Embed Documents"):
# st.sidebar.info("Embedding documents...")
# try:
# embed_pdf.embed_all_pdf_docs()
# st.sidebar.info("Done!")
# except Exception as e:
# st.sidebar.error(e)
# st.sidebar.error("Failed to embed documents.")
os.environ["OPENAI_API_KEY"] = st.sidebar.text_input(
"OpenAI API Key", type="password"
)
uploaded_file = st.sidebar.file_uploader("Upload Document", type=['pdf', 'docx'], disabled=False)
if uploaded_file is None:
file_uploaded_bool = False
else:
file_uploaded_bool = True
if st.sidebar.button("Embed Documents", disabled=not file_uploaded_bool):
st.sidebar.info("Embedding documents...")
try:
embed_pdf.embed_all_inputed_pdf_docs(uploaded_file)
# embed_pdf.embed_all_pdf_docs()
st.sidebar.info("Done!")
except Exception as e:
st.sidebar.error(e)
st.sidebar.error("Failed to embed documents.")
# create the app
st.title("Chat with your PDF")
# chosen_file = st.radio(
# "Choose a file to search", embed_pdf.get_all_index_files(), index=0
# )
# check if openai api key is set
if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"):
st.warning("Please enter your OpenAI API key!", icon="⚠")
st.stop()
# load the agent
from llm_helper import convert_message, get_rag_chain, get_rag_fusion_chain
rag_method_map = {
'Basic RAG': get_rag_chain,
'RAG Fusion': get_rag_fusion_chain
}
chosen_rag_method = st.radio(
"Choose a RAG method", rag_method_map.keys(), index=0
)
get_rag_chain_func = rag_method_map[chosen_rag_method]
## get the chain WITHOUT the retrieval callback (not used)
# custom_chain = get_rag_chain_func(chosen_file)
# create the message history state
if "messages" not in st.session_state:
st.session_state.messages = []
# render older messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# render the chat input
prompt = st.chat_input("Enter your message...")
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
# render the user's new message
with st.chat_message("user"):
st.markdown(prompt)
# render the assistant's response
with st.chat_message("assistant"):
retrival_container = st.container()
message_placeholder = st.empty()
# retrieval_status = retrival_container.status("**Context Retrieval**")
queried_questions = []
rendered_questions = set()
def update_retrieval_status():
for q in queried_questions:
if q in rendered_questions:
continue
rendered_questions.add(q)
# retrieval_status.markdown(f"\n\n`- {q}`")
retrival_container.markdown(f"\n\n`- {q}`")
def retrieval_cb(qs):
for q in qs:
if q not in queried_questions:
queried_questions.append(q)
return qs
# get the chain with the retrieval callback
custom_chain = get_rag_chain_func(uploaded_file.name, retrieval_cb=retrieval_cb)
if "messages" in st.session_state:
chat_history = [convert_message(m) for m in st.session_state.messages[:-1]]
else:
chat_history = []
full_response = ""
for response in custom_chain.stream(
{"input": prompt, "chat_history": chat_history}
):
if "output" in response:
full_response += response["output"]
else:
full_response += response.content
message_placeholder.markdown(full_response + "▌")
update_retrieval_status()
# retrival_container.update(state="complete")
# retrieval_status.update(state="complete")
message_placeholder.markdown(full_response)
# add the full response to the message history
st.session_state.messages.append({"role": "assistant", "content": full_response})