Spaces:
No application file
No application file
import os | |
import queue | |
import re | |
import tempfile | |
import threading | |
import streamlit as st | |
from embedchain import App | |
from embedchain.config import BaseLlmConfig | |
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield, | |
generate) | |
def embedchain_bot(db_path, api_key): | |
return App.from_config( | |
config={ | |
"llm": { | |
"provider": "openai", | |
"config": { | |
"model": "gpt-3.5-turbo-1106", | |
"temperature": 0.5, | |
"max_tokens": 1000, | |
"top_p": 1, | |
"stream": True, | |
"api_key": api_key, | |
}, | |
}, | |
"vectordb": { | |
"provider": "chroma", | |
"config": {"collection_name": "chat-pdf", "dir": db_path, "allow_reset": True}, | |
}, | |
"embedder": {"provider": "openai", "config": {"api_key": api_key}}, | |
"chunker": {"chunk_size": 2000, "chunk_overlap": 0, "length_function": "len"}, | |
} | |
) | |
def get_db_path(): | |
tmpdirname = tempfile.mkdtemp() | |
return tmpdirname | |
def get_ec_app(api_key): | |
if "app" in st.session_state: | |
print("Found app in session state") | |
app = st.session_state.app | |
else: | |
print("Creating app") | |
db_path = get_db_path() | |
app = embedchain_bot(db_path, api_key) | |
st.session_state.app = app | |
return app | |
with st.sidebar: | |
openai_access_token = st.text_input("OpenAI API Key", key="api_key", type="password") | |
"WE DO NOT STORE YOUR OPENAI KEY." | |
"Just paste your OpenAI API key here and we'll use it to power the chatbot. [Get your OpenAI API key](https://platform.openai.com/api-keys)" # noqa: E501 | |
if st.session_state.api_key: | |
app = get_ec_app(st.session_state.api_key) | |
pdf_files = st.file_uploader("Upload your PDF files", accept_multiple_files=True, type="pdf") | |
add_pdf_files = st.session_state.get("add_pdf_files", []) | |
for pdf_file in pdf_files: | |
file_name = pdf_file.name | |
if file_name in add_pdf_files: | |
continue | |
try: | |
if not st.session_state.api_key: | |
st.error("Please enter your OpenAI API Key") | |
st.stop() | |
temp_file_name = None | |
with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=file_name, suffix=".pdf") as f: | |
f.write(pdf_file.getvalue()) | |
temp_file_name = f.name | |
if temp_file_name: | |
st.markdown(f"Adding {file_name} to knowledge base...") | |
app.add(temp_file_name, data_type="pdf_file") | |
st.markdown("") | |
add_pdf_files.append(file_name) | |
os.remove(temp_file_name) | |
st.session_state.messages.append({"role": "assistant", "content": f"Added {file_name} to knowledge base!"}) | |
except Exception as e: | |
st.error(f"Error adding {file_name} to knowledge base: {e}") | |
st.stop() | |
st.session_state["add_pdf_files"] = add_pdf_files | |
st.title("π Embedchain - Chat with PDF") | |
styled_caption = '<p style="font-size: 17px; color: #aaa;">π An <a href="https://github.com/embedchain/embedchain">Embedchain</a> app powered by OpenAI!</p>' # noqa: E501 | |
st.markdown(styled_caption, unsafe_allow_html=True) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [ | |
{ | |
"role": "assistant", | |
"content": """ | |
Hi! I'm chatbot powered by Embedchain, which can answer questions about your pdf documents.\n | |
Upload your pdf documents here and I'll answer your questions about them! | |
""", | |
} | |
] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("Ask me anything!"): | |
if not st.session_state.api_key: | |
st.error("Please enter your OpenAI API Key", icon="π€") | |
st.stop() | |
app = get_ec_app(st.session_state.api_key) | |
with st.chat_message("user"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
msg_placeholder = st.empty() | |
msg_placeholder.markdown("Thinking...") | |
full_response = "" | |
q = queue.Queue() | |
def app_response(result): | |
llm_config = app.llm.config.as_dict() | |
llm_config["callbacks"] = [StreamingStdOutCallbackHandlerYield(q=q)] | |
config = BaseLlmConfig(**llm_config) | |
answer, citations = app.chat(prompt, config=config, citations=True) | |
result["answer"] = answer | |
result["citations"] = citations | |
results = {} | |
thread = threading.Thread(target=app_response, args=(results,)) | |
thread.start() | |
for answer_chunk in generate(q): | |
full_response += answer_chunk | |
msg_placeholder.markdown(full_response) | |
thread.join() | |
answer, citations = results["answer"], results["citations"] | |
if citations: | |
full_response += "\n\n**Sources**:\n" | |
sources = [] | |
for i, citation in enumerate(citations): | |
source = citation[1]["url"] | |
pattern = re.compile(r"([^/]+)\.[^\.]+\.pdf$") | |
match = pattern.search(source) | |
if match: | |
source = match.group(1) + ".pdf" | |
sources.append(source) | |
sources = list(set(sources)) | |
for source in sources: | |
full_response += f"- {source}\n" | |
msg_placeholder.markdown(full_response) | |
print("Answer: ", full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |