Spaces:
Sleeping
Sleeping
import openai | |
import gradio as gr | |
import os | |
from langchain.document_loaders import UnstructuredFileLoader | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ChatOpenAI | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
class DocumentManager: | |
def __init__(self): | |
self.api_key = None | |
self.citation = "" | |
self.docs = [] | |
self.retriever = None | |
self.files = [] | |
self.provide_citation = True | |
self.source_documents = [] | |
self.user_prompt = "Be direct and cite your sources." | |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100) | |
self.check_for_api_key_env() | |
def check_for_api_key_env(self): | |
if "OPENAI_API_KEY" in os.environ: | |
self.set_api_key(os.environ["OPENAI_API_KEY"]) | |
def set_api_key(self, value): | |
if (value is None) or (not value.startswith("sk")): | |
gr.Warning("Please enter a valid OpenAI API key.") | |
return self.create_api_key_status_display() | |
self.api_key = value | |
openai.api_key = self.api_key | |
if len(self.docs) > 0: | |
documents = self.text_splitter.split_documents(self.docs) | |
self.retriever = Chroma.from_documents(documents, OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) | |
else: | |
self.retriever = Chroma(embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) | |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key) | |
self.qa = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True) | |
return self.create_api_key_status_display() | |
def create_api_key_status_display(self): | |
if self.api_key is None: | |
return gr.Textbox("❌ Please enter an API key.", label=None, interactive=False, container=False) | |
else: | |
return gr.Textbox(f"✅ API key: {self.api_key[:9]}", label=None, interactive=False, container=False) | |
def get_user_prompt(self): | |
return self.user_prompt | |
def set_user_prompt(self, value): | |
self.user_prompt = value | |
def set_provide_citation(self, value): | |
self.provide_citation = value | |
def delete_files(self): | |
self.docs = [] | |
self.files = [] | |
self.source_documents = [] | |
self.db = Chroma(embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)) | |
self.db._client_settings.allow_reset = True | |
self.db._client.reset() | |
self.retreiver = self.db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) | |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key) | |
self.qa = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=Chroma( | |
embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)) | |
.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True), | |
return_source_documents=True) | |
return gr.Markdown(self.generate_file_markdown(), label="Uploaded files") | |
def reset_qa(self): | |
documents = self.text_splitter.split_documents(self.docs) | |
self.retriever = Chroma.from_documents(documents, OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) | |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key) | |
self.qa = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True) | |
def tokenize_doc(self, filepath): | |
loader = UnstructuredFileLoader(filepath) | |
doc = loader.load() | |
self.docs.extend(doc) | |
def update_citation(self): | |
if self.provide_citation and self.api_key: | |
summed = "" | |
for doc in self.source_documents: | |
summed += doc.page_content | |
self.citation = self.llm.predict( | |
"Question: " + self.question + ". Answer: " + self.result + ". Citation: " + summed + | |
". From the citation, return the relevant passage and the exact articles") | |
else: | |
self.citation = "" | |
return self.citation | |
def predict(self, message, history): | |
if self.api_key is None: | |
gr.Warning("Please enter an OpenAI API key in the settings tab.") | |
return "", [] | |
if history is None: | |
history = [] | |
summed_history = " ".join(sum(history, [])) | |
question = "You have access to these documents:" + self.generate_file_markdown() + ". Do not make things up, only say what you have a primary source document for. ---- CHAT HISTORY : " + summed_history + " --- SYSTEM PROMPT: " + self.user_prompt + " -- Answer this question: " + message | |
print(self.retriever.vectorstore.get()) | |
result = self.qa({"query": question}) | |
self.source_documents = result["source_documents"] | |
self.result = result["result"] | |
self.question = question | |
history.append([message, ""]) | |
for message in result["result"]: | |
history[-1][1] += message | |
yield "", history | |
def generate_file_markdown(self): | |
files_md = "" | |
for file in self.files: | |
filename = file.split("/")[-1] | |
files_md += "- " + filename + "\n" | |
return files_md | |
def upload_file(self, files): | |
if self.api_key is None: | |
gr.Warning("Please enter an OpenAI API key.") | |
return self.files | |
for file in files: | |
self.tokenize_doc(file.name) | |
filepaths = [file.orig_name for file in files] | |
self.files = filepaths + self.files | |
self.reset_qa() | |
return gr.Markdown(self.generate_file_markdown(), label="Uploaded files") | |
def create_delete_button(self, value): | |
if value and self.api_key: | |
return gr.Button("Delete files", scale=4, interactive=True) | |
else: | |
return gr.Button("Delete files", scale=4, interactive=False) | |
def create_demo(): | |
doc_manager = DocumentManager() | |
with gr.Blocks() as demo: | |
with gr.Tab("Chat"): | |
with gr.Row(): | |
chatbot = gr.Chatbot(scale=5, layout="panel", height=700) | |
with gr.Column(): | |
citation = gr.Textbox("", label="Citation", interactive=False, scale=3, container=False) | |
checkbox = gr.Checkbox(label="Provide document citation", value=True) | |
checkbox.change(doc_manager.set_provide_citation, checkbox) | |
msg = gr.Textbox(label="Enter your message") | |
with gr.Row(): | |
submit_button = gr.Button("Submit ➡️") | |
submit_button.click(doc_manager.predict, [msg, chatbot], [msg, chatbot]).then(doc_manager.update_citation, None, citation) | |
clear = gr.ClearButton([msg, chatbot, citation]) | |
msg.submit(doc_manager.predict, [msg, chatbot], [msg, chatbot]).then(doc_manager.update_citation, None, citation) | |
with gr.Tab("Settings") as settings_tab: | |
with gr.Row(): | |
api_key_textbox = gr.Textbox(label="OpenAI API Key", scale=5) | |
with gr.Column(): | |
api_key_status = doc_manager.create_api_key_status_display() | |
save_key_button = gr.Button("Save Key") | |
save_key_button.click(doc_manager.set_api_key, inputs=api_key_textbox, outputs=api_key_status).then(lambda:None, None, api_key_textbox, queue=False) | |
file_output = gr.Markdown("", label="Uploaded files") | |
upload_button = gr.UploadButton("Upload Files", file_count="multiple") | |
upload_button.upload(doc_manager.upload_file, upload_button, file_output) | |
prompt_textbox = gr.Textbox(label="Prompt", value=doc_manager.get_user_prompt()) | |
prompt_textbox.change(doc_manager.set_user_prompt, prompt_textbox) | |
with gr.Row(): | |
allow_delete_checkbox = gr.Checkbox(value=False, label="Allow deletion of files") | |
delete_button = doc_manager.create_delete_button(False) | |
delete_button.click(doc_manager.delete_files, outputs=file_output) | |
allow_delete_checkbox.select(doc_manager.create_delete_button, outputs=delete_button, inputs=allow_delete_checkbox) | |
settings_tab.select(doc_manager.create_api_key_status_display, outputs=api_key_status) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue().launch(auth=("user", "pw")) | |