Geraldine's picture
Update app.py
f040926
raw
history blame
No virus
4.15 kB
import os
import time
from langchain.embeddings import HuggingFaceEmbeddings
from huggingface_hub import InferenceClient
import langchain
from langchain import HuggingFaceHub
from langchain.cache import InMemoryCache
from langchain.document_loaders import PyPDFLoader, OnlinePDFLoader, Docx2txtLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
import gradio as gr
def define_embeddings_llm(openai_key):
if openai_key != "":
embeddings = OpenAIEmbeddings(openai_api_key=openai_key)
llm = OpenAI(
temperature=0, model_name="gpt-3.5-turbo-16k", openai_api_key=openai_key, verbose=False
)
else:
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
llm = HuggingFaceHub(repo_id="MBZUAI/LaMini-Flan-T5-248M",
model_kwargs={"max_length":2048,
"temperature":0.2}
)
langchain.llm_cache = InMemoryCache()
return embeddings,llm
def build_context(openai_key,files,urls):
embeddings, llm = define_embeddings_llm(openai_key)
documents = []
if files is not None:
for idx, file in enumerate(files):
if file.name.endswith('.pdf'):
loader = PyPDFLoader(file.name)
documents.extend(loader.load())
elif file.name.endswith('.docx'):
loader = Docx2txtLoader(file.name)
documents.extend(loader.load())
elif file.name.endswith('.ppt') or file.name.endswith('.pptx'):
loader = UnstructuredPowerPointLoader(file.name)
documents.extend(loader.load())
if urls != "":
list_urls = urls.split(sep=",")
for url in list_urls:
loader = OnlinePDFLoader(url)
documents.extend(loader.load())
#text_splitter = RecursiveCharacterTextSplitter(chunk_size=400,chunk_overlap=20,length_function=len,separators=["\n\n", "\n", " ", ""])
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunked_documents = text_splitter.split_documents(documents)
global vectordb
vectordb = Chroma.from_documents(
documents=chunked_documents,
embedding=embeddings
)
global qa_chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=vectordb.as_retriever(search_kwargs={'k': 7}),
chain_type="stuff",
#return_source_documents=True
)
return "ready"
def loading():
return "Loading..."
def respond(message, chat_history):
result = qa_chain({"query": message})["result"]
chat_history.append((message, result))
time.sleep(2)
return "", chat_history
def clear_chromadb():
ids = vectordb.get()["ids"]
for id in ids:
vectordb._collection.delete(ids=id)
with gr.Blocks() as demo:
with gr.Row():
openai_key = gr.Textbox(label="Enter your OpenAI API Key if you want to use the gpt-3.5-turbo-16k model. If not, the open source LaMini-Flan-T5-248M is used")
with gr.Row():
with gr.Column():
pdf_docs = gr.Files(label="Load pdf/docx/ppt/pptx files", file_types=['.pdf','.docx','.ppt','.pptx'], type="file")
with gr.Column():
urls = gr.Textbox(label="Enter one of multiple online pdf urls (comma separated if multiple)", value=None)
with gr.Row():
load_context = gr.Button("Load documents and urls")
with gr.Row():
loading_status = gr.Textbox(label="Status", placeholder="", interactive=False)
with gr.Row():
with gr.Column():
hg_chatbot = gr.Chatbot()
msg = gr.Textbox(label="User message")
clear = gr.ClearButton([msg, hg_chatbot])
cleardb = gr.Button(value="Réinitialiser le contexte")
load_context.click(loading, None, loading_status, queue=False)
load_context.click(build_context, inputs=[openai_key,pdf_docs, urls], outputs=[loading_status], queue=False)
msg.submit(respond, [msg, hg_chatbot], [msg, hg_chatbot])
cleardb.click(clear_chromadb)
demo.queue(concurrency_count=3)
demo.launch()