|
import os |
|
import gradio as gr |
|
from langchain.chains import RetrievalQA |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
|
|
|
|
|
from langchain.memory import ConversationBufferMemory |
|
|
|
|
|
from langchain.indexes import VectorstoreIndexCreator |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
from dotenv import find_dotenv, load_dotenv |
|
|
|
from langchain.chains import create_retrieval_chain, RetrievalQA |
|
from langchain.chains import ConversationalRetrievalChain |
|
|
|
from langchain_community.vectorstores import FAISS |
|
|
|
from langchain_community.vectorstores import LanceDB |
|
import lancedb |
|
|
|
from langchain_community.vectorstores import Chroma |
|
import chromadb |
|
|
|
_=load_dotenv(find_dotenv()) |
|
hf_api = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
llms = ["Google/flan-t5-xxl", "Mistralai/Mistral-7B-Instruct-v0.2", "Mistralai/Mistral-7B-Instruct-v0.1", \ |
|
"Google/gemma-7b-it","Google/gemma-2b-it", "HuggingFaceH4/zephyr-7b-beta", \ |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "Mosaicml/mpt-7b-instruct", "Tiiuae/falcon-7b-instruct", \ |
|
] |
|
|
|
def indexdocs (docs,chunk_size,chunk_overlap,vector_store, progress=gr.Progress()): |
|
|
|
progress(0.1,desc="Loading documents...") |
|
|
|
loaders = [PyPDFLoader(x) for x in docs] |
|
pages = [] |
|
for loader in loaders: |
|
pages.extend(loader.load()) |
|
|
|
progress(0.2,desc="Splitting documents...") |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size = chunk_size, |
|
chunk_overlap = chunk_overlap) |
|
doc_splits = text_splitter.split_documents(pages) |
|
|
|
progress(0.3,desc="Generating embeddings...") |
|
|
|
embedding = HuggingFaceEmbeddings() |
|
|
|
progress(0.5,desc="Generating vectorstore...") |
|
|
|
if vector_store== 0: |
|
new_client = chromadb.EphemeralClient() |
|
vector_store_db = Chroma.from_documents( |
|
documents=doc_splits, |
|
embedding=embedding, |
|
client=new_client |
|
) |
|
elif vector_store==1: |
|
vector_store_db = FAISS.from_documents( |
|
documents=doc_splits, |
|
embedding=embedding |
|
) |
|
else: |
|
vector_store_db = LanceDB.from_documents( |
|
documents=doc_splits, |
|
embedding=embedding |
|
) |
|
|
|
progress(0.9,desc="Vector store generated from the documents.") |
|
|
|
return vector_store_db, gr.Column(visible=True), "Vector store generated from the documents" |
|
|
|
def setup_llm(vector_store,llm_model,temp,max_tokens): |
|
retriever=vector_store.as_retriever() |
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key='answer', |
|
return_messages=True |
|
) |
|
llm = HuggingFaceEndpoint( |
|
repo_id=llms[llm_model], |
|
temperature = temp, |
|
max_new_tokens = max_tokens, |
|
top_k = 1 |
|
) |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
memory=memory, |
|
return_source_documents=True, |
|
verbose=False, |
|
) |
|
return qa_chain,gr.Column(visible=True) |
|
|
|
def format_chat_history(chat_history): |
|
formatted_chat_history = [] |
|
for user_message, bot_message in chat_history: |
|
formatted_chat_history.append(f"User: {user_message}") |
|
formatted_chat_history.append(f"Assistant: {bot_message}") |
|
return formatted_chat_history |
|
|
|
def chat(qa_chain,msg,history): |
|
formatted_chat_history = format_chat_history(history) |
|
response = qa_chain.invoke({"question": msg, "chat_history": formatted_chat_history}) |
|
response_answer = response["answer"] |
|
response_sources=response["source_documents"] |
|
response_source1=os.path.basename(response_sources[0].metadata["source"]) |
|
response_source_page=response_sources[0].metadata["page"]+1 |
|
new_history = history + [(msg, response_answer)] |
|
return qa_chain, gr.update(value=""), new_history, response_source1, response_source_page |
|
|
|
with gr.Blocks() as demo: |
|
vector_store_db=gr.State() |
|
qa_chain=gr.State() |
|
|
|
gr.Markdown( |
|
""" |
|
# PDF Knowledge Base QA using RAG |
|
""" |
|
) |
|
with gr.Accordion(label="Create Vectorstore",open=True): |
|
with gr.Column(): |
|
file_list = gr.File(label='Upload your PDF files...', file_count='multiple', file_types=['.pdf']) |
|
chunk_size=gr.Slider(minimum=100, maximum=1000, value=500, step=25, label="Chunk Size", interactive=True) |
|
chunk_overlap=gr.Slider(minimum=10, maximum=200, value=30, step=10, label="Chunk Overlap", interactive=True) |
|
vector_store=gr.Radio (["Chroma","FAISS","Lance"], value="FAISS", label="Vectorstore",type="index", interactive=True) |
|
vectorstore_db_progress=gr.Textbox(label="Vectorstore database progress",value="Not started yet") |
|
fileuploadbtn= gr.Button ("Generate Vectorstore and Move to LLM Setup Step") |
|
with gr.Column(visible=False) as llm_column: |
|
llm=gr.Radio(llms, label="Choose LLM Model", value=llms[0],type="index") |
|
model_temp=gr.Slider(minimum=0.0, maximum=1.0,step=0.1, value=0.3, label="Temperature", interactive=True) |
|
model_max_tokens=gr.Slider(minimum=100, maximum=1000,step=50, value=200, label="Maximum Tokens", interactive=True) |
|
setup_llm_btn=gr.Button("Set up LLM and Start Chat") |
|
with gr.Column(visible=False) as chat_column: |
|
with gr.Row(): |
|
chatbot=gr.Chatbot(height=300) |
|
with gr.Row(): |
|
source=gr.Textbox(info="Source",container=False,scale=4) |
|
source_page=gr.Textbox(info="Page",container=False,scale=1) |
|
with gr.Row(): |
|
prompt=gr.Textbox(container=False, scale=4, interactive=True) |
|
promptsubmit=gr.Button("Submit", scale=1, interactive=True) |
|
gr.Markdown( |
|
""" |
|
# Responsible AI Usage |
|
Your documents uploaded to the system or interactions with the chatbot are not saved. |
|
""" |
|
) |
|
|
|
fileuploadbtn.click(fn=indexdocs, inputs = [file_list,chunk_size,chunk_overlap,vector_store], outputs=[vector_store_db,llm_column,vectorstore_db_progress]) |
|
setup_llm_btn.click(fn=setup_llm, inputs=[vector_store_db,llm,model_temp,model_max_tokens], outputs=[qa_chain,chat_column]) |
|
promptsubmit.click(fn=chat, inputs=[qa_chain,prompt,chatbot], outputs=[qa_chain,prompt,chatbot]) |
|
prompt.submit(fn=chat, inputs=[qa_chain,prompt,chatbot], outputs=[qa_chain,prompt,chatbot,source,source_page],queue=False) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |