Spaces:
Sleeping
Sleeping
""" | |
Query Expansion Module for Advanced RAG | |
Now uses Groq's llama3-8b-8192 model directly for generating focused sub-queries. | |
""" | |
import re | |
import time | |
import asyncio | |
from typing import List | |
from groq import Groq | |
from config.config import ( | |
ENABLE_QUERY_EXPANSION, | |
QUERY_EXPANSION_COUNT, | |
GROQ_API_KEY_LITE, | |
GROQ_MODEL_LITE, | |
) | |
class QueryExpansionManager: | |
"""Manages query expansion for better information retrieval.""" | |
def __init__(self): | |
"""Initialize the query expansion manager with Groq client.""" | |
# Initialize Groq client with the lite key and llama3-8b-8192 model | |
self.model = GROQ_MODEL_LITE or "llama3-8b-8192" | |
if not GROQ_API_KEY_LITE: | |
print("β οΈ GROQ_API_KEY_LITE is not set. Query expansion will fall back to original query.") | |
self.client = None | |
else: | |
self.client = Groq(api_key=GROQ_API_KEY_LITE) | |
print(f"β Query Expansion Manager initialized using Groq model: {self.model}") | |
async def expand_query(self, original_query: str) -> List[str]: | |
"""Break complex queries into focused parts for better information retrieval using Groq.""" | |
if not ENABLE_QUERY_EXPANSION: | |
return [original_query] | |
if not self.client: | |
return [original_query] | |
try: | |
expansion_prompt = f"""Analyze this question and break it down into exactly {QUERY_EXPANSION_COUNT} specific, focused sub-questions that can be searched independently in a document. Each sub-question should target a distinct piece of information or process. | |
For complex questions with multiple parts, identify: | |
1. Different processes or procedures mentioned | |
2. Specific information requests (emails, contact details, forms, etc.) | |
3. Different entities or subjects involved | |
4. Sequential steps that might be documented separately | |
Original question: {original_query} | |
Break this into exactly {QUERY_EXPANSION_COUNT} focused search queries that target different aspects: | |
Examples of good breakdown: | |
- "What is the dental claim submission process?" | |
- "How to update surname/name in policy records?" | |
- "What are the company contact details and grievance email?" | |
Provide only {QUERY_EXPANSION_COUNT} focused sub-questions, one per line, without numbering or additional formatting:""" | |
# Call Groq's chat completions in a thread to avoid blocking the event loop | |
response = await asyncio.to_thread( | |
self.client.chat.completions.create, | |
messages=[{"role": "user", "content": expansion_prompt}], | |
model=self.model, | |
temperature=0.3, | |
max_tokens=300, | |
) | |
expanded_queries = [] # Start with empty list - don't include original | |
if response and response.choices: | |
content = response.choices[0].message.content if response.choices[0].message else "" | |
sub_queries = (content or "").strip().split('\n') | |
for query in sub_queries: | |
if len(expanded_queries) >= QUERY_EXPANSION_COUNT: # Stop when we have enough | |
break | |
query = query.strip() | |
# Remove any numbering or bullet points that might be added | |
query = re.sub(r'^[\d\.\-\*\s]+', '', query).strip() | |
if query and len(query) > 10: | |
expanded_queries.append(query) | |
# If we don't have enough sub-queries, fall back to using the original | |
if len(expanded_queries) < QUERY_EXPANSION_COUNT: | |
expanded_queries = [original_query] * QUERY_EXPANSION_COUNT | |
# Ensure we have exactly QUERY_EXPANSION_COUNT queries | |
final_queries = expanded_queries[:QUERY_EXPANSION_COUNT] | |
print(f"π Query broken down from 1 complex question to {len(final_queries)} focused sub-queries using Groq {self.model}") | |
print(f"π Original query will be used for final LLM generation only") | |
for i, q in enumerate(final_queries): | |
print(f" Sub-query {i+1}: {q[:80]}...") | |
return final_queries | |
except Exception as e: | |
print(f"β οΈ Query expansion failed: {e}") | |
return [original_query] | |
def _identify_query_components(self, query: str) -> dict: | |
"""Identify different components in a complex query for better breakdown.""" | |
components = { | |
'processes': [], | |
'documents': [], | |
'contacts': [], | |
'eligibility': [], | |
'timelines': [], | |
'benefits': [] | |
} | |
# Define keywords for different component types | |
process_keywords = ['process', 'procedure', 'steps', 'how to', 'submit', 'apply', 'claim', 'update', 'change', 'enroll'] | |
document_keywords = ['documents', 'forms', 'papers', 'certificate', 'proof', 'evidence', 'requirements'] | |
contact_keywords = ['email', 'phone', 'contact', 'grievance', 'customer service', 'support', 'helpline'] | |
eligibility_keywords = ['eligibility', 'criteria', 'qualify', 'eligible', 'conditions', 'requirements'] | |
timeline_keywords = ['timeline', 'period', 'duration', 'time', 'days', 'months', 'waiting', 'grace'] | |
benefit_keywords = ['benefits', 'coverage', 'limits', 'amount', 'reimbursement', 'claim amount'] | |
query_lower = query.lower() | |
# Check for process-related content | |
if any(keyword in query_lower for keyword in process_keywords): | |
components['processes'].append('process identification') | |
# Check for document-related content | |
if any(keyword in query_lower for keyword in document_keywords): | |
components['documents'].append('document requirements') | |
# Check for contact-related content | |
if any(keyword in query_lower for keyword in contact_keywords): | |
components['contacts'].append('contact information') | |
# Check for eligibility-related content | |
if any(keyword in query_lower for keyword in eligibility_keywords): | |
components['eligibility'].append('eligibility criteria') | |
# Check for timeline-related content | |
if any(keyword in query_lower for keyword in timeline_keywords): | |
components['timelines'].append('timeline information') | |
# Check for benefit-related content | |
if any(keyword in query_lower for keyword in benefit_keywords): | |
components['benefits'].append('benefit details') | |
return components | |