|
|
""" |
|
|
Agentic RAG Orchestrator |
|
|
|
|
|
Coordinates the multi-agent RAG pipeline with self-correction loop. |
|
|
Follows FAANG best practices for production RAG systems. |
|
|
|
|
|
Pipeline: |
|
|
Query -> Plan -> Retrieve -> Rerank -> Synthesize -> Validate -> (Revise?) -> Response |
|
|
|
|
|
Key Features: |
|
|
- LangGraph-style state machine |
|
|
- Self-correction loop (up to N attempts) |
|
|
- Streaming support |
|
|
- Comprehensive logging and metrics |
|
|
- Graceful degradation |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any, Generator, Tuple |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
import time |
|
|
|
|
|
from ..store import VectorStore, get_vector_store, VectorStoreConfig |
|
|
from ..embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig |
|
|
|
|
|
from .query_planner import QueryPlannerAgent, QueryPlan, SubQuery |
|
|
from .retriever import RetrieverAgent, RetrievalResult, HybridSearchConfig |
|
|
from .reranker import RerankerAgent, RankedResult, RerankerConfig |
|
|
from .synthesizer import SynthesizerAgent, SynthesisResult, Citation, SynthesizerConfig |
|
|
from .critic import CriticAgent, CriticResult, ValidationIssue, CriticConfig |
|
|
|
|
|
|
|
|
class PipelineStage(str, Enum): |
|
|
"""Stages in the RAG pipeline.""" |
|
|
PLANNING = "planning" |
|
|
RETRIEVAL = "retrieval" |
|
|
RERANKING = "reranking" |
|
|
SYNTHESIS = "synthesis" |
|
|
VALIDATION = "validation" |
|
|
REVISION = "revision" |
|
|
COMPLETE = "complete" |
|
|
|
|
|
|
|
|
class RAGConfig(BaseModel): |
|
|
"""Configuration for the agentic RAG system.""" |
|
|
|
|
|
model: str = Field(default="llama3.2:3b") |
|
|
base_url: str = Field(default="http://localhost:11434") |
|
|
|
|
|
|
|
|
max_revision_attempts: int = Field(default=2, ge=0, le=5) |
|
|
enable_query_planning: bool = Field(default=True) |
|
|
enable_reranking: bool = Field(default=True) |
|
|
enable_validation: bool = Field(default=True) |
|
|
|
|
|
|
|
|
retrieval_top_k: int = Field(default=10, ge=1) |
|
|
final_top_k: int = Field(default=5, ge=1) |
|
|
|
|
|
|
|
|
min_confidence: float = Field(default=0.5, ge=0.0, le=1.0) |
|
|
|
|
|
|
|
|
verbose: bool = Field(default=False) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RAGState: |
|
|
"""State maintained through the pipeline.""" |
|
|
query: str |
|
|
stage: PipelineStage = PipelineStage.PLANNING |
|
|
|
|
|
|
|
|
query_plan: Optional[QueryPlan] = None |
|
|
retrieved_chunks: List[RetrievalResult] = field(default_factory=list) |
|
|
ranked_chunks: List[RankedResult] = field(default_factory=list) |
|
|
synthesis_result: Optional[SynthesisResult] = None |
|
|
critic_result: Optional[CriticResult] = None |
|
|
|
|
|
|
|
|
revision_attempt: int = 0 |
|
|
revision_history: List[SynthesisResult] = field(default_factory=list) |
|
|
|
|
|
|
|
|
start_time: float = field(default_factory=time.time) |
|
|
stage_times: Dict[str, float] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
errors: List[str] = field(default_factory=list) |
|
|
|
|
|
|
|
|
class RAGResponse(BaseModel): |
|
|
"""Final response from the RAG system.""" |
|
|
answer: str |
|
|
citations: List[Citation] |
|
|
confidence: float |
|
|
|
|
|
|
|
|
query: str |
|
|
num_sources: int |
|
|
validated: bool |
|
|
revision_attempts: int |
|
|
|
|
|
|
|
|
query_plan: Optional[Dict[str, Any]] = None |
|
|
validation_details: Optional[Dict[str, Any]] = None |
|
|
latency_ms: float = 0.0 |
|
|
|
|
|
|
|
|
class AgenticRAG: |
|
|
""" |
|
|
Production-grade Multi-Agent RAG System. |
|
|
|
|
|
Orchestrates: |
|
|
- QueryPlannerAgent: Query decomposition and planning |
|
|
- RetrieverAgent: Hybrid retrieval |
|
|
- RerankerAgent: Cross-encoder reranking |
|
|
- SynthesizerAgent: Answer generation |
|
|
- CriticAgent: Validation and hallucination detection |
|
|
|
|
|
Features: |
|
|
- Self-correction loop |
|
|
- Graceful degradation |
|
|
- Comprehensive metrics |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Optional[RAGConfig] = None, |
|
|
vector_store: Optional[VectorStore] = None, |
|
|
embedding_adapter: Optional[EmbeddingAdapter] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the Agentic RAG system. |
|
|
|
|
|
Args: |
|
|
config: RAG configuration |
|
|
vector_store: Vector store for retrieval |
|
|
embedding_adapter: Embedding adapter |
|
|
""" |
|
|
self.config = config or RAGConfig() |
|
|
|
|
|
|
|
|
self._store = vector_store |
|
|
self._embedder = embedding_adapter |
|
|
|
|
|
|
|
|
self._init_agents() |
|
|
|
|
|
logger.info( |
|
|
f"AgenticRAG initialized (model={self.config.model}, " |
|
|
f"revision_attempts={self.config.max_revision_attempts})" |
|
|
) |
|
|
|
|
|
def _init_agents(self): |
|
|
"""Initialize all agents with shared configuration.""" |
|
|
|
|
|
self.planner = QueryPlannerAgent( |
|
|
model=self.config.model, |
|
|
base_url=self.config.base_url, |
|
|
use_llm=self.config.enable_query_planning, |
|
|
) |
|
|
|
|
|
|
|
|
retriever_config = HybridSearchConfig( |
|
|
dense_top_k=self.config.retrieval_top_k, |
|
|
sparse_top_k=self.config.retrieval_top_k, |
|
|
final_top_k=self.config.retrieval_top_k, |
|
|
) |
|
|
self.retriever = RetrieverAgent( |
|
|
config=retriever_config, |
|
|
vector_store=self._store, |
|
|
embedding_adapter=self._embedder, |
|
|
) |
|
|
|
|
|
|
|
|
reranker_config = RerankerConfig( |
|
|
model=self.config.model, |
|
|
base_url=self.config.base_url, |
|
|
top_k=self.config.final_top_k, |
|
|
use_llm_rerank=self.config.enable_reranking, |
|
|
min_relevance_score=0.1, |
|
|
) |
|
|
self.reranker = RerankerAgent(config=reranker_config) |
|
|
|
|
|
|
|
|
synth_config = SynthesizerConfig( |
|
|
model=self.config.model, |
|
|
base_url=self.config.base_url, |
|
|
confidence_threshold=self.config.min_confidence, |
|
|
) |
|
|
self.synthesizer = SynthesizerAgent(config=synth_config) |
|
|
|
|
|
|
|
|
critic_config = CriticConfig( |
|
|
model=self.config.model, |
|
|
base_url=self.config.base_url, |
|
|
) |
|
|
self.critic = CriticAgent(config=critic_config) |
|
|
|
|
|
@property |
|
|
def store(self) -> VectorStore: |
|
|
"""Get vector store (lazy initialization).""" |
|
|
if self._store is None: |
|
|
self._store = get_vector_store() |
|
|
return self._store |
|
|
|
|
|
@property |
|
|
def embedder(self) -> EmbeddingAdapter: |
|
|
"""Get embedding adapter (lazy initialization).""" |
|
|
if self._embedder is None: |
|
|
self._embedder = get_embedding_adapter() |
|
|
return self._embedder |
|
|
|
|
|
def query( |
|
|
self, |
|
|
question: str, |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
) -> RAGResponse: |
|
|
""" |
|
|
Process a query through the full RAG pipeline. |
|
|
|
|
|
Args: |
|
|
question: User's question |
|
|
filters: Optional metadata filters for retrieval |
|
|
|
|
|
Returns: |
|
|
RAGResponse with answer and metadata |
|
|
""" |
|
|
|
|
|
state = RAGState(query=question) |
|
|
|
|
|
try: |
|
|
|
|
|
state = self._plan(state) |
|
|
|
|
|
|
|
|
state = self._retrieve(state, filters) |
|
|
|
|
|
|
|
|
state = self._rerank(state) |
|
|
|
|
|
|
|
|
state = self._synthesize(state) |
|
|
|
|
|
|
|
|
if self.config.enable_validation: |
|
|
state = self._validate_and_revise(state) |
|
|
|
|
|
|
|
|
return self._build_response(state) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"RAG pipeline error: {e}") |
|
|
state.errors.append(str(e)) |
|
|
return self._build_error_response(state, str(e)) |
|
|
|
|
|
def query_stream( |
|
|
self, |
|
|
question: str, |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
) -> Generator[Tuple[PipelineStage, Any], None, None]: |
|
|
""" |
|
|
Process query with streaming updates. |
|
|
|
|
|
Yields: |
|
|
Tuple of (stage, stage_result) |
|
|
""" |
|
|
state = RAGState(query=question) |
|
|
|
|
|
try: |
|
|
|
|
|
state = self._plan(state) |
|
|
yield PipelineStage.PLANNING, state.query_plan |
|
|
|
|
|
|
|
|
state = self._retrieve(state, filters) |
|
|
yield PipelineStage.RETRIEVAL, len(state.retrieved_chunks) |
|
|
|
|
|
|
|
|
state = self._rerank(state) |
|
|
yield PipelineStage.RERANKING, len(state.ranked_chunks) |
|
|
|
|
|
|
|
|
state = self._synthesize(state) |
|
|
yield PipelineStage.SYNTHESIS, state.synthesis_result |
|
|
|
|
|
|
|
|
if self.config.enable_validation: |
|
|
state = self._validate_and_revise(state) |
|
|
yield PipelineStage.VALIDATION, state.critic_result |
|
|
|
|
|
|
|
|
response = self._build_response(state) |
|
|
yield PipelineStage.COMPLETE, response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Streaming error: {e}") |
|
|
yield PipelineStage.COMPLETE, self._build_error_response(state, str(e)) |
|
|
|
|
|
def _plan(self, state: RAGState) -> RAGState: |
|
|
"""Execute query planning stage.""" |
|
|
start = time.time() |
|
|
state.stage = PipelineStage.PLANNING |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info(f"Planning query: {state.query}") |
|
|
|
|
|
state.query_plan = self.planner.plan(state.query) |
|
|
|
|
|
state.stage_times["planning"] = time.time() - start |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info( |
|
|
f"Query plan: intent={state.query_plan.intent}, " |
|
|
f"sub_queries={len(state.query_plan.sub_queries)}" |
|
|
) |
|
|
|
|
|
return state |
|
|
|
|
|
def _retrieve( |
|
|
self, |
|
|
state: RAGState, |
|
|
filters: Optional[Dict[str, Any]], |
|
|
) -> RAGState: |
|
|
"""Execute retrieval stage.""" |
|
|
start = time.time() |
|
|
state.stage = PipelineStage.RETRIEVAL |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info("Retrieving relevant chunks...") |
|
|
|
|
|
|
|
|
state.retrieved_chunks = self.retriever.retrieve( |
|
|
query=state.query, |
|
|
plan=state.query_plan, |
|
|
top_k=self.config.retrieval_top_k, |
|
|
filters=filters, |
|
|
) |
|
|
|
|
|
state.stage_times["retrieval"] = time.time() - start |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info(f"Retrieved {len(state.retrieved_chunks)} chunks") |
|
|
|
|
|
return state |
|
|
|
|
|
def _rerank(self, state: RAGState) -> RAGState: |
|
|
"""Execute reranking stage.""" |
|
|
start = time.time() |
|
|
state.stage = PipelineStage.RERANKING |
|
|
|
|
|
if not state.retrieved_chunks: |
|
|
state.ranked_chunks = [] |
|
|
return state |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info("Reranking results...") |
|
|
|
|
|
state.ranked_chunks = self.reranker.rerank( |
|
|
query=state.query, |
|
|
results=state.retrieved_chunks, |
|
|
top_k=self.config.final_top_k, |
|
|
) |
|
|
|
|
|
state.stage_times["reranking"] = time.time() - start |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info(f"Reranked to {len(state.ranked_chunks)} chunks") |
|
|
|
|
|
return state |
|
|
|
|
|
def _synthesize(self, state: RAGState) -> RAGState: |
|
|
"""Execute synthesis stage.""" |
|
|
start = time.time() |
|
|
state.stage = PipelineStage.SYNTHESIS |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info("Synthesizing answer...") |
|
|
|
|
|
state.synthesis_result = self.synthesizer.synthesize( |
|
|
query=state.query, |
|
|
results=state.ranked_chunks, |
|
|
plan=state.query_plan, |
|
|
) |
|
|
|
|
|
state.stage_times["synthesis"] = time.time() - start |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info( |
|
|
f"Synthesized answer (confidence={state.synthesis_result.confidence:.2f})" |
|
|
) |
|
|
|
|
|
return state |
|
|
|
|
|
def _validate_and_revise(self, state: RAGState) -> RAGState: |
|
|
"""Execute validation and optional revision loop.""" |
|
|
start = time.time() |
|
|
|
|
|
while state.revision_attempt <= self.config.max_revision_attempts: |
|
|
state.stage = PipelineStage.VALIDATION |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info(f"Validating (attempt {state.revision_attempt + 1})...") |
|
|
|
|
|
|
|
|
state.critic_result = self.critic.validate( |
|
|
synthesis_result=state.synthesis_result, |
|
|
sources=state.ranked_chunks, |
|
|
) |
|
|
|
|
|
if state.critic_result.is_valid: |
|
|
if self.config.verbose: |
|
|
logger.info("Validation passed!") |
|
|
break |
|
|
|
|
|
|
|
|
if state.revision_attempt >= self.config.max_revision_attempts: |
|
|
if self.config.verbose: |
|
|
logger.warning("Max revision attempts reached") |
|
|
break |
|
|
|
|
|
|
|
|
state.stage = PipelineStage.REVISION |
|
|
state.revision_attempt += 1 |
|
|
state.revision_history.append(state.synthesis_result) |
|
|
|
|
|
if self.config.verbose: |
|
|
logger.info(f"Revising answer (attempt {state.revision_attempt})...") |
|
|
|
|
|
|
|
|
state.synthesis_result = self._revise_synthesis(state) |
|
|
|
|
|
state.stage_times["validation"] = time.time() - start |
|
|
return state |
|
|
|
|
|
def _revise_synthesis(self, state: RAGState) -> SynthesisResult: |
|
|
"""Revise synthesis based on critic feedback.""" |
|
|
|
|
|
|
|
|
|
|
|
return self.synthesizer.synthesize( |
|
|
query=state.query, |
|
|
results=state.ranked_chunks, |
|
|
plan=state.query_plan, |
|
|
) |
|
|
|
|
|
def _build_response(self, state: RAGState) -> RAGResponse: |
|
|
"""Build final response from state.""" |
|
|
total_time = (time.time() - state.start_time) * 1000 |
|
|
|
|
|
synthesis = state.synthesis_result |
|
|
if synthesis is None: |
|
|
return self._build_error_response(state, "No synthesis result") |
|
|
|
|
|
|
|
|
query_plan_dict = None |
|
|
if state.query_plan: |
|
|
query_plan_dict = { |
|
|
"intent": state.query_plan.intent.value, |
|
|
"sub_queries": len(state.query_plan.sub_queries), |
|
|
"expanded_terms": state.query_plan.expanded_terms[:5], |
|
|
} |
|
|
|
|
|
|
|
|
validation_dict = None |
|
|
if state.critic_result: |
|
|
validation_dict = { |
|
|
"is_valid": state.critic_result.is_valid, |
|
|
"confidence": state.critic_result.confidence, |
|
|
"hallucination_score": state.critic_result.hallucination_score, |
|
|
"citation_accuracy": state.critic_result.citation_accuracy, |
|
|
"issues": len(state.critic_result.issues), |
|
|
} |
|
|
|
|
|
return RAGResponse( |
|
|
answer=synthesis.answer, |
|
|
citations=synthesis.citations, |
|
|
confidence=synthesis.confidence, |
|
|
query=state.query, |
|
|
num_sources=synthesis.num_sources_used, |
|
|
validated=state.critic_result.is_valid if state.critic_result else False, |
|
|
revision_attempts=state.revision_attempt, |
|
|
query_plan=query_plan_dict, |
|
|
validation_details=validation_dict, |
|
|
latency_ms=total_time, |
|
|
) |
|
|
|
|
|
def _build_error_response( |
|
|
self, |
|
|
state: RAGState, |
|
|
error: str, |
|
|
) -> RAGResponse: |
|
|
"""Build error response.""" |
|
|
return RAGResponse( |
|
|
answer=f"I encountered an error processing your query: {error}", |
|
|
citations=[], |
|
|
confidence=0.0, |
|
|
query=state.query, |
|
|
num_sources=0, |
|
|
validated=False, |
|
|
revision_attempts=state.revision_attempt, |
|
|
latency_ms=(time.time() - state.start_time) * 1000, |
|
|
) |
|
|
|
|
|
def index_text( |
|
|
self, |
|
|
text: str, |
|
|
document_id: str, |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
) -> int: |
|
|
""" |
|
|
Index text content into the vector store. |
|
|
|
|
|
Args: |
|
|
text: Text content to index |
|
|
document_id: Unique document identifier |
|
|
metadata: Optional metadata |
|
|
|
|
|
Returns: |
|
|
Number of chunks indexed |
|
|
""" |
|
|
|
|
|
chunk_size = 500 |
|
|
overlap = 50 |
|
|
chunks = [] |
|
|
embeddings = [] |
|
|
|
|
|
for i in range(0, len(text), chunk_size - overlap): |
|
|
chunk_text = text[i:i + chunk_size] |
|
|
if len(chunk_text.strip()) < 50: |
|
|
continue |
|
|
|
|
|
chunk_id = f"{document_id}_chunk_{len(chunks)}" |
|
|
chunks.append({ |
|
|
"chunk_id": chunk_id, |
|
|
"document_id": document_id, |
|
|
"text": chunk_text, |
|
|
"page": 0, |
|
|
"chunk_type": "text", |
|
|
"source_path": metadata.get("filename", "") if metadata else "", |
|
|
}) |
|
|
|
|
|
|
|
|
embedding = self.embedder.embed_text(chunk_text) |
|
|
embeddings.append(embedding) |
|
|
|
|
|
if not chunks: |
|
|
return 0 |
|
|
|
|
|
|
|
|
self.store.add_chunks(chunks, embeddings) |
|
|
|
|
|
logger.info(f"Indexed {len(chunks)} chunks for document {document_id}") |
|
|
return len(chunks) |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get system statistics.""" |
|
|
return { |
|
|
"total_chunks": self.store.count(), |
|
|
"model": self.config.model, |
|
|
"embedding_model": self.embedder.model_name, |
|
|
"embedding_dimension": self.embedder.embedding_dimension, |
|
|
} |
|
|
|