Geraldine's picture
Update app.py
12fb790
raw
history blame
No virus
4.42 kB
import os
import time
import langchain
# loaders
from langchain.document_loaders import PyPDFLoader, OnlinePDFLoader, Docx2txtLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
# splits
from langchain.text_splitter import RecursiveCharacterTextSplitter
# embeddings
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
# vector stores
from langchain.vectorstores import Chroma
# huggingface hub
from huggingface_hub import InferenceClient
from langchain import HuggingFaceHub
# models
from langchain.llms import OpenAI
# retrievers
from langchain.chains import RetrievalQA
import gradio as gr
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
def build_context(openai_key,files,urls):
if openai_key != "":
embeddings = OpenAIEmbeddings(model_name="text-embedding-ada-002", openai_api_key=openai_key)
else:
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}
)
documents = []
if files is not None:
print("files not none")
for idx, file in enumerate(files):
if file.name.endswith('.pdf'):
loader = PyPDFLoader(file.name)
documents.extend(loader.load())
elif file.name.endswith('.docx'):
loader = Docx2txtLoader(file.name)
documents.extend(loader.load())
elif file.name.endswith('.ppt') or file.name.endswith('.pptx'):
loader = UnstructuredPowerPointLoader(file.name)
documents.extend(loader.load())
if urls != "":
print("urls not none")
list_urls = urls.split(sep=",")
for url in list_urls:
loader = OnlinePDFLoader(url)
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800,chunk_overlap=0,length_function=len,separators=["\n\n", "\n", " ", ""])
chunked_documents = text_splitter.split_documents(documents)
global vectordb
vectordb = Chroma.from_documents(
documents=chunked_documents,
embedding=embeddings,
)
return "loaded"
def llm_response(openai_key, message, chat_history):
if openai_key != "":
llm = OpenAI(
temperature=0, openai_api_key=openai_key, model_name="gpt-3.5-turbo", verbose=False
)
else:
llm = HuggingFaceHub(repo_id='MBZUAI/LaMini-Flan-T5-248M',
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
model_kwargs={"max_length":512,"do_sample":True,
"temperature":0.2})
qa_chain = RetrievalQA.from_chain_type(llm = llm,
chain_type = "stuff",
retriever = vectordb.as_retriever(search_kwargs = {"k": 10}),
return_source_documents = False,
verbose = True)
result = qa_chain(message)["result"]
chat_history.append((message, result))
time.sleep(2)
return "", chat_history
def loading():
return "Loading..."
def clear_chromadb():
ids = vectordb.get()["ids"]
for id in ids:
vectordb._collection.delete(ids=id)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
openai_key = gr.Textbox(label="Enter your OpenAI API Key if you want to use the gpt-3.5-turbo-16k model. If not, the open source LaMini-Flan-T5-248M is used")
with gr.Row():
pdf_docs = gr.Files(label="Load pdf files", file_types=['.pdf','.docx','.ppt','.pptx'], type="file")
urls = gr.Textbox(label="Enter one of multiple online pdf urls (comma separated if multiple)")
with gr.Row():
load_docs = gr.Button("Load documents and urls", variant="primary", scale=1)
loading_status = gr.Textbox(label="Loading status", placeholder="", interactive=False, scale=0)
with gr.Row():
with gr.Column(scale=1):
msg = gr.Textbox(label="User message")
chatbot = gr.Chatbot()
with gr.Row():
clearchat = gr.ClearButton([msg, chatbot], value="New chat",)
cleardb = gr.Button(value="Reset context (for loading new documents)", variant="secondary")
load_docs.click(loading, None, loading_status, queue=False)
load_docs.click(build_context, inputs=[openai_key,pdf_docs, urls], outputs=[loading_status], queue=False)
msg.submit(llm_response, [openai_key, msg, chatbot], [msg, chatbot])
cleardb.click(clear_chromadb)
demo.queue(concurrency_count=3)
demo.launch()