jamesredd's picture
initial commit
62124cd
import openai
import gradio as gr
import os
from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
class DocumentManager:
def __init__(self):
self.api_key = None
self.citation = ""
self.docs = []
self.retriever = None
self.files = []
self.provide_citation = True
self.source_documents = []
self.user_prompt = "Be direct and cite your sources."
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100)
self.check_for_api_key_env()
def check_for_api_key_env(self):
if "OPENAI_API_KEY" in os.environ:
self.set_api_key(os.environ["OPENAI_API_KEY"])
def set_api_key(self, value):
if (value is None) or (not value.startswith("sk")):
gr.Warning("Please enter a valid OpenAI API key.")
return self.create_api_key_status_display()
self.api_key = value
openai.api_key = self.api_key
if len(self.docs) > 0:
documents = self.text_splitter.split_documents(self.docs)
self.retriever = Chroma.from_documents(documents, OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True)
else:
self.retriever = Chroma(embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True)
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key)
self.qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True)
return self.create_api_key_status_display()
def create_api_key_status_display(self):
if self.api_key is None:
return gr.Textbox("❌ Please enter an API key.", label=None, interactive=False, container=False)
else:
return gr.Textbox(f"✅ API key: {self.api_key[:9]}", label=None, interactive=False, container=False)
def get_user_prompt(self):
return self.user_prompt
def set_user_prompt(self, value):
self.user_prompt = value
def set_provide_citation(self, value):
self.provide_citation = value
def delete_files(self):
self.docs = []
self.files = []
self.source_documents = []
self.db = Chroma(embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key))
self.db._client_settings.allow_reset = True
self.db._client.reset()
self.retreiver = self.db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True)
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key)
self.qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=Chroma(
embedding_function=OpenAIEmbeddings(openai_api_key=self.api_key))
.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True),
return_source_documents=True)
return gr.Markdown(self.generate_file_markdown(), label="Uploaded files")
def reset_qa(self):
documents = self.text_splitter.split_documents(self.docs)
self.retriever = Chroma.from_documents(documents, OpenAIEmbeddings(openai_api_key=self.api_key)).as_retriever(search_type="mmr", search_kwargs={'fetch_k': 30}, return_source_documents=True)
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0, streaming=True, openai_api_key=self.api_key)
self.qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True)
def tokenize_doc(self, filepath):
loader = UnstructuredFileLoader(filepath)
doc = loader.load()
self.docs.extend(doc)
def update_citation(self):
if self.provide_citation and self.api_key:
summed = ""
for doc in self.source_documents:
summed += doc.page_content
self.citation = self.llm.predict(
"Question: " + self.question + ". Answer: " + self.result + ". Citation: " + summed +
". From the citation, return the relevant passage and the exact articles")
else:
self.citation = ""
return self.citation
def predict(self, message, history):
if self.api_key is None:
gr.Warning("Please enter an OpenAI API key in the settings tab.")
return "", []
if history is None:
history = []
summed_history = " ".join(sum(history, []))
question = "You have access to these documents:" + self.generate_file_markdown() + ". Do not make things up, only say what you have a primary source document for. ---- CHAT HISTORY : " + summed_history + " --- SYSTEM PROMPT: " + self.user_prompt + " -- Answer this question: " + message
print(self.retriever.vectorstore.get())
result = self.qa({"query": question})
self.source_documents = result["source_documents"]
self.result = result["result"]
self.question = question
history.append([message, ""])
for message in result["result"]:
history[-1][1] += message
yield "", history
def generate_file_markdown(self):
files_md = ""
for file in self.files:
filename = file.split("/")[-1]
files_md += "- " + filename + "\n"
return files_md
def upload_file(self, files):
if self.api_key is None:
gr.Warning("Please enter an OpenAI API key.")
return self.files
for file in files:
self.tokenize_doc(file.name)
filepaths = [file.orig_name for file in files]
self.files = filepaths + self.files
self.reset_qa()
return gr.Markdown(self.generate_file_markdown(), label="Uploaded files")
def create_delete_button(self, value):
if value and self.api_key:
return gr.Button("Delete files", scale=4, interactive=True)
else:
return gr.Button("Delete files", scale=4, interactive=False)
def create_demo():
doc_manager = DocumentManager()
with gr.Blocks() as demo:
with gr.Tab("Chat"):
with gr.Row():
chatbot = gr.Chatbot(scale=5, layout="panel", height=700)
with gr.Column():
citation = gr.Textbox("", label="Citation", interactive=False, scale=3, container=False)
checkbox = gr.Checkbox(label="Provide document citation", value=True)
checkbox.change(doc_manager.set_provide_citation, checkbox)
msg = gr.Textbox(label="Enter your message")
with gr.Row():
submit_button = gr.Button("Submit ➡️")
submit_button.click(doc_manager.predict, [msg, chatbot], [msg, chatbot]).then(doc_manager.update_citation, None, citation)
clear = gr.ClearButton([msg, chatbot, citation])
msg.submit(doc_manager.predict, [msg, chatbot], [msg, chatbot]).then(doc_manager.update_citation, None, citation)
with gr.Tab("Settings") as settings_tab:
with gr.Row():
api_key_textbox = gr.Textbox(label="OpenAI API Key", scale=5)
with gr.Column():
api_key_status = doc_manager.create_api_key_status_display()
save_key_button = gr.Button("Save Key")
save_key_button.click(doc_manager.set_api_key, inputs=api_key_textbox, outputs=api_key_status).then(lambda:None, None, api_key_textbox, queue=False)
file_output = gr.Markdown("", label="Uploaded files")
upload_button = gr.UploadButton("Upload Files", file_count="multiple")
upload_button.upload(doc_manager.upload_file, upload_button, file_output)
prompt_textbox = gr.Textbox(label="Prompt", value=doc_manager.get_user_prompt())
prompt_textbox.change(doc_manager.set_user_prompt, prompt_textbox)
with gr.Row():
allow_delete_checkbox = gr.Checkbox(value=False, label="Allow deletion of files")
delete_button = doc_manager.create_delete_button(False)
delete_button.click(doc_manager.delete_files, outputs=file_output)
allow_delete_checkbox.select(doc_manager.create_delete_button, outputs=delete_button, inputs=allow_delete_checkbox)
settings_tab.select(doc_manager.create_api_key_status_display, outputs=api_key_status)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue().launch(auth=("user", "pw"))