import openai import gradio as gr import os from langchain.document_loaders import UnstructuredFileLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.text_splitter import RecursiveCharacterTextSplitter class DocumentManager: def __init__(self): self.api_key = None self.citation = "" self.docs = [] self.retriever = None self.files = [] self.provide_citation = True self.source_documents = [] self.user_prompt = "Be direct and cite your sources." self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100) self.check_for_api_key_env() def check_for_api_key_env(self): if "OPENAI_API_KEY" in os.environ: self.set_api_key(os.environ["OPENAI_API_KEY"]) def set_api_key(self, value): if (value is None) or (not value.startswith("sk")): gr.Warning("Please enter a valid OpenAI API key.") return self.create_api_key_status_display() self.api_key = value openai.api_key = self.api_key if len(self.docs) > 0: documents = self.text_splitter.split_documents(self.docs) self.retriever = Chroma.from_documents(documents, OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) else: self.retriever = Chroma(embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key) self.qa = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=self.retriever, return_source_documents=True) return self.create_api_key_status_display() def create_api_key_status_display(self): if self.api_key is None: return gr.Textbox("❌ Please enter an API key.", label=None, interactive=False, container=False) else: return gr.Textbox(f"✅ API key: {self.api_key[:9]}", label=None, interactive=False, container=False) def get_user_prompt(self): return self.user_prompt def set_user_prompt(self, value): self.user_prompt = value def set_provide_citation(self, value): self.provide_citation = value def delete_files(self): self.docs = [] self.files = [] self.source_documents = [] self.db = Chroma(embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)) self.db._client_settings.allow_reset = True self.db._client.reset() self.retreiver = self.db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key) self.qa = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=Chroma( embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)) .as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True), return_source_documents=True) return gr.Markdown(self.generate_file_markdown(), label="Uploaded files") def reset_qa(self): documents = self.text_splitter.split_documents(self.docs) self.retriever = Chroma.from_documents(documents, OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True) self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key) self.qa = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=self.retriever, return_source_documents=True) def tokenize_doc(self, filepath): loader = UnstructuredFileLoader(filepath) doc = loader.load() self.docs.extend(doc) def update_citation(self): if self.provide_citation and self.api_key: summed = "" for doc in self.source_documents: summed += doc.page_content self.citation = self.llm.predict( "Question: " + self.question + ". Answer: " + self.result + ". Citation: " + summed + ". From the citation, return the relevant passage and the exact articles") else: self.citation = "" return self.citation def predict(self, message, history): if self.api_key is None: gr.Warning("Please enter an OpenAI API key in the settings tab.") return "", [] if history is None: history = [] summed_history = " ".join(sum(history, [])) question = "You have access to these documents:" + self.generate_file_markdown() + ". Do not make things up, only say what you have a primary source document for. ---- CHAT HISTORY : " + summed_history + " --- SYSTEM PROMPT: " + self.user_prompt + " -- Answer this question: " + message print(self.retriever.vectorstore.get()) result = self.qa({"query": question}) self.source_documents = result["source_documents"] self.result = result["result"] self.question = question history.append([message, ""]) for message in result["result"]: history[-1][1] += message yield "", history def generate_file_markdown(self): files_md = "" for file in self.files: filename = file.split("/")[-1] files_md += "- " + filename + "\n" return files_md def upload_file(self, files): if self.api_key is None: gr.Warning("Please enter an OpenAI API key.") return self.files for file in files: self.tokenize_doc(file.name) filepaths = [file.orig_name for file in files] self.files = filepaths + self.files self.reset_qa() return gr.Markdown(self.generate_file_markdown(), label="Uploaded files") def create_delete_button(self, value): if value and self.api_key: return gr.Button("Delete files", scale=4, interactive=True) else: return gr.Button("Delete files", scale=4, interactive=False) def create_demo(): doc_manager = DocumentManager() with gr.Blocks() as demo: with gr.Tab("Chat"): with gr.Row(): chatbot = gr.Chatbot(scale=5, layout="panel", height=700) with gr.Column(): citation = gr.Textbox("", label="Citation", interactive=False, scale=3, container=False) checkbox = gr.Checkbox(label="Provide document citation", value=True) checkbox.change(doc_manager.set_provide_citation, checkbox) msg = gr.Textbox(label="Enter your message") with gr.Row(): submit_button = gr.Button("Submit ➡️") submit_button.click(doc_manager.predict, [msg, chatbot], [msg, chatbot]).then(doc_manager.update_citation, None, citation) clear = gr.ClearButton([msg, chatbot, citation]) msg.submit(doc_manager.predict, [msg, chatbot], [msg, chatbot]).then(doc_manager.update_citation, None, citation) with gr.Tab("Settings") as settings_tab: with gr.Row(): api_key_textbox = gr.Textbox(label="OpenAI API Key", scale=5) with gr.Column(): api_key_status = doc_manager.create_api_key_status_display() save_key_button = gr.Button("Save Key") save_key_button.click(doc_manager.set_api_key, inputs=api_key_textbox, outputs=api_key_status).then(lambda:None, None, api_key_textbox, queue=False) file_output = gr.Markdown("", label="Uploaded files") upload_button = gr.UploadButton("Upload Files", file_count="multiple") upload_button.upload(doc_manager.upload_file, upload_button, file_output) prompt_textbox = gr.Textbox(label="Prompt", value=doc_manager.get_user_prompt()) prompt_textbox.change(doc_manager.set_user_prompt, prompt_textbox) with gr.Row(): allow_delete_checkbox = gr.Checkbox(value=False, label="Allow deletion of files") delete_button = doc_manager.create_delete_button(False) delete_button.click(doc_manager.delete_files, outputs=file_output) allow_delete_checkbox.select(doc_manager.create_delete_button, outputs=delete_button, inputs=allow_delete_checkbox) settings_tab.select(doc_manager.create_api_key_status_display, outputs=api_key_status) return demo if __name__ == "__main__": demo = create_demo() demo.queue().launch(auth=("user", "pw"))