import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
def initialise_vectorstore(pdf, progress=gr.Progress()):
progress(0, desc="Reading PDF")
loader = PyPDFLoader(pdf.name)
pages = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(pages)
progress(0.5, desc="Initialising Vectorstore")
vectorstore = Chroma.from_documents(
splits,
embedding=HuggingFaceEmbeddings()
)
progress(1, desc="Complete")
return vectorstore, progress
def initialise_chain(llm, vectorstore, progress=gr.Progress()):
progress(0, desc="Initialising LLM")
llm = HuggingFaceEndpoint(
repo_id=llm,
task="text-generation",
max_new_tokens=512,
do_sample=False,
repetition_penalty=1.03
)
chat = ChatHuggingFace(
llm=llm,
verbose=True
)
progress(0.5, desc="Initialising RAG Chain")
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
parser = StrOutputParser()
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | chat | parser
progress(1, desc="Complete")
return rag_chain, progress
def send(message, rag_chain, chat_history):
response = rag_chain.invoke(message)
chat_history.append((message, response))
return "", chat_history
def restart():
return f"Restarting"
with gr.Blocks() as demo:
vectorstore = gr.State()
rag_chain = gr.State()
gr.Markdown("
Talk to Documents
")
gr.Markdown("Upload and ask questions about your PDF files
")
gr.Markdown("Note: This project uses LangChain to perform RAG (Retrieval Augmented Generation) on PDF files, allowing users to ask any questions related to their contents. When a PDF file is uploaded, it is embedded and stored in an in-memory Chroma vectorstore, which the chatbot uses as a source of knowledge when aswering user questions.
")
# Vectorstore Tab
with gr.Tab("Vectorstore"):
with gr.Row():
input_pdf = gr.File()
with gr.Row():
with gr.Column(scale=1, min_width=0):
pass
with gr.Column(scale=2, min_width=0):
initialise_vectorstore_btn = gr.Button(
"Initialise Vectorstore",
variant='primary'
)
with gr.Column(scale=1, min_width=0):
pass
with gr.Row():
vectorstore_initialisation_progress = gr.Textbox(value="None", label="Initialization")
# RAG Chain
with gr.Tab("RAG Chain"):
with gr.Row():
language_model = gr.Radio(["microsoft/Phi-3-mini-4k-instruct", "mistralai/Mistral-7B-Instruct-v0.2", "nvidia/Mistral-NeMo-Minitron-8B-Base"])
with gr.Row():
with gr.Column(scale=1, min_width=0):
pass
with gr.Column(scale=2, min_width=0):
initialise_chain_btn = gr.Button(
"Initialise RAG Chain",
variant='primary'
)
with gr.Column(scale=1, min_width=0):
pass
with gr.Row():
chain_initialisation_progress = gr.Textbox(value="None", label="Initialization")
# Chatbot Tab
with gr.Tab("Chatbot"):
with gr.Row():
chatbot = gr.Chatbot()
with gr.Accordion("Advanced - Document references", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Page", scale=1)
with gr.Row():
message = gr.Textbox()
with gr.Row():
send_btn = gr.Button(
"Send",
variant=["primary"]
)
restart_btn = gr.Button(
"Restart",
variant=["secondary"]
)
initialise_vectorstore_btn.click(fn=initialise_vectorstore, inputs=input_pdf, outputs=[vectorstore, vectorstore_initialisation_progress])
initialise_chain_btn.click(fn=initialise_chain, inputs=[language_model, vectorstore], outputs=[rag_chain, chain_initialisation_progress])
send_btn.click(fn=send, inputs=[message, rag_chain, chatbot], outputs=[message, chatbot])
restart_btn.click(fn=restart)
demo.launch()