Spaces:
Sleeping
Sleeping
File size: 5,686 Bytes
b7bb8ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import streamlit as st
import os
import embed_pdf
import shutil
def clear_directory(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
def clear_pdf_files(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) and file_path.endswith('.pdf'):
os.remove(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
# clear_pdf_files("pdf")
# clear_directory("index")
# create sidebar and ask for openai api key if not set in secrets
secrets_file_path = os.path.join(".streamlit", "secrets.toml")
# if os.path.exists(secrets_file_path):
# try:
# if "OPENAI_API_KEY" in st.secrets:
# os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# else:
# print("OpenAI API Key not found in environment variables")
# except FileNotFoundError:
# print('Secrets file not found')
# else:
# print('Secrets file not found')
# if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"):
# os.environ["OPENAI_API_KEY"] = st.sidebar.text_input(
# "OpenAI API Key", type="password"
# )
# else:
# if st.sidebar.button("Embed Documents"):
# st.sidebar.info("Embedding documents...")
# try:
# embed_pdf.embed_all_pdf_docs()
# st.sidebar.info("Done!")
# except Exception as e:
# st.sidebar.error(e)
# st.sidebar.error("Failed to embed documents.")
os.environ["OPENAI_API_KEY"] = st.sidebar.text_input(
"OpenAI API Key", type="password"
)
uploaded_file = st.sidebar.file_uploader("Upload Document", type=['pdf', 'docx'], disabled=False)
if uploaded_file is None:
file_uploaded_bool = False
else:
file_uploaded_bool = True
if st.sidebar.button("Embed Documents", disabled=not file_uploaded_bool):
st.sidebar.info("Embedding documents...")
try:
embed_pdf.embed_all_inputed_pdf_docs(uploaded_file)
# embed_pdf.embed_all_pdf_docs()
st.sidebar.info("Done!")
except Exception as e:
st.sidebar.error(e)
st.sidebar.error("Failed to embed documents.")
# create the app
st.title("Chat with your PDF")
# chosen_file = st.radio(
# "Choose a file to search", embed_pdf.get_all_index_files(), index=0
# )
# check if openai api key is set
if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"):
st.warning("Please enter your OpenAI API key!", icon="⚠")
st.stop()
# load the agent
from llm_helper import convert_message, get_rag_chain, get_rag_fusion_chain
rag_method_map = {
'Basic RAG': get_rag_chain,
'RAG Fusion': get_rag_fusion_chain
}
chosen_rag_method = st.radio(
"Choose a RAG method", rag_method_map.keys(), index=0
)
get_rag_chain_func = rag_method_map[chosen_rag_method]
## get the chain WITHOUT the retrieval callback (not used)
# custom_chain = get_rag_chain_func(chosen_file)
# create the message history state
if "messages" not in st.session_state:
st.session_state.messages = []
# render older messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# render the chat input
prompt = st.chat_input("Enter your message...")
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
# render the user's new message
with st.chat_message("user"):
st.markdown(prompt)
# render the assistant's response
with st.chat_message("assistant"):
retrival_container = st.container()
message_placeholder = st.empty()
# retrieval_status = retrival_container.status("**Context Retrieval**")
queried_questions = []
rendered_questions = set()
def update_retrieval_status():
for q in queried_questions:
if q in rendered_questions:
continue
rendered_questions.add(q)
# retrieval_status.markdown(f"\n\n`- {q}`")
retrival_container.markdown(f"\n\n`- {q}`")
def retrieval_cb(qs):
for q in qs:
if q not in queried_questions:
queried_questions.append(q)
return qs
# get the chain with the retrieval callback
custom_chain = get_rag_chain_func(uploaded_file.name, retrieval_cb=retrieval_cb)
if "messages" in st.session_state:
chat_history = [convert_message(m) for m in st.session_state.messages[:-1]]
else:
chat_history = []
full_response = ""
for response in custom_chain.stream(
{"input": prompt, "chat_history": chat_history}
):
if "output" in response:
full_response += response["output"]
else:
full_response += response.content
message_placeholder.markdown(full_response + "▌")
update_retrieval_status()
# retrival_container.update(state="complete")
# retrieval_status.update(state="complete")
message_placeholder.markdown(full_response)
# add the full response to the message history
st.session_state.messages.append({"role": "assistant", "content": full_response})
|