Spaces:
Runtime error
Runtime error
import os | |
import time | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from huggingface_hub import InferenceClient | |
import langchain | |
from langchain import HuggingFaceHub | |
from langchain.cache import InMemoryCache | |
from langchain.document_loaders import PyPDFLoader, OnlinePDFLoader, Docx2txtLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader | |
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
import gradio as gr | |
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
) | |
model="MBZUAI/LaMini-Flan-T5-248M" | |
llm = HuggingFaceHub(repo_id=model , | |
model_kwargs={"temperature": 0} | |
) | |
langchain.llm_cache = InMemoryCache() | |
def build_context(files,urls): | |
documents = [] | |
if files is not None: | |
for idx, file in enumerate(files): | |
if file.name.endswith('.pdf'): | |
loader = PyPDFLoader(file.name) | |
documents.extend(loader.load()) | |
elif file.name.endswith('.docx'): | |
loader = Docx2txtLoader(file.name) | |
documents.extend(loader.load()) | |
elif file.name.endswith('.ppt') or file.name.endswith('.pptx'): | |
loader = UnstructuredPowerPointLoader(file.name) | |
documents.extend(loader.load()) | |
if urls != "": | |
list_urls = urls.split(sep=",") | |
for url in list_urls: | |
loader = OnlinePDFLoader(url) | |
documents.extend(loader.load()) | |
#text_splitter = RecursiveCharacterTextSplitter(chunk_size=400,chunk_overlap=20,length_function=len,separators=["\n\n", "\n", " ", ""]) | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
chunked_documents = text_splitter.split_documents(documents) | |
global vectordb | |
vectordb = Chroma.from_documents( | |
documents=chunked_documents, | |
embedding=embeddings | |
) | |
global qa_chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=vectordb.as_retriever(search_kwargs={'k': 7}), | |
chain_type="stuff", | |
#return_source_documents=True | |
) | |
return "ready" | |
def loading(): | |
return "Loading..." | |
def respond(message, chat_history): | |
result = qa_chain({"query": message})["result"] | |
chat_history.append((message, result)) | |
time.sleep(2) | |
return "", chat_history | |
def clear_chromadb(): | |
ids = vectordb.get()["ids"] | |
for id in ids: | |
vectordb._collection.delete(ids=id) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
pdf_docs = gr.Files(label="Load pdf/docx/ppt/pptx files", file_types=['.pdf','.docx','.ppt','.pptx'], type="file") | |
with gr.Column(): | |
urls = gr.Textbox(label="Enter one of multiple online pdf urls (comma separated if multiple)", value=None) | |
with gr.Row(): | |
load_context = gr.Button("Load documents and urls") | |
with gr.Row(): | |
loading_status = gr.Textbox(label="Status", placeholder="", interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
hg_chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="User message") | |
clear = gr.ClearButton([msg, hg_chatbot]) | |
cleardb = gr.Button(value="Réinitialiser le contexte") | |
load_context.click(loading, None, loading_status, queue=False) | |
load_context.click(build_context, inputs=[pdf_docs, urls], outputs=[loading_status], queue=False) | |
msg.submit(respond, [msg, hg_chatbot], [msg, hg_chatbot]) | |
cleardb.click(clear_chromadb) | |
demo.queue(concurrency_count=3) | |
demo.launch() |