Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import embed_pdf | |
import shutil | |
def clear_directory(directory): | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
def clear_pdf_files(directory): | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
try: | |
if os.path.isfile(file_path) and file_path.endswith('.pdf'): | |
os.remove(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
# clear_pdf_files("pdf") | |
# clear_directory("index") | |
# create sidebar and ask for openai api key if not set in secrets | |
secrets_file_path = os.path.join(".streamlit", "secrets.toml") | |
# if os.path.exists(secrets_file_path): | |
# try: | |
# if "OPENAI_API_KEY" in st.secrets: | |
# os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
# else: | |
# print("OpenAI API Key not found in environment variables") | |
# except FileNotFoundError: | |
# print('Secrets file not found') | |
# else: | |
# print('Secrets file not found') | |
# if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"): | |
# os.environ["OPENAI_API_KEY"] = st.sidebar.text_input( | |
# "OpenAI API Key", type="password" | |
# ) | |
# else: | |
# if st.sidebar.button("Embed Documents"): | |
# st.sidebar.info("Embedding documents...") | |
# try: | |
# embed_pdf.embed_all_pdf_docs() | |
# st.sidebar.info("Done!") | |
# except Exception as e: | |
# st.sidebar.error(e) | |
# st.sidebar.error("Failed to embed documents.") | |
os.environ["OPENAI_API_KEY"] = st.sidebar.text_input( | |
"OpenAI API Key", type="password" | |
) | |
uploaded_file = st.sidebar.file_uploader("Upload Document", type=['pdf', 'docx'], disabled=False) | |
if uploaded_file is None: | |
file_uploaded_bool = False | |
else: | |
file_uploaded_bool = True | |
if st.sidebar.button("Embed Documents", disabled=not file_uploaded_bool): | |
st.sidebar.info("Embedding documents...") | |
try: | |
embed_pdf.embed_all_inputed_pdf_docs(uploaded_file) | |
# embed_pdf.embed_all_pdf_docs() | |
st.sidebar.info("Done!") | |
except Exception as e: | |
st.sidebar.error(e) | |
st.sidebar.error("Failed to embed documents.") | |
# create the app | |
st.title("Chat with your PDF") | |
# chosen_file = st.radio( | |
# "Choose a file to search", embed_pdf.get_all_index_files(), index=0 | |
# ) | |
# check if openai api key is set | |
if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"): | |
st.warning("Please enter your OpenAI API key!", icon="⚠") | |
st.stop() | |
# load the agent | |
from llm_helper import convert_message, get_rag_chain, get_rag_fusion_chain | |
rag_method_map = { | |
'Basic RAG': get_rag_chain, | |
'RAG Fusion': get_rag_fusion_chain | |
} | |
chosen_rag_method = st.radio( | |
"Choose a RAG method", rag_method_map.keys(), index=0 | |
) | |
get_rag_chain_func = rag_method_map[chosen_rag_method] | |
## get the chain WITHOUT the retrieval callback (not used) | |
# custom_chain = get_rag_chain_func(chosen_file) | |
# create the message history state | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# render older messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# render the chat input | |
prompt = st.chat_input("Enter your message...") | |
if prompt: | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# render the user's new message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# render the assistant's response | |
with st.chat_message("assistant"): | |
retrival_container = st.container() | |
message_placeholder = st.empty() | |
# retrieval_status = retrival_container.status("**Context Retrieval**") | |
queried_questions = [] | |
rendered_questions = set() | |
def update_retrieval_status(): | |
for q in queried_questions: | |
if q in rendered_questions: | |
continue | |
rendered_questions.add(q) | |
# retrieval_status.markdown(f"\n\n`- {q}`") | |
retrival_container.markdown(f"\n\n`- {q}`") | |
def retrieval_cb(qs): | |
for q in qs: | |
if q not in queried_questions: | |
queried_questions.append(q) | |
return qs | |
# get the chain with the retrieval callback | |
custom_chain = get_rag_chain_func(uploaded_file.name, retrieval_cb=retrieval_cb) | |
if "messages" in st.session_state: | |
chat_history = [convert_message(m) for m in st.session_state.messages[:-1]] | |
else: | |
chat_history = [] | |
full_response = "" | |
for response in custom_chain.stream( | |
{"input": prompt, "chat_history": chat_history} | |
): | |
if "output" in response: | |
full_response += response["output"] | |
else: | |
full_response += response.content | |
message_placeholder.markdown(full_response + "▌") | |
update_retrieval_status() | |
# retrival_container.update(state="complete") | |
# retrieval_status.update(state="complete") | |
message_placeholder.markdown(full_response) | |
# add the full response to the message history | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |