rag-bajaj / RAG /rag_modules /context_manager.py
quantumbit's picture
Upload 39 files
e8051be verified
"""
Context Management Module for Advanced RAG
Handles context creation and management for LLM generation.
"""
from typing import List, Dict
from collections import defaultdict
from config.config import MAX_CONTEXT_LENGTH
class ContextManager:
"""Manages context creation for LLM generation."""
def __init__(self):
"""Initialize the context manager."""
print("βœ… Context Manager initialized")
def create_enhanced_context(self, question: str, results: List[Dict], max_length: int = MAX_CONTEXT_LENGTH) -> str:
"""Create enhanced context ensuring each query contributes equally."""
# Group results by expanded query index
query_to_chunks = defaultdict(list)
for i, result in enumerate(results):
# Find the most relevant expanded query for this chunk
if 'contributing_queries' in result and result['contributing_queries']:
# Use the highest scoring contributing query
best_contrib = max(result['contributing_queries'], key=lambda cq: cq.get('semantic_score', cq.get('bm25_score', 0)))
query_idx = best_contrib['query_idx']
else:
query_idx = 0 # fallback to first query
query_to_chunks[query_idx].append((i, result))
# Sort chunks within each query by their relevance scores
for q_idx in query_to_chunks:
query_to_chunks[q_idx].sort(key=lambda x: x[1].get('rerank_score', x[1].get('final_score', x[1].get('score', 0))), reverse=True)
# Calculate chunks per query (should be 3 for each query with total budget = 9 and 3 queries)
num_queries = len(query_to_chunks)
if num_queries == 0:
return ""
# Ensure each query contributes equally (round-robin with guaranteed slots)
context_parts = []
current_length = 0
added_chunks = set()
# Calculate how many chunks each query should contribute
chunks_per_query = len(results) // num_queries if num_queries > 0 else len(results)
extra_chunks = len(results) % num_queries
print(f"πŸ“Š Context Creation: {num_queries} queries, {chunks_per_query} chunks per query (+{extra_chunks} extra)")
for q_idx in sorted(query_to_chunks.keys()):
# Determine how many chunks this query should contribute
query_chunk_limit = chunks_per_query + (1 if q_idx < extra_chunks else 0)
query_chunks_added = 0
print(f" Query {q_idx+1}: Adding up to {query_chunk_limit} chunks")
for i, result in query_to_chunks[q_idx]:
if i not in added_chunks and query_chunks_added < query_chunk_limit:
text = result['payload'].get('text', '')
relevance_info = ""
if 'rerank_score' in result:
relevance_info = f" [Relevance: {result['rerank_score']:.2f}]"
elif 'final_score' in result:
relevance_info = f" [Score: {result['final_score']:.2f}]"
doc_text = f"[Query {q_idx+1} Doc {len(added_chunks)+1}]{relevance_info}\n{text}\n"
if current_length + len(doc_text) > max_length:
print(f" ⚠️ Context length limit reached at {current_length} chars")
break
context_parts.append(doc_text)
current_length += len(doc_text)
added_chunks.add(i)
query_chunks_added += 1
print(f" Query {q_idx+1}: Added {query_chunks_added} chunks")
print(f"πŸ“ Final context: {len(added_chunks)} chunks, {current_length} chars")
return "\n".join(context_parts)