SPARKNET / examples /rag_pipeline.py
MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Example: RAG Pipeline
Demonstrates:
1. Indexing documents into vector store
2. Semantic search
3. Question answering with citations
"""
from pathlib import Path
from loguru import logger
# Import RAG components
from src.rag import (
VectorStoreConfig,
EmbeddingConfig,
RetrieverConfig,
GeneratorConfig,
get_document_indexer,
get_document_retriever,
get_grounded_generator,
)
def example_indexing():
"""Index documents into vector store."""
print("=" * 50)
print("Document Indexing")
print("=" * 50)
# Get indexer
indexer = get_document_indexer()
# Index a document
sample_doc = Path("./data/sample.pdf")
if not sample_doc.exists():
print(f"Sample document not found: {sample_doc}")
print("Create a sample PDF at ./data/sample.pdf")
return False
# Index
result = indexer.index_document(sample_doc)
if result.success:
print(f"\nIndexed: {result.source_path}")
print(f" Document ID: {result.document_id}")
print(f" Chunks indexed: {result.num_chunks_indexed}")
print(f" Chunks skipped: {result.num_chunks_skipped}")
else:
print(f"Indexing failed: {result.error}")
return False
# Show stats
stats = indexer.get_index_stats()
print(f"\nIndex Stats:")
print(f" Total chunks: {stats['total_chunks']}")
print(f" Documents: {stats['num_documents']}")
print(f" Embedding model: {stats['embedding_model']}")
return True
def example_search():
"""Search indexed documents."""
print("\n" + "=" * 50)
print("Semantic Search")
print("=" * 50)
# Get retriever
retriever = get_document_retriever()
# Search queries
queries = [
"What is the main topic?",
"key findings",
"conclusions and recommendations",
]
for query in queries:
print(f"\nQuery: '{query}'")
chunks = retriever.retrieve(query, top_k=3)
if not chunks:
print(" No results found")
continue
for i, chunk in enumerate(chunks, 1):
print(f"\n [{i}] Similarity: {chunk.similarity:.3f}")
if chunk.page is not None:
print(f" Page: {chunk.page + 1}")
print(f" Text: {chunk.text[:150]}...")
def example_question_answering():
"""Answer questions using RAG."""
print("\n" + "=" * 50)
print("Question Answering with Citations")
print("=" * 50)
# Get generator
generator = get_grounded_generator()
# Questions
questions = [
"What is the main purpose of this document?",
"What are the key findings?",
"What recommendations are made?",
]
for question in questions:
print(f"\nQuestion: {question}")
print("-" * 40)
result = generator.answer_question(question, top_k=5)
print(f"\nAnswer: {result.answer}")
print(f"\nConfidence: {result.confidence:.2f}")
if result.abstained:
print(f"Note: {result.abstain_reason}")
if result.citations:
print(f"\nCitations ({len(result.citations)}):")
for citation in result.citations:
page = f"Page {citation.page + 1}" if citation.page is not None else ""
print(f" [{citation.index}] {page}: {citation.text_snippet[:60]}...")
def example_filtered_search():
"""Search with metadata filters."""
print("\n" + "=" * 50)
print("Filtered Search")
print("=" * 50)
retriever = get_document_retriever()
# Search only in tables
print("\nSearching for tables only...")
table_chunks = retriever.retrieve_tables("data values", top_k=3)
if table_chunks:
print(f"Found {len(table_chunks)} table chunks:")
for chunk in table_chunks:
print(f" - Page {chunk.page + 1}: {chunk.text[:100]}...")
else:
print("No table chunks found")
# Search specific page range
print("\nSearching pages 1-3...")
page_chunks = retriever.retrieve_by_page(
"introduction",
page_range=(0, 2),
top_k=3,
)
if page_chunks:
print(f"Found {len(page_chunks)} chunks in pages 1-3:")
for chunk in page_chunks:
print(f" - Page {chunk.page + 1}: {chunk.text[:100]}...")
else:
print("No chunks found in specified pages")
def example_full_pipeline():
"""Complete RAG pipeline demo."""
print("\n" + "=" * 50)
print("Full RAG Pipeline Demo")
print("=" * 50)
# Step 1: Index
print("\n[Step 1] Indexing documents...")
if not example_indexing():
return
# Step 2: Search
print("\n[Step 2] Testing search...")
example_search()
# Step 3: Q&A
print("\n[Step 3] Question answering...")
example_question_answering()
print("\n" + "=" * 50)
print("Pipeline demo complete!")
print("=" * 50)
if __name__ == "__main__":
# Run full pipeline
example_full_pipeline()