rag-bajaj / RAG /rag_modules /query_expansion.py
quantumbit's picture
Upload 39 files
e8051be verified
"""
Query Expansion Module for Advanced RAG
Now uses Groq's llama3-8b-8192 model directly for generating focused sub-queries.
"""
import re
import time
import asyncio
from typing import List
from groq import Groq
from config.config import (
ENABLE_QUERY_EXPANSION,
QUERY_EXPANSION_COUNT,
GROQ_API_KEY_LITE,
GROQ_MODEL_LITE,
)
class QueryExpansionManager:
"""Manages query expansion for better information retrieval."""
def __init__(self):
"""Initialize the query expansion manager with Groq client."""
# Initialize Groq client with the lite key and llama3-8b-8192 model
self.model = GROQ_MODEL_LITE or "llama3-8b-8192"
if not GROQ_API_KEY_LITE:
print("⚠️ GROQ_API_KEY_LITE is not set. Query expansion will fall back to original query.")
self.client = None
else:
self.client = Groq(api_key=GROQ_API_KEY_LITE)
print(f"βœ… Query Expansion Manager initialized using Groq model: {self.model}")
async def expand_query(self, original_query: str) -> List[str]:
"""Break complex queries into focused parts for better information retrieval using Groq."""
if not ENABLE_QUERY_EXPANSION:
return [original_query]
if not self.client:
return [original_query]
try:
expansion_prompt = f"""Analyze this question and break it down into exactly {QUERY_EXPANSION_COUNT} specific, focused sub-questions that can be searched independently in a document. Each sub-question should target a distinct piece of information or process.
For complex questions with multiple parts, identify:
1. Different processes or procedures mentioned
2. Specific information requests (emails, contact details, forms, etc.)
3. Different entities or subjects involved
4. Sequential steps that might be documented separately
Original question: {original_query}
Break this into exactly {QUERY_EXPANSION_COUNT} focused search queries that target different aspects:
Examples of good breakdown:
- "What is the dental claim submission process?"
- "How to update surname/name in policy records?"
- "What are the company contact details and grievance email?"
Provide only {QUERY_EXPANSION_COUNT} focused sub-questions, one per line, without numbering or additional formatting:"""
# Call Groq's chat completions in a thread to avoid blocking the event loop
response = await asyncio.to_thread(
self.client.chat.completions.create,
messages=[{"role": "user", "content": expansion_prompt}],
model=self.model,
temperature=0.3,
max_tokens=300,
)
expanded_queries = [] # Start with empty list - don't include original
if response and response.choices:
content = response.choices[0].message.content if response.choices[0].message else ""
sub_queries = (content or "").strip().split('\n')
for query in sub_queries:
if len(expanded_queries) >= QUERY_EXPANSION_COUNT: # Stop when we have enough
break
query = query.strip()
# Remove any numbering or bullet points that might be added
query = re.sub(r'^[\d\.\-\*\s]+', '', query).strip()
if query and len(query) > 10:
expanded_queries.append(query)
# If we don't have enough sub-queries, fall back to using the original
if len(expanded_queries) < QUERY_EXPANSION_COUNT:
expanded_queries = [original_query] * QUERY_EXPANSION_COUNT
# Ensure we have exactly QUERY_EXPANSION_COUNT queries
final_queries = expanded_queries[:QUERY_EXPANSION_COUNT]
print(f"πŸ”„ Query broken down from 1 complex question to {len(final_queries)} focused sub-queries using Groq {self.model}")
print(f"πŸ“Œ Original query will be used for final LLM generation only")
for i, q in enumerate(final_queries):
print(f" Sub-query {i+1}: {q[:80]}...")
return final_queries
except Exception as e:
print(f"⚠️ Query expansion failed: {e}")
return [original_query]
def _identify_query_components(self, query: str) -> dict:
"""Identify different components in a complex query for better breakdown."""
components = {
'processes': [],
'documents': [],
'contacts': [],
'eligibility': [],
'timelines': [],
'benefits': []
}
# Define keywords for different component types
process_keywords = ['process', 'procedure', 'steps', 'how to', 'submit', 'apply', 'claim', 'update', 'change', 'enroll']
document_keywords = ['documents', 'forms', 'papers', 'certificate', 'proof', 'evidence', 'requirements']
contact_keywords = ['email', 'phone', 'contact', 'grievance', 'customer service', 'support', 'helpline']
eligibility_keywords = ['eligibility', 'criteria', 'qualify', 'eligible', 'conditions', 'requirements']
timeline_keywords = ['timeline', 'period', 'duration', 'time', 'days', 'months', 'waiting', 'grace']
benefit_keywords = ['benefits', 'coverage', 'limits', 'amount', 'reimbursement', 'claim amount']
query_lower = query.lower()
# Check for process-related content
if any(keyword in query_lower for keyword in process_keywords):
components['processes'].append('process identification')
# Check for document-related content
if any(keyword in query_lower for keyword in document_keywords):
components['documents'].append('document requirements')
# Check for contact-related content
if any(keyword in query_lower for keyword in contact_keywords):
components['contacts'].append('contact information')
# Check for eligibility-related content
if any(keyword in query_lower for keyword in eligibility_keywords):
components['eligibility'].append('eligibility criteria')
# Check for timeline-related content
if any(keyword in query_lower for keyword in timeline_keywords):
components['timelines'].append('timeline information')
# Check for benefit-related content
if any(keyword in query_lower for keyword in benefit_keywords):
components['benefits'].append('benefit details')
return components