Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
# Step 1: Load and Split Documents | |
def load_documents(pdf_files): | |
loaders = [PyPDFLoader(file.name) for file in pdf_files] | |
docs = [] | |
for loader in loaders: | |
docs.extend(loader.load()) | |
# Split documents into smaller chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50 | |
) | |
return text_splitter.split_documents(docs) | |
# Step 2: Create Vector Database | |
def create_vector_db(splits): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_db = FAISS.from_documents(splits, embeddings) | |
return vector_db | |
# Step 3: Initialize Conversational Retrieval Chain | |
def initialize_qa_chain(vector_db): | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
qa_chain = ConversationalRetrievalChain.from_chain_type( | |
retriever=vector_db.as_retriever(), | |
chain_type="stuff", | |
memory=memory | |
) | |
return qa_chain | |
# Step 4: Handle Conversation | |
def handle_conversation(qa_chain, query, history): | |
result = qa_chain({"question": query, "chat_history": history}) | |
response = result["answer"] | |
history.append((query, response)) | |
return history, history | |
# Gradio UI | |
def demo(): | |
vector_db = gr.State() | |
qa_chain = gr.State() | |
with gr.Blocks() as interface: | |
gr.Markdown("<h1><center>CPU-Friendly RAG Chatbot</center></h1>") | |
with gr.Tab("Step 1: Upload PDFs"): | |
pdf_files = gr.File(file_types=[".pdf"], label="Upload PDF Files", file_count="multiple") | |
create_db_button = gr.Button("Create Vector Database") | |
db_status = gr.Textbox(label="Database Status", value="Not created", interactive=False) | |
with gr.Tab("Step 2: Chat"): | |
chatbot = gr.Chatbot() | |
query = gr.Textbox(label="Your Query") | |
send_button = gr.Button("Ask") | |
# Create database | |
create_db_button.click( | |
fn=lambda files: (create_vector_db(load_documents(files)), "Database created successfully!"), | |
inputs=[pdf_files], | |
outputs=[vector_db, db_status] | |
) | |
# Initialize QA Chain | |
create_db_button.click( | |
fn=lambda db: initialize_qa_chain(db), | |
inputs=[vector_db], | |
outputs=[qa_chain] | |
) | |
# Handle conversation | |
send_button.click( | |
fn=handle_conversation, | |
inputs=[qa_chain, query, chatbot], | |
outputs=[chatbot, chatbot] | |
) | |
return interface | |
# Launch the app | |
if __name__ == "__main__": | |
demo().launch() | |