audit_assistant / src /pipeline.py
Ara Yeroyan
add src
f5df983
raw
history blame
30.7 kB
"""Main pipeline orchestrator for the Audit QA system."""
import time
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from langchain.docstore.document import Document
from .logging import log_error
from .llm.adapters import LLMRegistry
from .loader import chunks_to_documents
from .vectorstore import VectorStoreManager
from .retrieval.context import ContextRetriever
from .config.loader import get_embedding_model_for_collection
@dataclass
class PipelineResult:
"""Result of pipeline execution."""
answer: str
sources: List[Document]
execution_time: float
metadata: Dict[str, Any]
query: str = "" # Add default value for query
def __post_init__(self):
"""Post-initialization processing."""
if not self.query:
self.query = "Unknown query"
class PipelineManager:
"""Main pipeline manager for the RAG system."""
def __init__(self, config: dict = None):
"""
Initialize the pipeline manager.
"""
self.config = config or {}
self.vectorstore_manager = None
self.context_retriever = None # Initialize as None
self.llm_client = None
self.report_service = None
self.chunks = None
# Initialize components
self._initialize_components()
def update_config(self, new_config: dict):
"""
Update the pipeline configuration.
This is useful for experiments that need different settings.
"""
if not isinstance(new_config, dict):
return
# Deep merge the new config with existing config
def deep_merge(base_dict, update_dict):
for key, value in update_dict.items():
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
deep_merge(base_dict[key], value)
else:
base_dict[key] = value
deep_merge(self.config, new_config)
# Auto-infer embedding model from collection name if not "docling"
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
if collection_name != 'docling':
inferred_model = get_embedding_model_for_collection(collection_name)
if inferred_model:
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
if 'retriever' not in self.config:
self.config['retriever'] = {}
self.config['retriever']['model'] = inferred_model
# Set default normalize parameter if not present
if 'normalize' not in self.config['retriever']:
self.config['retriever']['normalize'] = True
# Also update vectorstore config if it exists
if 'vectorstore' in self.config:
self.config['vectorstore']['embedding_model'] = inferred_model
print(f"🔧 CONFIG UPDATED: Pipeline config updated with experiment settings")
# Re-initialize vectorstore manager with updated config
self._reinitialize_vectorstore_manager()
def _reinitialize_vectorstore_manager(self):
"""Re-initialize vectorstore manager with current config."""
try:
self.vectorstore_manager = VectorStoreManager(self.config)
print("🔄 VectorStore manager re-initialized with updated config")
except Exception as e:
print(f"❌ Error re-initializing vectorstore manager: {e}")
def _get_reranker_model_name(self) -> str:
"""
Get the reranker model name from configuration.
Returns:
Reranker model name or default
"""
return (
self.config.get('retrieval', {}).get('reranker_model') or
self.config.get('ranker', {}).get('model') or
self.config.get('reranker_model') or
'BAAI/bge-reranker-v2-m3'
)
def _initialize_components(self):
"""Initialize pipeline components."""
try:
# Load config if not provided
if not self.config:
from auditqa.config.loader import load_config
self.config = load_config()
# Auto-infer embedding model from collection name if not "docling"
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
if collection_name != 'docling':
inferred_model = get_embedding_model_for_collection(collection_name)
if inferred_model:
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
if 'retriever' not in self.config:
self.config['retriever'] = {}
self.config['retriever']['model'] = inferred_model
# Set default normalize parameter if not present
if 'normalize' not in self.config['retriever']:
self.config['retriever']['normalize'] = True
# Also update vectorstore config if it exists
if 'vectorstore' in self.config:
self.config['vectorstore']['embedding_model'] = inferred_model
self.vectorstore_manager = VectorStoreManager(self.config)
self.llm_manager = LLMRegistry()
# Try to get LLM client using the correct method
self.llm_client = None
try:
# Try using get_adapter method (most likely correct)
self.llm_client = self.llm_manager.get_adapter("openai")
print("✅ LLM CLIENT: Initialized using get_adapter method")
except Exception as e:
try:
# Try direct instantiation with config
from auditqa.llm.adapters import get_llm_client
self.llm_client = get_llm_client("openai", self.config)
print("✅ LLM CLIENT: Initialized using direct get_llm_client function with config")
except Exception as e2:
print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
# Try to create a simple LLM client directly
try:
from langchain_openai import ChatOpenAI
import os
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
if api_key:
self.llm_client = ChatOpenAI(
model="gpt-3.5-turbo",
api_key=api_key,
temperature=0.1,
max_tokens=1000
)
print("✅ LLM CLIENT: Initialized using direct ChatOpenAI")
else:
print("❌ LLM CLIENT: No API key available")
except Exception as e3:
print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
self.llm_client = None
# Load system prompt
from auditqa.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
# Initialize report service
try:
from auditqa.reporting.service import ReportService
self.report_service = ReportService()
except Exception as e:
print(f"Warning: Could not initialize report service: {e}")
self.report_service = None
except Exception as e:
print(f"Warning: Error initializing components: {e}")
def test_retrieval(
self,
query: str,
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
k: int = None,
search_mode: str = None,
search_alpha: float = None,
use_reranking: bool = True
) -> Dict[str, Any]:
"""
Test retrieval only without LLM inference.
Args:
query: User query
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
k: Number of documents to retrieve
search_mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
search_alpha: Weight for vector scores in hybrid mode
use_reranking: Whether to use reranking
Returns:
Dictionary with retrieval results and metadata
"""
start_time = time.time()
try:
# Set default search parameters if not provided
if search_mode is None:
search_mode = self.config.get("hybrid", {}).get("default_mode", "vector_only")
if search_alpha is None:
search_alpha = self.config.get("hybrid", {}).get("default_alpha", 0.5)
# Get vector store
vectorstore = self.vectorstore_manager.get_vectorstore()
if not vectorstore:
raise ValueError(
"Vector store not available. Call connect_vectorstore() or create_vectorstore() first."
)
# Retrieve context with scores for test retrieval
context_docs_with_scores = self.context_retriever.retrieve_with_scores(
vectorstore=vectorstore,
query=query,
reports=reports,
sources=sources,
subtype=subtype,
k=k,
search_mode=search_mode,
alpha=search_alpha,
)
# Extract documents and scores
context_docs = [doc for doc, score in context_docs_with_scores]
context_scores = [score for doc, score in context_docs_with_scores]
execution_time = time.time() - start_time
# Format results with actual scores
results = []
for i, (doc, score) in enumerate(zip(context_docs, context_scores)):
results.append({
"rank": i + 1,
"content": doc.page_content, # Return full content without truncation
"metadata": doc.metadata,
"score": score if score is not None else 0.0
})
return {
"results": results,
"num_results": len(results),
"execution_time": execution_time,
"search_mode": search_mode,
"search_alpha": search_alpha,
"query": query
}
except Exception as e:
print(f"❌ Error during retrieval test: {e}")
log_error(e, {"component": "retrieval_test", "query": query})
return {
"results": [],
"num_results": 0,
"execution_time": time.time() - start_time,
"error": str(e),
"search_mode": search_mode or "unknown",
"search_alpha": search_alpha or 0.5,
"query": query
}
def connect_vectorstore(self, force_recreate: bool = False) -> bool:
"""
Connect to existing vector store.
Args:
force_recreate: If True, recreate the collection if dimension mismatch occurs
Returns:
True if successful, False otherwise
"""
try:
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=force_recreate)
if vectorstore:
print("✅ Connected to vector store")
return True
else:
print("❌ Failed to connect to vector store")
return False
except Exception as e:
print(f"❌ Error connecting to vector store: {e}")
log_error(e, {"component": "vectorstore_connection"})
# If it's a dimension mismatch error, try with force_recreate
if "dimensions" in str(e).lower() and not force_recreate:
print("🔄 Dimension mismatch detected, attempting to recreate collection...")
try:
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=True)
if vectorstore:
print("✅ Connected to vector store (recreated)")
return True
except Exception as recreate_error:
print(f"❌ Failed to recreate vector store: {recreate_error}")
log_error(recreate_error, {"component": "vectorstore_recreation"})
return False
def create_vectorstore(self) -> bool:
"""
Create new vector store from chunks.
Returns:
True if successful, False otherwise
"""
try:
if not self.chunks:
raise ValueError("No chunks available for vector store creation")
documents = chunks_to_documents(self.chunks)
self.vectorstore_manager.create_from_documents(documents)
print("✅ Vector store created successfully")
return True
except Exception as e:
print(f"❌ Error creating vector store: {e}")
log_error(e, {"component": "vectorstore_creation"})
return False
def create_audit_prompt(self, query: str, context_docs: List[Document]) -> str:
"""Create a prompt for the LLM to generate an answer."""
try:
# Ensure query is not None
if not query or not isinstance(query, str) or query.strip() == "":
return "Error: No query provided"
# Ensure context_docs is not None and is a list
if context_docs is None:
context_docs = []
# Filter out None documents and ensure they have content
valid_docs = []
for doc in context_docs:
if doc is not None:
if hasattr(doc, 'page_content') and doc.page_content and isinstance(doc.page_content, str):
valid_docs.append(doc)
elif isinstance(doc, str) and doc.strip():
valid_docs.append(doc)
# Create context string
if valid_docs:
context_parts = []
for i, doc in enumerate(valid_docs, 1):
if hasattr(doc, 'page_content') and doc.page_content:
context_parts.append(f"Doc {i}: {doc.page_content}")
elif isinstance(doc, str) and doc.strip():
context_parts.append(f"Doc {i}: {doc}")
context_string = "\n\n".join(context_parts)
else:
context_string = "No relevant context found."
# Create the prompt
prompt = f"""
{self.system_prompt}
Context:
{context_string}
Query: {query}
Answer:"""
return prompt
except Exception as e:
print(f"Error creating audit prompt: {e}")
return f"Error creating prompt: {e}"
def _generate_answer(self, prompt: str) -> str:
"""Generate answer using the LLM."""
try:
if not prompt or not isinstance(prompt, str) or prompt.strip() == "":
return "Error: No prompt provided"
# Ensure LLM client is available
if not self.llm_client:
return "Error: LLM client not available"
# Generate response using the correct method
if hasattr(self.llm_client, 'generate'):
# Use the generate method (for adapters)
response = self.llm_client.generate([{"role": "user", "content": prompt}])
# Extract content from LLMResponse
if hasattr(response, 'content'):
answer = response.content
else:
answer = str(response)
elif hasattr(self.llm_client, 'invoke'):
# Use the invoke method (for direct LangChain models)
response = self.llm_client.invoke(prompt)
# Extract content safely
if hasattr(response, 'content') and response.content is not None:
answer = response.content
elif isinstance(response, str) and response.strip():
answer = response
else:
answer = str(response) if response is not None else "Error: LLM returned None response"
else:
return "Error: LLM client has no generate or invoke method"
# Ensure answer is not None and is a string
if answer is None or not isinstance(answer, str):
return "Error: LLM returned invalid response"
return answer.strip()
except Exception as e:
print(f"Error generating answer: {e}")
return f"Error generating answer: {e}"
def run(
self,
query: str,
reports: List[str] = None,
sources: List[str] = None,
subtype: List[str] = None,
llm_provider: str = None,
use_reranking: bool = True,
search_mode: str = None,
search_alpha: float = None,
auto_infer_filters: bool = True,
filters: Dict[str, Any] = None,
) -> PipelineResult:
"""
Run the complete RAG pipeline.
Args:
query: User query
reports: List of specific report filenames
sources: Source category filter
subtype: List of subtypes/filenames
llm_provider: LLM provider to use
use_reranking: Whether to use reranking
search_mode: Search mode (vector, sparse, hybrid)
search_alpha: Alpha value for hybrid search
auto_infer_filters: Whether to auto-infer filters from query
Returns:
PipelineResult object
"""
try:
# Validate input
if not query or not isinstance(query, str) or query.strip() == "":
return PipelineResult(
answer="Error: Invalid query provided",
sources=[],
execution_time=0.0,
metadata={'error': 'Invalid query'},
query=query
)
# Ensure lists are not None
if reports is None:
reports = []
if subtype is None:
subtype = []
start_time = time.time()
# Auto-infer filters if enabled and no explicit filters provided
inferred_filters = {}
filters_applied = False
qdrant_filter = None # Add this
if auto_infer_filters and not any([reports, sources, subtype]):
print(f"🤖 AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
try:
# Import get_available_metadata here to avoid circular imports
from auditqa.retrieval.filter import get_available_metadata, infer_filters_from_query
# Get available metadata
available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
# Infer filters from query - this returns a Qdrant filter
qdrant_filter, filter_summary = infer_filters_from_query(
query=query,
available_metadata=available_metadata,
llm_client=self.llm_client
)
if qdrant_filter:
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
filters_applied = True
# Don't set sources/reports/subtype - use the Qdrant filter directly
else:
print(f"⚠️ NO QDRANT FILTER: Could not build Qdrant filter from query")
except Exception as e:
print(f"❌ AUTO-INFERENCE FAILED: {e}")
qdrant_filter = None
else:
# Check if any explicit filters were provided
filters_applied = any([reports, sources, subtype])
if filters_applied:
print(f"✅ EXPLICIT FILTERS: Using provided filters")
else:
print(f"⚠️ NO FILTERS: No explicit filters and auto-inference disabled")
# Extract filter parameters from the filters parameter
reports = filters.get('reports', []) if filters else []
sources = filters.get('sources', []) if filters else []
subtype = filters.get('subtype', []) if filters else []
year = filters.get('year', []) if filters else []
district = filters.get('district', []) if filters else []
filenames = filters.get('filenames', []) if filters else [] # Support mutually exclusive filename filtering
# Get vectorstore
vectorstore = self.vectorstore_manager.get_vectorstore()
if not vectorstore:
return PipelineResult(
answer="Error: Vector store not available",
sources=[],
execution_time=0.0,
metadata={'error': 'Vector store not available'},
query=query
)
# Initialize context retriever if not already done
if not hasattr(self, 'context_retriever') or self.context_retriever is None:
# Get the actual vectorstore object
vectorstore_obj = self.vectorstore_manager.get_vectorstore()
if vectorstore_obj is None:
print("❌ ERROR: Vectorstore is None, cannot initialize ContextRetriever")
return None
self.context_retriever = ContextRetriever(vectorstore_obj, self.config)
print("✅ ContextRetriever initialized successfully")
# Debug config access
print(f" CONFIG DEBUG: Full config keys: {list(self.config.keys()) if isinstance(self.config, dict) else 'Not a dict'}")
print(f"🔍 CONFIG DEBUG: Retriever config: {self.config.get('retriever', {})}")
print(f"🔍 CONFIG DEBUG: Retrieval config: {self.config.get('retrieval', {})}")
print(f"🔍 CONFIG DEBUG: use_reranking from config: {self.config.get('retrieval', {}).get('use_reranking', 'NOT_FOUND')}")
# Get the correct top_k value
# Priority: experiment config > retriever config > default
top_k = (
self.config.get('retrieval', {}).get('top_k') or
self.config.get('retriever', {}).get('top_k') or
5
)
# Get reranking setting
use_reranking = self.config.get('retrieval', {}).get('use_reranking', False)
print(f"🔍 CONFIG DEBUG: Final top_k: {top_k}")
print(f"🔍 CONFIG DEBUG: Final use_reranking: {use_reranking}")
# Retrieve context using the context retriever
context_docs = self.context_retriever.retrieve_context(
query=query,
k=top_k,
reports=reports,
sources=sources,
subtype=subtype,
year=year,
district=district,
filenames=filenames,
use_reranking=use_reranking,
qdrant_filter=qdrant_filter
)
# Ensure context_docs is not None
if context_docs is None:
context_docs = []
# Generate answer
answer = self._generate_answer(self.create_audit_prompt(query, context_docs))
execution_time = time.time() - start_time
# Create result with comprehensive metadata
result = PipelineResult(
answer=answer,
sources=context_docs,
execution_time=execution_time,
metadata={
'llm_provider': llm_provider,
'use_reranking': use_reranking,
'search_mode': search_mode,
'search_alpha': search_alpha,
'auto_infer_filters': auto_infer_filters,
'filters_applied': filters_applied,
'with_filtering': filters_applied,
'filter_conditions': {
'reports': reports,
'sources': sources,
'subtype': subtype
},
'inferred_filters': inferred_filters,
'applied_filters': {
'reports': reports,
'sources': sources,
'subtype': subtype
},
# Store filter and reranking metadata
'filter_details': {
'explicit_filters': {
'reports': reports,
'sources': sources,
'subtype': subtype,
'year': year
},
'inferred_filters': inferred_filters if auto_infer_filters else {},
'auto_inference_enabled': auto_infer_filters,
'qdrant_filter_applied': qdrant_filter is not None,
'filter_summary': filter_summary if 'filter_summary' in locals() else None
},
'reranker_model': self._get_reranker_model_name() if use_reranking else None,
'reranker_applied': use_reranking,
'reranking_info': {
'model': self._get_reranker_model_name(),
'applied': use_reranking,
'top_k': len(context_docs) if context_docs else 0,
# 'original_documents': [
# {
# 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
# 'metadata': doc.metadata,
# 'score': getattr(doc, 'score', getattr(doc, 'original_score', 0.0))
# } for doc in context_docs
# ] if use_reranking else None,
'reranked_documents': [
{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': doc.metadata,
'score': doc.metadata.get('original_score', getattr(doc, 'score', 0.0)),
'original_rank': doc.metadata.get('original_rank', None),
'final_rank': doc.metadata.get('final_rank', None),
'reranked_score': doc.metadata.get('reranked_score', None)
} for doc in context_docs
] if use_reranking else None
}
},
query=query
)
return result
except Exception as e:
print(f"Error in pipeline run: {e}")
return PipelineResult(
answer=f"Error processing query: {e}",
sources=[],
execution_time=0.0,
metadata={'error': str(e)},
query=query
)
def get_system_status(self) -> Dict[str, Any]:
"""
Get system status information.
Returns:
Dictionary with system status
"""
status = {
"config_loaded": bool(self.config),
"chunks_loaded": bool(self.chunks),
"vectorstore_connected": bool(
self.vectorstore_manager and self.vectorstore_manager.get_vectorstore()
),
"components_initialized": bool(
self.context_retriever and self.report_service
),
}
if self.chunks:
status["num_chunks"] = len(self.chunks)
if self.report_service:
status["available_sources"] = self.report_service.get_available_sources()
status["available_reports"] = len(
self.report_service.get_available_reports()
)
status["overall_status"] = (
"ready"
if all(
[
status["config_loaded"],
status["chunks_loaded"],
status["vectorstore_connected"],
status["components_initialized"],
]
)
else "not_ready"
)
return status
def get_available_llm_providers(self) -> List[str]:
"""Get list of available LLM providers."""
providers = []
reader_config = self.config.get("reader", {})
for provider in [
"MISTRAL",
"OPENAI",
"OLLAMA",
"INF_PROVIDERS",
"NVIDIA",
"DEDICATED",
"OPENROUTER",
]:
if provider in reader_config:
providers.append(provider.lower())
return providers