Spaces:
Sleeping
Sleeping
File size: 4,955 Bytes
f46b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import gradio as gr
from langchain_community.document_loaders import PDFPlumberLoader
from langchain_experimental.text_splitter import SemanticChunker
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain_groq import ChatGroq
from pathlib import Path
import pickle
# Function to load or create vector store
def load_or_create_vector_store(pdf_path, cache_path="vector_store.faiss"):
try:
# Initialize embeddings
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Check if cached vector store exists
if Path(cache_path).exists():
print("Loading cached vector store...")
return FAISS.load_local(cache_path, embedder, allow_dangerous_deserialization=True)
# Load and process PDF
if not Path(pdf_path).exists():
raise FileNotFoundError(f"PDF file not found at: {pdf_path}")
loader = PDFPlumberLoader(pdf_path)
docs = loader.load()
if not docs:
raise ValueError("No content extracted from the PDF.")
# Split documents
text_splitter = SemanticChunker(embedder)
documents = text_splitter.split_documents(docs)
# Create and save vector store
vector_store = FAISS.from_documents(documents, embedder)
vector_store.save_local(cache_path)
return vector_store
except Exception as e:
raise Exception(f"Error processing PDF: {str(e)}")
# Initialize language model
def initialize_llm(api_key):
try:
return ChatGroq(
groq_api_key=api_key,
temperature=0,
model_name="deepseek-r1-distill-llama-70b"
)
except Exception as e:
raise Exception(f"Error initializing LLM: {str(e)}")
# Define prompt template
prompt_template = """
Use the following context to answer the question. If you don't know the answer, say so. Keep the answer to 3 sentences. Always end with "Thanks for asking!"
{context}
Question: {question}
Helpful Answer:
"""
QA_PROMPT = PromptTemplate.from_template(prompt_template)
# Gradio interface function
def query_rag(pdf_file, question, api_key, k=2, chat_history=[]):
try:
if not question.strip():
return "Please enter a valid question.", [], chat_history
if not api_key:
return "Please provide a valid Grok API key.", [], chat_history
if not pdf_file:
return "Please upload a valid PDF file.", [], chat_history
# Handle Gradio file upload
pdf_path = "temp_uploaded_pdf.pdf"
if isinstance(pdf_file, str): # Gradio provides file path
pdf_path = pdf_file
else: # Gradio provides file-like object
with open(pdf_path, "wb") as f:
f.write(pdf_file) # Write file content directly
# Load or create vector store
vector_store = load_or_create_vector_store(pdf_path)
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k})
# Initialize LLM
llm = initialize_llm(api_key)
# Set up conversational chain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
combine_docs_chain_kwargs={"prompt": QA_PROMPT},
return_source_documents=True
)
# Run query
result = qa_chain({"question": question, "chat_history": chat_history})
answer = result["answer"]
sources = [doc.page_content[:200] + "..." for doc in result["source_documents"]]
# Update chat history
chat_history.append((question, answer))
return answer, sources, chat_history
except Exception as e:
return f"Error: {str(e)}", [], chat_history
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# RAG Question-Answering System")
gr.Markdown("Upload a PDF and ask questions about its content. Provide your Grok API key to proceed.")
api_key = gr.Textbox(label="Grok API Key", type="password", placeholder="Enter your GROQ_API_KEY")
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
question_input = gr.Textbox(label="Ask a Question", placeholder="e.g., What is a cost function in ML?")
k_slider = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Number of Retrieved Documents")
output = gr.Textbox(label="Answer")
sources = gr.Textbox(label="Source Documents (First 200 Characters)")
chat_history = gr.State(value=[])
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=query_rag,
inputs=[pdf_input, question_input, api_key, k_slider, chat_history],
outputs=[output, sources, chat_history]
)
# Launch the interface
demo.launch()
|